1from pathlib import Path
2from typing import List, Sequence, Tuple
3
4import torch
5
6from ..base import LightningIRModule
7from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch
8from ..loss.loss import EmbeddingLossFunction, InBatchLossFunction, LossFunction, ScoringLossFunction
9from ..retrieve import SearchConfig, Searcher
10from .config import BiEncoderConfig
11from .model import BiEncoderEmbedding, BiEncoderModel, BiEncoderOutput
12from .tokenizer import BiEncoderTokenizer
13
14
[docs]
15class BiEncoderModule(LightningIRModule):
[docs]
16 def __init__(
17 self,
18 model_name_or_path: str | None = None,
19 config: BiEncoderConfig | None = None,
20 model: BiEncoderModel | None = None,
21 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
22 evaluation_metrics: Sequence[str] | None = None,
23 index_dir: Path | None = None,
24 search_config: SearchConfig | None = None,
25 ):
26 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics)
27 self.model: BiEncoderModel
28 self.config: BiEncoderConfig
29 self.tokenizer: BiEncoderTokenizer
30 self.scoring_function = self.model.scoring_function
31 if self.config.add_marker_tokens and len(self.tokenizer) > self.config.vocab_size:
32 self.model.resize_token_embeddings(len(self.tokenizer), 8)
33 self._searcher = None
34 self.search_config = search_config
35 self.index_dir = index_dir
36
37 @property
38 def searcher(self) -> Searcher | None:
39 return self._searcher
40
41 @searcher.setter
42 def searcher(self, searcher: Searcher):
43 self._searcher = searcher
44
45 def on_test_start(self) -> None:
46 if self.search_config is not None and self.index_dir is not None:
47 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, self)
48 return super().on_test_start()
49
50 def forward(self, batch: RankBatch | IndexBatch | SearchBatch) -> BiEncoderOutput:
51 queries = getattr(batch, "queries", None)
52 docs = getattr(batch, "docs", None)
53 num_docs = None
54 if isinstance(batch, RankBatch):
55 num_docs = None if docs is None else [len(d) for d in docs]
56 docs = [d for nested in docs for d in nested] if docs is not None else None
57 encodings = self.prepare_input(queries, docs, num_docs)
58
59 if not encodings:
60 raise ValueError("No encodings were generated.")
61 output = self.model.forward(
62 encodings.get("query_encoding", None), encodings.get("doc_encoding", None), num_docs
63 )
64 if isinstance(batch, SearchBatch) and self.searcher is not None:
65 scores, doc_ids, num_docs = self.searcher.search(output)
66 output.scores = scores
67 cum_num_docs = [0] + [sum(num_docs[: i + 1]) for i in range(len(num_docs))]
68 doc_ids = tuple(tuple(doc_ids[cum_num_docs[i] : cum_num_docs[i + 1]]) for i in range(len(num_docs)))
69 batch.doc_ids = doc_ids
70 return output
71
72 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> BiEncoderOutput:
73 return super().score(queries, docs)
74
75 def compute_losses(self, batch: TrainBatch) -> List[torch.Tensor]:
76 if self.loss_functions is None:
77 raise ValueError("Loss function is not set")
78 output = self.forward(batch)
79
80 scores = output.scores
81 query_embeddings = output.query_embeddings
82 doc_embeddings = output.doc_embeddings
83 if batch.targets is None or query_embeddings is None or doc_embeddings is None or scores is None:
84 raise ValueError(
85 "targets, scores, query_embeddings, and doc_embeddings must be set in " "the output and batch"
86 )
87
88 num_queries = len(batch.queries)
89 scores = scores.view(num_queries, -1)
90 targets = batch.targets.view(*scores.shape, -1)
91 losses = []
92 for loss_function, _ in self.loss_functions:
93 if isinstance(loss_function, InBatchLossFunction):
94 pos_idcs, neg_idcs = loss_function.get_ib_idcs(*scores.shape)
95 ib_doc_embeddings = self.get_ib_doc_embeddings(doc_embeddings, pos_idcs, neg_idcs, num_queries)
96 ib_scores = self.model.score(query_embeddings, ib_doc_embeddings)
97 ib_scores = ib_scores.view(num_queries, -1)
98 losses.append(loss_function.compute_loss(ib_scores))
99 elif isinstance(loss_function, EmbeddingLossFunction):
100 losses.append(loss_function.compute_loss(query_embeddings.embeddings, doc_embeddings.embeddings))
101 elif isinstance(loss_function, ScoringLossFunction):
102 losses.append(loss_function.compute_loss(scores, targets))
103 else:
104 raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}")
105 if self.config.sparsification is not None:
106 query_num_nonzero = (
107 torch.nonzero(query_embeddings.embeddings).shape[0] / query_embeddings.embeddings.shape[0]
108 )
109 doc_num_nonzero = torch.nonzero(doc_embeddings.embeddings).shape[0] / doc_embeddings.embeddings.shape[0]
110 self.log("query_num_nonzero", query_num_nonzero)
111 self.log("doc_num_nonzero", doc_num_nonzero)
112 return losses
113
114 def get_ib_doc_embeddings(
115 self,
116 embeddings: BiEncoderEmbedding,
117 pos_idcs: torch.Tensor,
118 neg_idcs: torch.Tensor,
119 num_queries: int,
120 ) -> BiEncoderEmbedding:
121 _, seq_len, emb_dim = embeddings.embeddings.shape
122 ib_embeddings = torch.cat(
123 [
124 embeddings.embeddings[pos_idcs].view(num_queries, -1, seq_len, emb_dim),
125 embeddings.embeddings[neg_idcs].view(num_queries, -1, seq_len, emb_dim),
126 ],
127 dim=1,
128 ).view(-1, seq_len, emb_dim)
129 ib_scoring_mask = torch.cat(
130 [
131 embeddings.scoring_mask[pos_idcs].view(num_queries, -1, seq_len),
132 embeddings.scoring_mask[neg_idcs].view(num_queries, -1, seq_len),
133 ],
134 dim=1,
135 ).view(-1, seq_len)
136 return BiEncoderEmbedding(ib_embeddings, ib_scoring_mask)
137
138 def validation_step(
139 self,
140 batch: TrainBatch | IndexBatch | SearchBatch | RankBatch,
141 batch_idx: int,
142 dataloader_idx: int = 0,
143 ) -> BiEncoderOutput:
144 if isinstance(batch, IndexBatch):
145 return self.forward(batch)
146 if isinstance(batch, (RankBatch, TrainBatch, SearchBatch)):
147 return super().validation_step(batch, batch_idx, dataloader_idx)
148 raise ValueError(f"Unknown batch type {type(batch)}")