SpladeModel

class lightning_ir.models.splade.SpladeModel(config: BiEncoderConfig, *args, **kwargs)[source]

Bases: BiEncoderModel

__init__(config: BiEncoderConfig, *args, **kwargs) None

Methods

__init__(config, *args, **kwargs)

add_mask_scoring_input_ids()

doc_scoring_mask(input_ids, attention_mask)

encode_doc(encoding)

encode_query(encoding)

forward(query_encoding, doc_encoding[, num_docs])

from_pretrained(model_name_or_path, *args, ...)

Loads a pretrained model.

get_output_embeddings()

query_scoring_mask(input_ids, attention_mask)

score(query_embeddings, doc_embeddings[, ...])

Attributes

ALLOW_SUB_BATCHING

Flag to allow mini batches of documents for a single query.

classmethod from_pretrained(model_name_or_path: str | Path, *args, **kwargs) LightningIRModel

Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method and to return a derived LightningIRModel. See LightningIRModelClassFactory for more details.

Parameters:

model_name_or_path (str | Path) – Name or path of the pretrained model

Raises:

ValueError – If called on the abstract class LightningIRModel and no config is passed

Returns:

A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin

Return type:

LightningIRModel

>>> # Loading using model class and backbone checkpoint
>>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
<class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
>>> # Loading using base class and backbone checkpoint
>>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
<class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>