Source code for lightning_ir.models.splade

 1from typing import Literal
 2
 3from ..bi_encoder import BiEncoderConfig, BiEncoderModel
 4
 5
[docs] 6class SpladeConfig(BiEncoderConfig): 7 model_type = "splade" 8
[docs] 9 def __init__( 10 self, 11 query_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "max", 12 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "max", 13 projection: Literal["linear", "linear_no_bias", "mlm"] | None = "mlm", 14 sparsification: Literal["relu", "relu_log"] | None = "relu_log", 15 embedding_dim: int = 30522, 16 **kwargs, 17 ) -> None: 18 kwargs["query_expansion"] = False 19 kwargs["attend_to_query_expanded_tokens"] = False 20 kwargs["query_mask_scoring_tokens"] = None 21 kwargs["doc_expansion"] = False 22 kwargs["attend_to_doc_expanded_tokens"] = False 23 kwargs["doc_mask_scoring_tokens"] = None 24 kwargs["query_aggregation_function"] = "sum" 25 kwargs["normalize"] = False 26 kwargs["add_marker_tokens"] = False 27 super().__init__( 28 query_pooling_strategy=query_pooling_strategy, 29 doc_pooling_strategy=doc_pooling_strategy, 30 embedding_dim=embedding_dim, 31 projection=projection, 32 sparsification=sparsification, 33 **kwargs, 34 )
35 36
[docs] 37class SpladeModel(BiEncoderModel): 38 config_class = SpladeConfig