Source code for lightning_ir.bi_encoder.model

  1import warnings
  2from dataclasses import dataclass
  3from functools import wraps
  4from string import punctuation
  5from typing import Callable, Iterable, Literal, Sequence, Tuple, overload
  6
  7import torch
  8from transformers import BatchEncoding
  9from transformers.activations import ACT2FN
 10
 11from ..base import LightningIRModel, LightningIROutput
 12from ..base.model import _batch_encoding
 13from . import BiEncoderConfig
 14
 15
[docs] 16class MLMHead(torch.nn.Module):
[docs] 17 def __init__(self, config: BiEncoderConfig) -> None: 18 super().__init__() 19 self.config = config 20 self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size) 21 if isinstance(config.hidden_act, str): 22 self.transform_act_fn = ACT2FN[config.hidden_act] 23 else: 24 self.transform_act_fn = config.hidden_act 25 self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 26 self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 27 self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size)) 28 29 self.decoder.bias = self.bias
30 31 def _tie_weights(self): 32 self.decoder.bias = self.bias 33 34 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 35 hidden_states = self.dense(hidden_states) 36 hidden_states = self.transform_act_fn(hidden_states) 37 hidden_states = self.LayerNorm(hidden_states) 38 hidden_states = self.decoder(hidden_states) 39 return hidden_states
40 41
[docs] 42@dataclass 43class BiEncoderEmbedding: 44 embeddings: torch.Tensor 45 scoring_mask: torch.Tensor 46 47 @overload 48 def to(self, device: torch.device, /) -> "BiEncoderEmbedding": ... 49 50 @overload 51 def to(self, other: "BiEncoderEmbedding", /) -> "BiEncoderEmbedding": ... 52 53 def to(self, device) -> "BiEncoderEmbedding": 54 if isinstance(device, BiEncoderEmbedding): 55 device = device.device 56 self.embeddings.to() 57 self.embeddings = self.embeddings.to(device) 58 self.scoring_mask = self.scoring_mask.to(device) 59 return self 60 61 @property 62 def device(self) -> torch.device: 63 if self.embeddings.device != self.scoring_mask.device: 64 raise ValueError("Embeddings and scoring_mask must be on the same device") 65 return self.embeddings.device 66 67 def items(self) -> Iterable[Tuple[str, torch.Tensor]]: 68 for field in self.__dataclass_fields__: 69 yield field, getattr(self, field)
70 71
[docs] 72@dataclass 73class BiEncoderOutput(LightningIROutput): 74 query_embeddings: BiEncoderEmbedding | None = None 75 doc_embeddings: BiEncoderEmbedding | None = None
76 77
[docs] 78class BiEncoderModel(LightningIRModel): 79 80 _tied_weights_keys = ["projection.decoder.bias", "projection.decoder.weight", "encoder.embed_tokens.weight"] 81 _keys_to_ignore_on_load_unexpected = [r"decoder"] 82 83 config_class = BiEncoderConfig 84
[docs] 85 def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None: 86 super().__init__(config, *args, **kwargs) 87 self.config: BiEncoderConfig 88 self.scoring_function = ScoringFunction(self.config) 89 self.projection: torch.nn.Linear | MLMHead | None = None 90 if self.config.projection is not None: 91 if "linear" in self.config.projection: 92 self.projection = torch.nn.Linear( 93 self.config.hidden_size, 94 self.config.embedding_dim, 95 bias="no_bias" not in self.config.projection, 96 ) 97 elif self.config.projection == "mlm": 98 self.projection = MLMHead(config) 99 else: 100 raise ValueError(f"Unknown projection {self.config.projection}") 101 else: 102 if self.config.embedding_dim != self.config.hidden_size: 103 warnings.warn( 104 "No projection is used but embedding_dim != hidden_size. " 105 "The output embeddings will not have embedding_size dimensions." 106 ) 107 108 self.query_mask_scoring_input_ids: torch.Tensor | None = None 109 self.doc_mask_scoring_input_ids: torch.Tensor | None = None 110 self.add_mask_scoring_input_ids()
111 112 @classmethod 113 def _load_pretrained_model( 114 cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs 115 ): 116 if model.config.projection == "mlm": 117 has_base_model_prefix = any(s.startswith(model.base_model_prefix) for s in state_dict.keys()) 118 prefix = model.base_model_prefix + "." if has_base_model_prefix else "" 119 for key in list(state_dict.keys()): 120 if key.startswith("cls"): 121 new_key = prefix + key.replace("cls.predictions", "projection").replace(".transform", "") 122 state_dict[new_key] = state_dict.pop(key) 123 loaded_keys[loaded_keys.index(key)] = new_key 124 return super()._load_pretrained_model( 125 model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs 126 ) 127 128 def get_output_embeddings(self) -> torch.nn.Module | None: 129 if isinstance(self.projection, MLMHead): 130 return self.projection.decoder 131 return None 132 133 def add_mask_scoring_input_ids(self) -> None: 134 for sequence in ("query", "doc"): 135 mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens") 136 if mask_scoring_tokens is None: 137 continue 138 if mask_scoring_tokens == "punctuation": 139 mask_scoring_tokens = list(punctuation) 140 try: 141 from transformers import AutoTokenizer 142 143 tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path) 144 except OSError: 145 raise ValueError("Can't use token scoring masking if the checkpoint does not have a tokenizer.") 146 mask_scoring_input_ids = [] 147 for token in mask_scoring_tokens: 148 if token not in tokenizer.vocab: 149 raise ValueError(f"Token {token} not in tokenizer vocab") 150 mask_scoring_input_ids.append(tokenizer.vocab[token]) 151 setattr( 152 self, 153 f"{sequence}_mask_scoring_input_ids", 154 torch.tensor(mask_scoring_input_ids, dtype=torch.long), 155 ) 156 157 def forward( 158 self, 159 query_encoding: BatchEncoding | None, 160 doc_encoding: BatchEncoding | None, 161 num_docs: Sequence[int] | int | None = None, 162 ) -> BiEncoderOutput: 163 query_embeddings = None 164 if query_encoding is not None: 165 query_embeddings = self.encode_query(query_encoding) 166 doc_embeddings = None 167 if doc_encoding is not None: 168 doc_embeddings = self.encode_doc(doc_encoding) 169 scores = None 170 if doc_embeddings is not None and query_embeddings is not None: 171 scores = self.score(query_embeddings, doc_embeddings, num_docs) 172 return BiEncoderOutput(scores=scores, query_embeddings=query_embeddings, doc_embeddings=doc_embeddings) 173 174 def encode_query(self, encoding: BatchEncoding) -> BiEncoderEmbedding: 175 return self._encode( 176 encoding, 177 expansion=self.config.query_expansion, 178 pooling_strategy=self.config.query_pooling_strategy, 179 mask_scoring_input_ids=self.query_mask_scoring_input_ids, 180 ) 181 182 def encode_doc(self, encoding: BatchEncoding) -> BiEncoderEmbedding: 183 return self._encode( 184 encoding, 185 expansion=self.config.doc_expansion, 186 pooling_strategy=self.config.doc_pooling_strategy, 187 mask_scoring_input_ids=self.doc_mask_scoring_input_ids, 188 ) 189 190 @_batch_encoding 191 def _encode( 192 self, 193 encoding: BatchEncoding, 194 expansion: bool = False, 195 pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None, 196 mask_scoring_input_ids: torch.Tensor | None = None, 197 ) -> BiEncoderEmbedding: 198 embeddings = self._backbone_forward(**encoding).last_hidden_state 199 if self.projection is not None: 200 embeddings = self.projection(embeddings) 201 embeddings = self._sparsification(embeddings, self.config.sparsification) 202 embeddings = self._pooling(embeddings, encoding["attention_mask"], pooling_strategy) 203 if self.config.normalize: 204 embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 205 scoring_mask = self._scoring_mask( 206 encoding["input_ids"], 207 encoding["attention_mask"], 208 expansion, 209 pooling_strategy, 210 mask_scoring_input_ids, 211 ) 212 return BiEncoderEmbedding(embeddings, scoring_mask) 213 214 def query_scoring_mask(self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None) -> torch.Tensor: 215 return self._scoring_mask( 216 input_ids, 217 attention_mask, 218 expansion=self.config.query_expansion, 219 pooling_strategy=self.config.query_pooling_strategy, 220 mask_scoring_input_ids=self.config.query_mask_scoring_input_ids, 221 ) 222 223 def doc_scoring_mask(self, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None) -> torch.Tensor: 224 return self._scoring_mask( 225 input_ids, 226 attention_mask, 227 expansion=self.config.query_expansion, 228 pooling_strategy=self.config.doc_pooling_strategy, 229 mask_scoring_input_ids=self.config.doc_mask_scoring_input_ids, 230 ) 231 232 def _scoring_mask( 233 self, 234 input_ids: torch.Tensor | None, 235 attention_mask: torch.Tensor | None, 236 expansion: bool, 237 pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None, 238 mask_scoring_input_ids: torch.Tensor | None = None, 239 ) -> torch.Tensor: 240 if input_ids is not None: 241 shape = input_ids.shape 242 device = input_ids.device 243 elif attention_mask is not None: 244 shape = attention_mask.shape 245 device = attention_mask.device 246 else: 247 raise ValueError("Pass either input_ids or attention_mask") 248 if pooling_strategy is not None: 249 return torch.ones((shape[0], 1), dtype=torch.bool, device=device) 250 scoring_mask = attention_mask 251 if expansion or scoring_mask is None: 252 scoring_mask = torch.ones(shape, dtype=torch.bool, device=device) 253 scoring_mask = scoring_mask.bool() 254 if mask_scoring_input_ids is not None and input_ids is not None: 255 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(device)).any(-1) 256 scoring_mask = scoring_mask & ~ignore_mask 257 return scoring_mask 258 259 def score( 260 self, 261 query_embeddings: BiEncoderEmbedding, 262 doc_embeddings: BiEncoderEmbedding, 263 num_docs: Sequence[int] | int | None = None, 264 ) -> torch.Tensor: 265 scores = self.scoring_function.score(query_embeddings, doc_embeddings, num_docs=num_docs) 266 return scores
267 268 269def _batch_scoring( 270 similarity_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] 271) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: 272 BATCH_SIZE = 1024 273 274 @wraps(similarity_function) 275 def batch_similarity_function(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 276 if x.shape[0] <= BATCH_SIZE: 277 return similarity_function(x, y) 278 out = torch.zeros(x.shape[0], x.shape[1], y.shape[2], device=x.device, dtype=x.dtype) 279 for i in range(0, x.shape[0], BATCH_SIZE): 280 out[i : i + BATCH_SIZE] = similarity_function(x[i : i + BATCH_SIZE], y[i : i + BATCH_SIZE]) 281 return out 282 283 return batch_similarity_function 284 285
[docs] 286class ScoringFunction(torch.nn.Module):
[docs] 287 def __init__(self, config: BiEncoderConfig) -> None: 288 super().__init__() 289 self.config = config 290 if self.config.similarity_function == "cosine": 291 self.similarity_function = self.cosine_similarity 292 elif self.config.similarity_function == "l2": 293 self.similarity_function = self.l2_similarity 294 elif self.config.similarity_function == "dot": 295 self.similarity_function = self.dot_similarity 296 else: 297 raise ValueError(f"Unknown similarity function {self.config.similarity_function}") 298 self.query_aggregation_function = self.config.query_aggregation_function
299 300 def compute_similarity( 301 self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding 302 ) -> torch.Tensor: 303 # if torch.cuda.is_available(): 304 # # bfloat16 similarity yields weird values with gpu, so we use fp16 instead 305 # # TODO investigate why, all values are a factor of 1/4 306 # query_tensor = query_tensor.cuda().half() 307 # doc_tensor = doc_tensor.cuda().half() 308 309 # TODO compute similarity only for non-masked values 310 similarity = self.similarity_function(query_embeddings.embeddings, doc_embeddings.embeddings) 311 return similarity 312 313 @staticmethod 314 @_batch_scoring 315 def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 316 return torch.nn.functional.cosine_similarity(x, y, dim=-1) 317 318 @staticmethod 319 @_batch_scoring 320 def l2_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 321 return -1 * torch.cdist(x, y).squeeze(-2) 322 323 @staticmethod 324 @_batch_scoring 325 def dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 326 return torch.matmul(x, y.transpose(-1, -2)).squeeze(-2) 327 328 def parse_num_docs( 329 self, 330 query_embeddings: BiEncoderEmbedding, 331 doc_embeddings: BiEncoderEmbedding, 332 num_docs: int | Sequence[int] | None, 333 ) -> torch.Tensor: 334 batch_size = query_embeddings.embeddings.shape[0] 335 if isinstance(num_docs, int): 336 num_docs = [num_docs] * batch_size 337 if isinstance(num_docs, list): 338 if sum(num_docs) != doc_embeddings.embeddings.shape[0] or len(num_docs) != batch_size: 339 raise ValueError("Num docs does not match doc embeddings") 340 if num_docs is None: 341 if doc_embeddings.embeddings.shape[0] % batch_size != 0: 342 raise ValueError("Docs are not evenly distributed in _batch, but no num_docs provided") 343 num_docs = [doc_embeddings.embeddings.shape[0] // batch_size] * batch_size 344 return torch.tensor(num_docs, device=query_embeddings.embeddings.device) 345 346 def expand_query_embeddings( 347 self, 348 embeddings: BiEncoderEmbedding, 349 num_docs: torch.Tensor, 350 ) -> BiEncoderEmbedding: 351 return BiEncoderEmbedding( 352 embeddings.embeddings.repeat_interleave(num_docs, dim=0).unsqueeze(2), 353 embeddings.scoring_mask.repeat_interleave(num_docs, dim=0).unsqueeze(2), 354 ) 355 356 def expand_doc_embeddings( 357 self, 358 embeddings: BiEncoderEmbedding, 359 num_docs: torch.Tensor, 360 ) -> BiEncoderEmbedding: 361 return BiEncoderEmbedding(embeddings.embeddings.unsqueeze(1), embeddings.scoring_mask.unsqueeze(1)) 362 363 def aggregate( 364 self, 365 scores: torch.Tensor, 366 mask: torch.Tensor | None, 367 query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"] | None, 368 dim: int, 369 ) -> torch.Tensor: 370 if query_aggregation_function is None: 371 return scores 372 if query_aggregation_function == "max": 373 if mask is not None: 374 scores = scores.masked_fill(~mask, float("-inf")) 375 return scores.max(dim, keepdim=True).values 376 if query_aggregation_function == "sum": 377 if mask is not None: 378 scores = scores.masked_fill(~mask, 0) 379 return scores.sum(dim, keepdim=True) 380 if mask is None: 381 shape = list(scores.shape) 382 shape[dim] = 1 383 num_non_masked = torch.full(shape, scores.shape[dim], device=scores.device) 384 else: 385 num_non_masked = mask.sum(dim, keepdim=True) 386 if query_aggregation_function == "mean": 387 return torch.where(num_non_masked == 0, 0, scores.sum(dim, keepdim=True) / num_non_masked) 388 if query_aggregation_function == "harmonic_mean": 389 return torch.where( 390 num_non_masked == 0, 391 0, 392 num_non_masked / (1 / scores).sum(dim, keepdim=True), 393 ) 394 raise ValueError(f"Unknown aggregation {query_aggregation_function}") 395 396 def score( 397 self, 398 query_embeddings: BiEncoderEmbedding, 399 doc_embeddings: BiEncoderEmbedding, 400 num_docs: Sequence[int] | int | None = None, 401 ) -> torch.Tensor: 402 num_docs_t = self.parse_num_docs(query_embeddings, doc_embeddings, num_docs) 403 query_embeddings = self.expand_query_embeddings(query_embeddings, num_docs_t) 404 doc_embeddings = self.expand_doc_embeddings(doc_embeddings, num_docs_t) 405 similarity = self.compute_similarity(query_embeddings, doc_embeddings) 406 scores = self.aggregate(similarity, doc_embeddings.scoring_mask, "max", -1) 407 scores = self.aggregate(scores, query_embeddings.scoring_mask, self.query_aggregation_function, -2) 408 return scores[..., 0, 0]