Source code for lightning_ir.bi_encoder.module

  1from pathlib import Path
  2from typing import List, Sequence, Tuple
  3
  4import torch
  5
  6from ..base import LightningIRModule
  7from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch
  8from ..loss.loss import EmbeddingLossFunction, InBatchLossFunction, LossFunction, ScoringLossFunction
  9from ..retrieve import SearchConfig, Searcher
 10from .config import BiEncoderConfig
 11from .model import BiEncoderEmbedding, BiEncoderModel, BiEncoderOutput
 12from .tokenizer import BiEncoderTokenizer
 13
 14
[docs] 15class BiEncoderModule(LightningIRModule):
[docs] 16 def __init__( 17 self, 18 model_name_or_path: str | None = None, 19 config: BiEncoderConfig | None = None, 20 model: BiEncoderModel | None = None, 21 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 22 evaluation_metrics: Sequence[str] | None = None, 23 index_dir: Path | None = None, 24 search_config: SearchConfig | None = None, 25 ): 26 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics) 27 self.model: BiEncoderModel 28 self.config: BiEncoderConfig 29 self.tokenizer: BiEncoderTokenizer 30 self.scoring_function = self.model.scoring_function 31 if self.config.add_marker_tokens and len(self.tokenizer) > self.config.vocab_size: 32 self.model.resize_token_embeddings(len(self.tokenizer), 8) 33 self._searcher = None 34 self.search_config = search_config 35 self.index_dir = index_dir
36 37 @property 38 def searcher(self) -> Searcher | None: 39 return self._searcher 40 41 @searcher.setter 42 def searcher(self, searcher: Searcher): 43 self._searcher = searcher 44 45 def on_test_start(self) -> None: 46 if self.search_config is not None and self.index_dir is not None: 47 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, self) 48 return super().on_test_start() 49 50 def forward(self, batch: RankBatch | IndexBatch | SearchBatch) -> BiEncoderOutput: 51 queries = getattr(batch, "queries", None) 52 docs = getattr(batch, "docs", None) 53 num_docs = None 54 if isinstance(batch, RankBatch): 55 num_docs = None if docs is None else [len(d) for d in docs] 56 docs = [d for nested in docs for d in nested] if docs is not None else None 57 encodings = self.prepare_input(queries, docs, num_docs) 58 59 if not encodings: 60 raise ValueError("No encodings were generated.") 61 output = self.model.forward( 62 encodings.get("query_encoding", None), encodings.get("doc_encoding", None), num_docs 63 ) 64 if isinstance(batch, SearchBatch) and self.searcher is not None: 65 scores, doc_ids, num_docs = self.searcher.search(output) 66 output.scores = scores 67 cum_num_docs = [0] + [sum(num_docs[: i + 1]) for i in range(len(num_docs))] 68 doc_ids = tuple(tuple(doc_ids[cum_num_docs[i] : cum_num_docs[i + 1]]) for i in range(len(num_docs))) 69 batch.doc_ids = doc_ids 70 return output 71 72 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> BiEncoderOutput: 73 return super().score(queries, docs) 74 75 def compute_losses(self, batch: TrainBatch) -> List[torch.Tensor]: 76 if self.loss_functions is None: 77 raise ValueError("Loss function is not set") 78 output = self.forward(batch) 79 80 scores = output.scores 81 query_embeddings = output.query_embeddings 82 doc_embeddings = output.doc_embeddings 83 if batch.targets is None or query_embeddings is None or doc_embeddings is None or scores is None: 84 raise ValueError( 85 "targets, scores, query_embeddings, and doc_embeddings must be set in " "the output and batch" 86 ) 87 88 num_queries = len(batch.queries) 89 scores = scores.view(num_queries, -1) 90 targets = batch.targets.view(*scores.shape, -1) 91 losses = [] 92 for loss_function, _ in self.loss_functions: 93 if isinstance(loss_function, InBatchLossFunction): 94 pos_idcs, neg_idcs = loss_function.get_ib_idcs(*scores.shape) 95 ib_doc_embeddings = self.get_ib_doc_embeddings(doc_embeddings, pos_idcs, neg_idcs, num_queries) 96 ib_scores = self.model.score(query_embeddings, ib_doc_embeddings) 97 ib_scores = ib_scores.view(num_queries, -1) 98 losses.append(loss_function.compute_loss(ib_scores)) 99 elif isinstance(loss_function, EmbeddingLossFunction): 100 losses.append(loss_function.compute_loss(query_embeddings.embeddings, doc_embeddings.embeddings)) 101 elif isinstance(loss_function, ScoringLossFunction): 102 losses.append(loss_function.compute_loss(scores, targets)) 103 else: 104 raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}") 105 if self.config.sparsification is not None: 106 query_num_nonzero = ( 107 torch.nonzero(query_embeddings.embeddings).shape[0] / query_embeddings.embeddings.shape[0] 108 ) 109 doc_num_nonzero = torch.nonzero(doc_embeddings.embeddings).shape[0] / doc_embeddings.embeddings.shape[0] 110 self.log("query_num_nonzero", query_num_nonzero) 111 self.log("doc_num_nonzero", doc_num_nonzero) 112 return losses 113 114 def get_ib_doc_embeddings( 115 self, 116 embeddings: BiEncoderEmbedding, 117 pos_idcs: torch.Tensor, 118 neg_idcs: torch.Tensor, 119 num_queries: int, 120 ) -> BiEncoderEmbedding: 121 _, seq_len, emb_dim = embeddings.embeddings.shape 122 ib_embeddings = torch.cat( 123 [ 124 embeddings.embeddings[pos_idcs].view(num_queries, -1, seq_len, emb_dim), 125 embeddings.embeddings[neg_idcs].view(num_queries, -1, seq_len, emb_dim), 126 ], 127 dim=1, 128 ).view(-1, seq_len, emb_dim) 129 ib_scoring_mask = torch.cat( 130 [ 131 embeddings.scoring_mask[pos_idcs].view(num_queries, -1, seq_len), 132 embeddings.scoring_mask[neg_idcs].view(num_queries, -1, seq_len), 133 ], 134 dim=1, 135 ).view(-1, seq_len) 136 return BiEncoderEmbedding(ib_embeddings, ib_scoring_mask) 137 138 def validation_step( 139 self, 140 batch: TrainBatch | IndexBatch | SearchBatch | RankBatch, 141 batch_idx: int, 142 dataloader_idx: int = 0, 143 ) -> BiEncoderOutput: 144 if isinstance(batch, IndexBatch): 145 return self.forward(batch) 146 if isinstance(batch, (RankBatch, TrainBatch, SearchBatch)): 147 return super().validation_step(batch, batch_idx, dataloader_idx) 148 raise ValueError(f"Unknown batch type {type(batch)}")