1import json
2import os
3from os import PathLike
4from typing import Any, Dict, Literal, Sequence, Tuple
5
6from ..base import LightningIRConfig
7
8
[docs]
9class BiEncoderConfig(LightningIRConfig):
10 """The configuration class to instantiate a Bi-Encoder model."""
11
12 model_type = "bi-encoder"
13
14 TOKENIZER_ARGS = LightningIRConfig.TOKENIZER_ARGS.union(
15 {
16 "query_expansion",
17 "attend_to_query_expanded_tokens",
18 "doc_expansion",
19 "attend_to_doc_expanded_tokens",
20 "add_marker_tokens",
21 }
22 )
23
24 ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union(
25 {
26 "similarity_function",
27 "query_pooling_strategy",
28 "query_mask_scoring_tokens",
29 "query_aggregation_function",
30 "doc_pooling_strategy",
31 "doc_mask_scoring_tokens",
32 "normalize",
33 "sparsification",
34 "embedding_dim",
35 "projection",
36 }
37 ).union(TOKENIZER_ARGS)
38
[docs]
39 def __init__(
40 self,
41 similarity_function: Literal["cosine", "dot"] = "dot",
42 query_expansion: bool = False,
43 attend_to_query_expanded_tokens: bool = False,
44 query_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean",
45 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
46 query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum",
47 doc_expansion: bool = False,
48 attend_to_doc_expanded_tokens: bool = False,
49 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean",
50 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
51 normalize: bool = False,
52 sparsification: Literal["relu", "relu_log"] | None = None,
53 add_marker_tokens: bool = False,
54 embedding_dim: int = 768,
55 projection: Literal["linear", "linear_no_bias", "mlm"] | None = "linear",
56 **kwargs,
57 ):
58 """Initializes a bi-encoder configuration.
59
60 :param similarity_function: Similarity function to compute scores between query and document embeddings,
61 defaults to "dot"
62 :type similarity_function: Literal['cosine', 'dot'], optional
63 :param query_expansion: Whether to expand queries with mask tokens, defaults to False
64 :type query_expansion: bool, optional
65 :param attend_to_query_expanded_tokens: Whether to allow query tokens to attend to mask tokens,
66 defaults to False
67 :type attend_to_query_expanded_tokens: bool, optional
68 :param query_pooling_strategy: Whether and how to pool the query token embeddings, defaults to "mean"
69 :type query_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional
70 :param query_mask_scoring_tokens: Whether and which query tokens to ignore during scoring, defaults to None
71 :type query_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional
72 :param query_aggregation_function: How to aggregate similarity scores over query tokens, defaults to "sum"
73 :type query_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional
74 :param doc_expansion: Whether to expand documents with mask tokens, defaults to False
75 :type doc_expansion: bool, optional
76 :param attend_to_doc_expanded_tokens: Whether to allow document tokens to attend to mask tokens,
77 defaults to False
78 :type attend_to_doc_expanded_tokens: bool, optional
79 :param doc_pooling_strategy: Whether andhow to pool document token embeddings, defaults to "mean"
80 :type doc_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional
81 :param doc_mask_scoring_tokens: Whether and which document tokens to ignore during scoring, defaults to None
82 :type doc_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional
83 :param normalize: Whether to normalize query and document embeddings, defaults to False
84 :type normalize: bool, optional
85 :param sparsification: Whether and which sparsification function to apply, defaults to None
86 :type sparsification: Literal['relu', 'relu_log'] | None, optional
87 :param add_marker_tokens: Whether to add extra marker tokens [Q] / [D] to queries / documents, defaults to False
88 :type add_marker_tokens: bool, optional
89 :param embedding_dim: The output embedding dimension, defaults to 768
90 :type embedding_dim: int, optional
91 :param projection: Whether and how to project the output emeddings, defaults to "linear"
92 :type projection: Literal['linear', 'linear_no_bias', 'mlm'] | None, optional
93 """
94 super().__init__(**kwargs)
95 self.similarity_function = similarity_function
96 self.query_expansion = query_expansion
97 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
98 self.query_pooling_strategy = query_pooling_strategy
99 self.query_mask_scoring_tokens = query_mask_scoring_tokens
100 self.query_aggregation_function = query_aggregation_function
101 self.doc_expansion = doc_expansion
102 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
103 self.doc_pooling_strategy = doc_pooling_strategy
104 self.doc_mask_scoring_tokens = doc_mask_scoring_tokens
105 self.normalize = normalize
106 self.sparsification = sparsification
107 self.add_marker_tokens = add_marker_tokens
108 self.embedding_dim = embedding_dim
109 self.projection = projection
110
111 def to_dict(self) -> Dict[str, Any]:
112 output = super().to_dict()
113 if "query_mask_scoring_tokens" in output:
114 output.pop("query_mask_scoring_tokens")
115 if "doc_mask_scoring_tokens" in output:
116 output.pop("doc_mask_scoring_tokens")
117 return output
118
119 def save_pretrained(self, save_directory: str | PathLike, push_to_hub: bool = False, **kwargs):
120 with open(os.path.join(save_directory, "mask_scoring_tokens.json"), "w") as f:
121 json.dump({"query": self.query_mask_scoring_tokens, "doc": self.doc_mask_scoring_tokens}, f)
122 return super().save_pretrained(save_directory, push_to_hub, **kwargs)
123
124 @classmethod
125 def get_config_dict(
126 cls, pretrained_model_name_or_path: str | PathLike, **kwargs
127 ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
128 config_dict, kwargs = super().get_config_dict(pretrained_model_name_or_path, **kwargs)
129 mask_scoring_tokens = None
130 mask_scoring_tokens_path = os.path.join(pretrained_model_name_or_path, "mask_scoring_tokens.json")
131 if os.path.exists(mask_scoring_tokens_path):
132 with open(mask_scoring_tokens_path) as f:
133 mask_scoring_tokens = json.load(f)
134 config_dict["query_mask_scoring_tokens"] = mask_scoring_tokens["query"]
135 config_dict["doc_mask_scoring_tokens"] = mask_scoring_tokens["doc"]
136 return config_dict, kwargs