BiEncoderModel
- class lightning_ir.bi_encoder.model.BiEncoderModel(config: BiEncoderConfig, *args, **kwargs)[source]
Bases:
LightningIRModel
- __init__(config: BiEncoderConfig, *args, **kwargs) None [source]
Methods
__init__
(config, *args, **kwargs)add_mask_scoring_input_ids
()doc_scoring_mask
(input_ids, attention_mask)encode_doc
(encoding)encode_query
(encoding)forward
(query_encoding, doc_encoding[, num_docs])from_pretrained
(model_name_or_path, *args, ...)Loads a pretrained model.
get_output_embeddings
()query_scoring_mask
(input_ids, attention_mask)score
(query_embeddings, doc_embeddings[, ...])Attributes
ALLOW_SUB_BATCHING
Flag to allow mini batches of documents for a single query.
- config_class
alias of
BiEncoderConfig
- classmethod from_pretrained(model_name_or_path: str | Path, *args, **kwargs) LightningIRModel
Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method and to return a derived LightningIRModel. See
LightningIRModelClassFactory
for more details.- Parameters:
model_name_or_path (str | Path) – Name or path of the pretrained model
- Raises:
ValueError – If called on the abstract class
LightningIRModel
and no config is passed- Returns:
A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin
- Return type:
>>> # Loading using model class and backbone checkpoint >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> >>> # Loading using base class and backbone checkpoint >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>