1import warnings
2from dataclasses import dataclass
3from functools import wraps
4from string import punctuation
5from typing import Callable, Iterable, Literal, Sequence, Tuple, overload
6
7import torch
8from transformers import BatchEncoding
9from transformers.activations import ACT2FN
10
11from ..base import LightningIRModel, LightningIROutput
12from ..base.model import _batch_encoding
13from . import BiEncoderConfig
14
15
[docs]
16class MLMHead(torch.nn.Module):
[docs]
17 def __init__(self, config: BiEncoderConfig) -> None:
18 super().__init__()
19 self.config = config
20 self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
21 if isinstance(config.hidden_act, str):
22 self.transform_act_fn = ACT2FN[config.hidden_act]
23 else:
24 self.transform_act_fn = config.hidden_act
25 self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
26 self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
27 self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size))
28
29 self.decoder.bias = self.bias
30
31 def _tie_weights(self):
32 self.decoder.bias = self.bias
33
34 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
35 hidden_states = self.dense(hidden_states)
36 hidden_states = self.transform_act_fn(hidden_states)
37 hidden_states = self.LayerNorm(hidden_states)
38 hidden_states = self.decoder(hidden_states)
39 return hidden_states
40
41
[docs]
42@dataclass
43class BiEncoderEmbedding:
44 embeddings: torch.Tensor
45 scoring_mask: torch.Tensor
46
47 @overload
48 def to(self, device: torch.device, /) -> "BiEncoderEmbedding": ...
49
50 @overload
51 def to(self, other: "BiEncoderEmbedding", /) -> "BiEncoderEmbedding": ...
52
53 def to(self, device) -> "BiEncoderEmbedding":
54 if isinstance(device, BiEncoderEmbedding):
55 device = device.device
56 self.embeddings.to()
57 self.embeddings = self.embeddings.to(device)
58 self.scoring_mask = self.scoring_mask.to(device)
59 return self
60
61 @property
62 def device(self) -> torch.device:
63 if self.embeddings.device != self.scoring_mask.device:
64 raise ValueError("Embeddings and scoring_mask must be on the same device")
65 return self.embeddings.device
66
67 def items(self) -> Iterable[Tuple[str, torch.Tensor]]:
68 for field in self.__dataclass_fields__:
69 yield field, getattr(self, field)
70
71
[docs]
72@dataclass
73class BiEncoderOutput(LightningIROutput):
74 query_embeddings: BiEncoderEmbedding | None = None
75 doc_embeddings: BiEncoderEmbedding | None = None
76
77
[docs]
78class BiEncoderModel(LightningIRModel):
79
80 _tied_weights_keys = ["projection.decoder.bias", "projection.decoder.weight", "encoder.embed_tokens.weight"]
81 _keys_to_ignore_on_load_unexpected = [r"decoder"]
82
83 config_class = BiEncoderConfig
84
[docs]
85 def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None:
86 super().__init__(config, *args, **kwargs)
87 self.config: BiEncoderConfig
88 self.scoring_function = ScoringFunction(self.config)
89 self.projection: torch.nn.Linear | MLMHead | None = None
90 if self.config.projection is not None:
91 if "linear" in self.config.projection:
92 self.projection = torch.nn.Linear(
93 self.config.hidden_size,
94 self.config.embedding_dim,
95 bias="no_bias" not in self.config.projection,
96 )
97 elif self.config.projection == "mlm":
98 self.projection = MLMHead(config)
99 else:
100 raise ValueError(f"Unknown projection {self.config.projection}")
101 else:
102 if self.config.embedding_dim != self.config.hidden_size:
103 warnings.warn(
104 "No projection is used but embedding_dim != hidden_size. "
105 "The output embeddings will not have embedding_size dimensions."
106 )
107
108 self.query_mask_scoring_input_ids: torch.Tensor | None = None
109 self.doc_mask_scoring_input_ids: torch.Tensor | None = None
110 self.add_mask_scoring_input_ids()
111
112 @classmethod
113 def _load_pretrained_model(
114 cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs
115 ):
116 if model.config.projection == "mlm":
117 has_base_model_prefix = any(s.startswith(model.base_model_prefix) for s in state_dict.keys())
118 prefix = model.base_model_prefix + "." if has_base_model_prefix else ""
119 for key in list(state_dict.keys()):
120 if key.startswith("cls"):
121 new_key = prefix + key.replace("cls.predictions", "projection").replace(".transform", "")
122 state_dict[new_key] = state_dict.pop(key)
123 loaded_keys[loaded_keys.index(key)] = new_key
124 return super()._load_pretrained_model(
125 model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs
126 )
127
128 def get_output_embeddings(self) -> torch.nn.Module | None:
129 if isinstance(self.projection, MLMHead):
130 return self.projection.decoder
131 return None
132
133 def add_mask_scoring_input_ids(self) -> None:
134 for sequence in ("query", "doc"):
135 mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens")
136 if mask_scoring_tokens is None:
137 continue
138 if mask_scoring_tokens == "punctuation":
139 mask_scoring_tokens = list(punctuation)
140 try:
141 from transformers import AutoTokenizer
142
143 tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
144 except OSError:
145 raise ValueError("Can't use token scoring masking if the checkpoint does not have a tokenizer.")
146 mask_scoring_input_ids = []
147 for token in mask_scoring_tokens:
148 if token not in tokenizer.vocab:
149 raise ValueError(f"Token {token} not in tokenizer vocab")
150 mask_scoring_input_ids.append(tokenizer.vocab[token])
151 setattr(
152 self,
153 f"{sequence}_mask_scoring_input_ids",
154 torch.tensor(mask_scoring_input_ids, dtype=torch.long),
155 )
156
157 def forward(
158 self,
159 query_encoding: BatchEncoding | None,
160 doc_encoding: BatchEncoding | None,
161 num_docs: Sequence[int] | int | None = None,
162 ) -> BiEncoderOutput:
163 query_embeddings = None
164 if query_encoding is not None:
165 query_embeddings = self.encode_query(query_encoding)
166 doc_embeddings = None
167 if doc_encoding is not None:
168 doc_embeddings = self.encode_doc(doc_encoding)
169 scores = None
170 if doc_embeddings is not None and query_embeddings is not None:
171 scores = self.score(query_embeddings, doc_embeddings, num_docs)
172 return BiEncoderOutput(scores=scores, query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
173
174 def encode_query(self, encoding: BatchEncoding) -> BiEncoderEmbedding:
175 return self._encode(
176 encoding,
177 expansion=self.config.query_expansion,
178 pooling_strategy=self.config.query_pooling_strategy,
179 mask_scoring_input_ids=self.query_mask_scoring_input_ids,
180 )
181
182 def encode_doc(self, encoding: BatchEncoding) -> BiEncoderEmbedding:
183 return self._encode(
184 encoding,
185 expansion=self.config.doc_expansion,
186 pooling_strategy=self.config.doc_pooling_strategy,
187 mask_scoring_input_ids=self.doc_mask_scoring_input_ids,
188 )
189
190 @_batch_encoding
191 def _encode(
192 self,
193 encoding: BatchEncoding,
194 expansion: bool = False,
195 pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None,
196 mask_scoring_input_ids: torch.Tensor | None = None,
197 ) -> BiEncoderEmbedding:
198 embeddings = self._backbone_forward(**encoding).last_hidden_state
199 if self.projection is not None:
200 embeddings = self.projection(embeddings)
201 embeddings = self._sparsification(embeddings, self.config.sparsification)
202 embeddings = self._pooling(embeddings, encoding["attention_mask"], pooling_strategy)
203 if self.config.normalize:
204 embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
205 scoring_mask = self._scoring_mask(
206 encoding["input_ids"],
207 encoding["attention_mask"],
208 expansion,
209 pooling_strategy,
210 mask_scoring_input_ids,
211 )
212 return BiEncoderEmbedding(embeddings, scoring_mask)
213
214 def query_scoring_mask(self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None) -> torch.Tensor:
215 return self._scoring_mask(
216 input_ids,
217 attention_mask,
218 expansion=self.config.query_expansion,
219 pooling_strategy=self.config.query_pooling_strategy,
220 mask_scoring_input_ids=self.config.query_mask_scoring_input_ids,
221 )
222
223 def doc_scoring_mask(self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None) -> torch.Tensor:
224 return self._scoring_mask(
225 input_ids,
226 attention_mask,
227 expansion=self.config.query_expansion,
228 pooling_strategy=self.config.doc_pooling_strategy,
229 mask_scoring_input_ids=self.config.doc_mask_scoring_input_ids,
230 )
231
232 def _scoring_mask(
233 self,
234 input_ids: torch.Tensor | None,
235 attention_mask: torch.Tensor | None,
236 expansion: bool,
237 pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None,
238 mask_scoring_input_ids: torch.Tensor | None = None,
239 ) -> torch.Tensor:
240 if input_ids is not None:
241 shape = input_ids.shape
242 device = input_ids.device
243 elif attention_mask is not None:
244 shape = attention_mask.shape
245 device = attention_mask.device
246 else:
247 raise ValueError("Pass either input_ids or attention_mask")
248 if pooling_strategy is not None:
249 return torch.ones((shape[0], 1), dtype=torch.bool, device=device)
250 scoring_mask = attention_mask
251 if expansion or scoring_mask is None:
252 scoring_mask = torch.ones(shape, dtype=torch.bool, device=device)
253 scoring_mask = scoring_mask.bool()
254 if mask_scoring_input_ids is not None and input_ids is not None:
255 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(device)).any(-1)
256 scoring_mask = scoring_mask & ~ignore_mask
257 return scoring_mask
258
259 def score(
260 self,
261 query_embeddings: BiEncoderEmbedding,
262 doc_embeddings: BiEncoderEmbedding,
263 num_docs: Sequence[int] | int | None = None,
264 ) -> torch.Tensor:
265 scores = self.scoring_function.score(query_embeddings, doc_embeddings, num_docs=num_docs)
266 return scores
267
268
269def _batch_scoring(
270 similarity_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
271) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
272 BATCH_SIZE = 1024
273
274 @wraps(similarity_function)
275 def batch_similarity_function(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
276 if x.shape[0] <= BATCH_SIZE:
277 return similarity_function(x, y)
278 out = torch.zeros(x.shape[0], x.shape[1], y.shape[2], device=x.device, dtype=x.dtype)
279 for i in range(0, x.shape[0], BATCH_SIZE):
280 out[i : i + BATCH_SIZE] = similarity_function(x[i : i + BATCH_SIZE], y[i : i + BATCH_SIZE])
281 return out
282
283 return batch_similarity_function
284
285
[docs]
286class ScoringFunction(torch.nn.Module):
[docs]
287 def __init__(self, config: BiEncoderConfig) -> None:
288 super().__init__()
289 self.config = config
290 if self.config.similarity_function == "cosine":
291 self.similarity_function = self.cosine_similarity
292 elif self.config.similarity_function == "l2":
293 self.similarity_function = self.l2_similarity
294 elif self.config.similarity_function == "dot":
295 self.similarity_function = self.dot_similarity
296 else:
297 raise ValueError(f"Unknown similarity function {self.config.similarity_function}")
298 self.query_aggregation_function = self.config.query_aggregation_function
299
300 def compute_similarity(
301 self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding
302 ) -> torch.Tensor:
303 # if torch.cuda.is_available():
304 # # bfloat16 similarity yields weird values with gpu, so we use fp16 instead
305 # # TODO investigate why, all values are a factor of 1/4
306 # query_tensor = query_tensor.cuda().half()
307 # doc_tensor = doc_tensor.cuda().half()
308
309 # TODO compute similarity only for non-masked values
310 similarity = self.similarity_function(query_embeddings.embeddings, doc_embeddings.embeddings)
311 return similarity
312
313 @staticmethod
314 @_batch_scoring
315 def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
316 return torch.nn.functional.cosine_similarity(x, y, dim=-1)
317
318 @staticmethod
319 @_batch_scoring
320 def l2_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
321 return -1 * torch.cdist(x, y).squeeze(-2)
322
323 @staticmethod
324 @_batch_scoring
325 def dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
326 return torch.matmul(x, y.transpose(-1, -2)).squeeze(-2)
327
328 def parse_num_docs(
329 self,
330 query_embeddings: BiEncoderEmbedding,
331 doc_embeddings: BiEncoderEmbedding,
332 num_docs: int | Sequence[int] | None,
333 ) -> torch.Tensor:
334 batch_size = query_embeddings.embeddings.shape[0]
335 if isinstance(num_docs, int):
336 num_docs = [num_docs] * batch_size
337 if isinstance(num_docs, list):
338 if sum(num_docs) != doc_embeddings.embeddings.shape[0] or len(num_docs) != batch_size:
339 raise ValueError("Num docs does not match doc embeddings")
340 if num_docs is None:
341 if doc_embeddings.embeddings.shape[0] % batch_size != 0:
342 raise ValueError("Docs are not evenly distributed in _batch, but no num_docs provided")
343 num_docs = [doc_embeddings.embeddings.shape[0] // batch_size] * batch_size
344 return torch.tensor(num_docs, device=query_embeddings.embeddings.device)
345
346 def expand_query_embeddings(
347 self,
348 embeddings: BiEncoderEmbedding,
349 num_docs: torch.Tensor,
350 ) -> BiEncoderEmbedding:
351 return BiEncoderEmbedding(
352 embeddings.embeddings.repeat_interleave(num_docs, dim=0).unsqueeze(2),
353 embeddings.scoring_mask.repeat_interleave(num_docs, dim=0).unsqueeze(2),
354 )
355
356 def expand_doc_embeddings(
357 self,
358 embeddings: BiEncoderEmbedding,
359 num_docs: torch.Tensor,
360 ) -> BiEncoderEmbedding:
361 return BiEncoderEmbedding(embeddings.embeddings.unsqueeze(1), embeddings.scoring_mask.unsqueeze(1))
362
363 def aggregate(
364 self,
365 scores: torch.Tensor,
366 mask: torch.Tensor | None,
367 query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"] | None,
368 dim: int,
369 ) -> torch.Tensor:
370 if query_aggregation_function is None:
371 return scores
372 if query_aggregation_function == "max":
373 if mask is not None:
374 scores = scores.masked_fill(~mask, float("-inf"))
375 return scores.max(dim, keepdim=True).values
376 if query_aggregation_function == "sum":
377 if mask is not None:
378 scores = scores.masked_fill(~mask, 0)
379 return scores.sum(dim, keepdim=True)
380 if mask is None:
381 shape = list(scores.shape)
382 shape[dim] = 1
383 num_non_masked = torch.full(shape, scores.shape[dim], device=scores.device)
384 else:
385 num_non_masked = mask.sum(dim, keepdim=True)
386 if query_aggregation_function == "mean":
387 return torch.where(num_non_masked == 0, 0, scores.sum(dim, keepdim=True) / num_non_masked)
388 if query_aggregation_function == "harmonic_mean":
389 return torch.where(
390 num_non_masked == 0,
391 0,
392 num_non_masked / (1 / scores).sum(dim, keepdim=True),
393 )
394 raise ValueError(f"Unknown aggregation {query_aggregation_function}")
395
396 def score(
397 self,
398 query_embeddings: BiEncoderEmbedding,
399 doc_embeddings: BiEncoderEmbedding,
400 num_docs: Sequence[int] | int | None = None,
401 ) -> torch.Tensor:
402 num_docs_t = self.parse_num_docs(query_embeddings, doc_embeddings, num_docs)
403 query_embeddings = self.expand_query_embeddings(query_embeddings, num_docs_t)
404 doc_embeddings = self.expand_doc_embeddings(doc_embeddings, num_docs_t)
405 similarity = self.compute_similarity(query_embeddings, doc_embeddings)
406 scores = self.aggregate(similarity, doc_embeddings.scoring_mask, "max", -1)
407 scores = self.aggregate(scores, query_embeddings.scoring_mask, self.query_aggregation_function, -2)
408 return scores[..., 0, 0]