Source code for lightning_ir.bi_encoder.config

  1import json
  2import os
  3from os import PathLike
  4from typing import Any, Dict, Literal, Sequence, Tuple
  5
  6from ..base import LightningIRConfig
  7
  8
[docs] 9class BiEncoderConfig(LightningIRConfig): 10 """The configuration class to instantiate a Bi-Encoder model.""" 11 12 model_type = "bi-encoder" 13 14 TOKENIZER_ARGS = LightningIRConfig.TOKENIZER_ARGS.union( 15 { 16 "query_expansion", 17 "attend_to_query_expanded_tokens", 18 "doc_expansion", 19 "attend_to_doc_expanded_tokens", 20 "add_marker_tokens", 21 } 22 ) 23 24 ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union( 25 { 26 "similarity_function", 27 "query_pooling_strategy", 28 "query_mask_scoring_tokens", 29 "query_aggregation_function", 30 "doc_pooling_strategy", 31 "doc_mask_scoring_tokens", 32 "normalize", 33 "sparsification", 34 "embedding_dim", 35 "projection", 36 } 37 ).union(TOKENIZER_ARGS) 38
[docs] 39 def __init__( 40 self, 41 similarity_function: Literal["cosine", "dot"] = "dot", 42 query_expansion: bool = False, 43 attend_to_query_expanded_tokens: bool = False, 44 query_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean", 45 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 46 query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum", 47 doc_expansion: bool = False, 48 attend_to_doc_expanded_tokens: bool = False, 49 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean", 50 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 51 normalize: bool = False, 52 sparsification: Literal["relu", "relu_log"] | None = None, 53 add_marker_tokens: bool = False, 54 embedding_dim: int = 768, 55 projection: Literal["linear", "linear_no_bias", "mlm"] | None = "linear", 56 **kwargs, 57 ): 58 """Initializes a bi-encoder configuration. 59 60 :param similarity_function: Similarity function to compute scores between query and document embeddings, 61 defaults to "dot" 62 :type similarity_function: Literal['cosine', 'dot'], optional 63 :param query_expansion: Whether to expand queries with mask tokens, defaults to False 64 :type query_expansion: bool, optional 65 :param attend_to_query_expanded_tokens: Whether to allow query tokens to attend to mask tokens, 66 defaults to False 67 :type attend_to_query_expanded_tokens: bool, optional 68 :param query_pooling_strategy: Whether and how to pool the query token embeddings, defaults to "mean" 69 :type query_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional 70 :param query_mask_scoring_tokens: Whether and which query tokens to ignore during scoring, defaults to None 71 :type query_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional 72 :param query_aggregation_function: How to aggregate similarity scores over query tokens, defaults to "sum" 73 :type query_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional 74 :param doc_expansion: Whether to expand documents with mask tokens, defaults to False 75 :type doc_expansion: bool, optional 76 :param attend_to_doc_expanded_tokens: Whether to allow document tokens to attend to mask tokens, 77 defaults to False 78 :type attend_to_doc_expanded_tokens: bool, optional 79 :param doc_pooling_strategy: Whether andhow to pool document token embeddings, defaults to "mean" 80 :type doc_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional 81 :param doc_mask_scoring_tokens: Whether and which document tokens to ignore during scoring, defaults to None 82 :type doc_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional 83 :param normalize: Whether to normalize query and document embeddings, defaults to False 84 :type normalize: bool, optional 85 :param sparsification: Whether and which sparsification function to apply, defaults to None 86 :type sparsification: Literal['relu', 'relu_log'] | None, optional 87 :param add_marker_tokens: Whether to add extra marker tokens [Q] / [D] to queries / documents, defaults to False 88 :type add_marker_tokens: bool, optional 89 :param embedding_dim: The output embedding dimension, defaults to 768 90 :type embedding_dim: int, optional 91 :param projection: Whether and how to project the output emeddings, defaults to "linear" 92 :type projection: Literal['linear', 'linear_no_bias', 'mlm'] | None, optional 93 """ 94 super().__init__(**kwargs) 95 self.similarity_function = similarity_function 96 self.query_expansion = query_expansion 97 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 98 self.query_pooling_strategy = query_pooling_strategy 99 self.query_mask_scoring_tokens = query_mask_scoring_tokens 100 self.query_aggregation_function = query_aggregation_function 101 self.doc_expansion = doc_expansion 102 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens 103 self.doc_pooling_strategy = doc_pooling_strategy 104 self.doc_mask_scoring_tokens = doc_mask_scoring_tokens 105 self.normalize = normalize 106 self.sparsification = sparsification 107 self.add_marker_tokens = add_marker_tokens 108 self.embedding_dim = embedding_dim 109 self.projection = projection
110 111 def to_dict(self) -> Dict[str, Any]: 112 output = super().to_dict() 113 if "query_mask_scoring_tokens" in output: 114 output.pop("query_mask_scoring_tokens") 115 if "doc_mask_scoring_tokens" in output: 116 output.pop("doc_mask_scoring_tokens") 117 return output 118 119 def save_pretrained(self, save_directory: str | PathLike, push_to_hub: bool = False, **kwargs): 120 with open(os.path.join(save_directory, "mask_scoring_tokens.json"), "w") as f: 121 json.dump({"query": self.query_mask_scoring_tokens, "doc": self.doc_mask_scoring_tokens}, f) 122 return super().save_pretrained(save_directory, push_to_hub, **kwargs) 123 124 @classmethod 125 def get_config_dict( 126 cls, pretrained_model_name_or_path: str | PathLike, **kwargs 127 ) -> Tuple[Dict[str, Any], Dict[str, Any]]: 128 config_dict, kwargs = super().get_config_dict(pretrained_model_name_or_path, **kwargs) 129 mask_scoring_tokens = None 130 mask_scoring_tokens_path = os.path.join(pretrained_model_name_or_path, "mask_scoring_tokens.json") 131 if os.path.exists(mask_scoring_tokens_path): 132 with open(mask_scoring_tokens_path) as f: 133 mask_scoring_tokens = json.load(f) 134 config_dict["query_mask_scoring_tokens"] = mask_scoring_tokens["query"] 135 config_dict["doc_mask_scoring_tokens"] = mask_scoring_tokens["doc"] 136 return config_dict, kwargs