1from __future__ import annotations
2
3from pathlib import Path
4from typing import TYPE_CHECKING, List, Literal, Tuple
5
6import torch
7
8from ..bi_encoder.model import BiEncoderEmbedding
9from .searcher import SearchConfig, Searcher
10
11if TYPE_CHECKING:
12 from ..bi_encoder import BiEncoderModule
13
14
[docs]
15class FaissSearcher(Searcher):
[docs]
16 def __init__(
17 self, index_dir: Path | str, search_config: FaissSearchConfig, module: BiEncoderModule, use_gpu: bool = False
18 ) -> None:
19 import faiss
20
21 self.search_config: FaissSearchConfig
22 self.index = faiss.read_index(str(Path(index_dir) / "index.faiss"))
23 ivf_index = None
24 try:
25 ivf_index = faiss.extract_index_ivf(self.index)
26 except RuntimeError:
27 pass
28 if ivf_index is not None:
29 ivf_index.nprobe = search_config.n_probe
30 quantizer = getattr(ivf_index, "quantizer", None)
31 if quantizer is not None:
32 downcasted_quantizer = faiss.downcast_index(quantizer)
33 hnsw = getattr(downcasted_quantizer, "hnsw", None)
34 if hnsw is not None:
35 hnsw.efSearch = search_config.ef_search
36 super().__init__(index_dir, search_config, module, use_gpu)
37
38 @property
39 def num_embeddings(self) -> int:
40 return self.index.ntotal
41
42 @property
43 def doc_is_single_vector(self) -> bool:
44 return self.num_docs == self.num_embeddings
45
46 def _search(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
47 query_embeddings = query_embeddings.to(self.device)
48 candidate_scores, candidate_doc_idcs = self.candidate_retrieval(query_embeddings)
49 query_lengths = query_embeddings.scoring_mask.sum(-1)
50 if self.search_config.imputation_strategy == "gather":
51 doc_embeddings, doc_idcs, num_docs = self.gather_imputation(candidate_doc_idcs, query_lengths)
52 doc_scores = self.module.model.score(query_embeddings, doc_embeddings, num_docs)
53 else:
54 doc_scores, doc_idcs, num_docs = self.intra_ranking_imputation(
55 candidate_scores, candidate_doc_idcs, query_lengths
56 )
57 return doc_scores, doc_idcs, num_docs
58
59 def candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor]:
60 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask]
61 candidate_scores, candidate_idcs = self.index.search(embeddings.float().cpu(), self.search_config.candidate_k)
62 candidate_scores = torch.from_numpy(candidate_scores)
63 candidate_idcs = torch.from_numpy(candidate_idcs)
64 if self.doc_is_single_vector:
65 candidate_doc_idcs = candidate_idcs.to(self.cumulative_doc_lengths.device)
66 else:
67 candidate_doc_idcs = torch.searchsorted(
68 self.cumulative_doc_lengths,
69 candidate_idcs.to(self.cumulative_doc_lengths.device),
70 side="right",
71 )
72 return candidate_scores, candidate_doc_idcs
73
74 def gather_imputation(
75 self, candidate_doc_idcs: torch.Tensor, query_lengths: torch.Tensor
76 ) -> Tuple[BiEncoderEmbedding, torch.Tensor, List[int]]:
77 # unique doc_idcs per query
78 doc_idcs_per_query = [
79 list(sorted(set(idcs.reshape(-1).tolist())))
80 for idcs in torch.split(candidate_doc_idcs, query_lengths.tolist())
81 ]
82 num_docs = [len(idcs) for idcs in doc_idcs_per_query]
83 doc_idcs = torch.tensor(sum(doc_idcs_per_query, [])).to(candidate_doc_idcs)
84 unique_doc_idcs, inverse_idcs = torch.unique(doc_idcs, return_inverse=True)
85
86 # gather all vectors for unique doc_idcs
87 doc_lengths = self.doc_lengths[unique_doc_idcs]
88 start_doc_idcs = self.cumulative_doc_lengths[unique_doc_idcs - 1]
89 start_doc_idcs[unique_doc_idcs == 0] = 0
90 all_doc_idcs = torch.cat(
91 [
92 torch.arange(start.item(), start.item() + length.item())
93 for start, length in zip(start_doc_idcs.cpu(), doc_lengths.cpu())
94 ]
95 )
96 all_doc_embeddings = torch.from_numpy(self.index.reconstruct_batch(all_doc_idcs))
97 unique_embeddings = torch.nn.utils.rnn.pad_sequence(
98 [embeddings for embeddings in torch.split(all_doc_embeddings, doc_lengths.tolist())],
99 batch_first=True,
100 ).to(inverse_idcs.device)
101 embeddings = unique_embeddings[inverse_idcs]
102
103 # mask out padding
104 doc_lengths = doc_lengths[inverse_idcs]
105 scoring_mask = torch.arange(embeddings.shape[1], device=embeddings.device) < doc_lengths[:, None]
106 doc_embeddings = BiEncoderEmbedding(embeddings=embeddings, scoring_mask=scoring_mask)
107 return doc_embeddings, doc_idcs, num_docs
108
109 def intra_ranking_imputation(
110 self,
111 candidate_scores: torch.Tensor,
112 candidate_doc_idcs: torch.Tensor,
113 query_lengths: torch.Tensor,
114 ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
115 max_query_length = int(query_lengths.max().item())
116 is_query_single_vector = max_query_length == 1
117
118 if self.doc_is_single_vector:
119 scores = candidate_scores.view(-1)
120 doc_idcs = candidate_doc_idcs.view(-1)
121 num_docs = torch.full((candidate_scores.shape[0],), candidate_scores.shape[1])
122 else:
123 # grab unique doc ids per query candidate
124 query_idcs = torch.arange(query_lengths.shape[0], device=query_lengths.device).repeat_interleave(
125 query_lengths
126 )
127 query_candidate_idcs = torch.cat(
128 [torch.arange(length.item(), device=query_lengths.device) for length in query_lengths]
129 )
130 paired_idcs = torch.stack(
131 [
132 query_idcs.repeat_interleave(candidate_scores.shape[1]),
133 query_candidate_idcs.repeat_interleave(candidate_scores.shape[1]),
134 candidate_doc_idcs.view(-1),
135 ]
136 ).T
137 unique_paired_idcs, inverse_idcs = torch.unique(paired_idcs[:, [0, 2]], return_inverse=True, dim=0)
138 doc_idcs = unique_paired_idcs[:, 1]
139 num_docs = unique_paired_idcs[:, 0].bincount()
140
141 # accumulate max score per doc
142 ranking_doc_idcs = torch.arange(doc_idcs.shape[0], device=query_lengths.device)[inverse_idcs]
143 idcs = ranking_doc_idcs * max_query_length + paired_idcs[:, 1]
144 shape = torch.Size((doc_idcs.shape[0], max_query_length))
145 scores = torch.scatter_reduce(
146 torch.full((shape.numel(),), float("inf"), device=query_lengths.device),
147 0,
148 idcs,
149 candidate_scores.view(-1).to(query_lengths.device),
150 "max",
151 include_self=False,
152 ).view(shape)
153
154 if is_query_single_vector:
155 scores = scores.squeeze(-1)
156 else:
157 # impute missing values
158 if self.search_config.imputation_strategy == "min":
159 impute_values = (
160 scores.masked_fill(scores == torch.finfo(scores.dtype).min, float("inf"))
161 .min(0, keepdim=True)
162 .values.expand_as(scores)
163 )
164 elif self.search_config.imputation_strategy is None:
165 impute_values = torch.zeros_like(scores)
166 else:
167 raise ValueError("Invalid imputation strategy: " f"{self.search_config.imputation_strategy}")
168 is_inf = torch.isinf(scores)
169 scores[is_inf] = impute_values[is_inf]
170
171 # aggregate score per query vector
172 mask = (
173 torch.arange(max_query_length, device=query_lengths.device) < query_lengths[:, None]
174 ).repeat_interleave(num_docs, dim=0)
175 scores = self.module.scoring_function.aggregate(
176 scores, mask, self.module.config.query_aggregation_function, dim=1
177 ).squeeze(-1)
178 return scores, doc_idcs, num_docs.tolist()
179
180
[docs]
181class FaissSearchConfig(SearchConfig):
182 search_class = FaissSearcher
183
[docs]
184 def __init__(
185 self,
186 k: int = 10,
187 candidate_k: int = 100,
188 imputation_strategy: Literal["min", "gather"] | None = None,
189 n_probe: int = 1,
190 ef_search: int = 16,
191 ) -> None:
192 super().__init__(k)
193 self.candidate_k = candidate_k
194 self.imputation_strategy = imputation_strategy
195 self.n_probe = n_probe
196 self.ef_search = ef_search