1import warnings
2from abc import abstractmethod
3from pathlib import Path
4
5import torch
6
7from ..bi_encoder import BiEncoderConfig, BiEncoderOutput
8from ..data import IndexBatch
9from .indexer import IndexConfig, Indexer
10
11
[docs]
12class FaissIndexer(Indexer):
13 INDEX_FACTORY: str
14
[docs]
15 def __init__(
16 self,
17 index_dir: Path,
18 index_config: "FaissIndexConfig",
19 bi_encoder_config: BiEncoderConfig,
20 verbose: bool = False,
21 ) -> None:
22 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
23 import faiss
24
25 similarity_function = bi_encoder_config.similarity_function
26 if similarity_function in ("cosine", "dot"):
27 self.metric_type = faiss.METRIC_INNER_PRODUCT
28 else:
29 raise ValueError(f"similarity_function {similarity_function} unknown")
30
31 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict())
32 if similarity_function == "cosine":
33 index_factory = "L2norm," + index_factory
34 self.index = faiss.index_factory(self.bi_encoder_config.embedding_dim, index_factory, self.metric_type)
35
36 self.set_verbosity()
37
38 if torch.cuda.is_available():
39 self.to_gpu()
40
41 @abstractmethod
42 def to_gpu(self) -> None: ...
43
44 @abstractmethod
45 def to_cpu(self) -> None: ...
46
47 @abstractmethod
48 def set_verbosity(self, verbose: bool | None = None) -> None: ...
49
50 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
51 return embeddings
52
53 def save(self) -> None:
54 super().save()
55 import faiss
56
57 if self.num_embeddings != self.index.ntotal:
58 raise ValueError("number of embeddings does not match index.ntotal")
59 if torch.cuda.is_available():
60 self.index = faiss.index_gpu_to_cpu(self.index)
61
62 faiss.write_index(self.index, str(self.index_dir / "index.faiss"))
63
64 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
65 doc_embeddings = output.doc_embeddings
66 if doc_embeddings is None:
67 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
68 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
69 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
70 doc_ids = index_batch.doc_ids
71 embeddings = self.process_embeddings(embeddings)
72
73 if embeddings.shape[0]:
74 self.index.add(embeddings.float().cpu())
75
76 self.num_embeddings += embeddings.shape[0]
77 self.num_docs += len(doc_ids)
78
79 self.doc_lengths.extend(doc_lengths.cpu().tolist())
80 self.doc_ids.extend(doc_ids)
81
82
[docs]
83class FaissFlatIndexer(FaissIndexer):
84 INDEX_FACTORY = "Flat"
85
[docs]
86 def __init__(
87 self,
88 index_dir: Path,
89 index_config: "FaissFlatIndexConfig",
90 bi_encoder_config: BiEncoderConfig,
91 verbose: bool = False,
92 ) -> None:
93 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
94 self.index_config: FaissFlatIndexConfig
95
96 def to_gpu(self) -> None:
97 pass
98
99 def to_cpu(self) -> None:
100 pass
101
102 def set_verbosity(self, verbose: bool | None = None) -> None:
103 self.index.verbose = self.verbose if verbose is None else verbose
104
105
[docs]
106class FaissIVFIndexer(FaissIndexer):
107 INDEX_FACTORY = "IVF{num_centroids},Flat"
108
[docs]
109 def __init__(
110 self,
111 index_dir: Path,
112 index_config: "FaissIVFIndexConfig",
113 bi_encoder_config: BiEncoderConfig,
114 verbose: bool = False,
115 ) -> None:
116 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
117
118 import faiss
119
120 ivf_index = faiss.extract_index_ivf(self.index)
121 if hasattr(ivf_index, "quantizer"):
122 quantizer = ivf_index.quantizer
123 if hasattr(faiss.downcast_index(quantizer), "hnsw"):
124 downcasted_quantizer = faiss.downcast_index(quantizer)
125 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
126
127 # default faiss values
128 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43
129 max_points_per_centroid = 256
130 self.num_train_embeddings = (
131 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid
132 )
133
134 self._train_embeddings = torch.full(
135 (
136 self.num_train_embeddings,
137 self.bi_encoder_config.embedding_dim,
138 ),
139 torch.nan,
140 dtype=torch.float32,
141 )
142
143 def to_gpu(self) -> None:
144 import faiss
145
146 # clustering_index overrides the index used during clustering but leaves
147 # the quantizer on the gpu
148 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html
149 clustering_index = faiss.index_cpu_to_all_gpus(
150 faiss.IndexFlat(self.bi_encoder_config.embedding_dim, self.metric_type)
151 )
152 clustering_index.verbose = self.verbose
153 index_ivf = faiss.extract_index_ivf(self.index)
154 index_ivf.clustering_index = clustering_index
155
156 def to_cpu(self) -> None:
157 import faiss
158
159 self.index = faiss.index_gpu_to_cpu(self.index)
160
161 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e
162 # move index to cpu but leave quantizer on gpu
163 index_ivf = faiss.extract_index_ivf(self.index)
164 quantizer = index_ivf.quantizer
165 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer)
166 index_ivf.quantizer = gpu_quantizer
167
168 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
169 embeddings = self._grab_train_embeddings(embeddings)
170 self._train()
171 return embeddings
172
173 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
174 if self._train_embeddings is not None:
175 # save training embeddings until num_train_embeddings is reached
176 # if num_train_embeddings overflows, save the remaining embeddings
177 start = self.num_embeddings
178 end = start + embeddings.shape[0]
179 if end > self.num_train_embeddings:
180 end = self.num_train_embeddings
181 length = end - start
182 self._train_embeddings[start:end] = embeddings[:length]
183 self.num_embeddings += length
184 embeddings = embeddings[length:]
185 return embeddings
186
187 def _train(self, force: bool = False):
188 if self._train_embeddings is not None and (force or self.num_embeddings >= self.num_train_embeddings):
189 if torch.isnan(self._train_embeddings).any():
190 warnings.warn(
191 "Corpus contains less tokens/documents than num_train_embeddings. " "Removing NaN embeddings."
192 )
193 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)]
194 self.index.train(self._train_embeddings)
195 if torch.cuda.is_available():
196 self.to_cpu()
197 self.index.add(self._train_embeddings)
198 self._train_embeddings = None
199 self.set_verbosity(False)
200
201 def save(self) -> None:
202 if not self.index.is_trained:
203 self._train(force=True)
204 return super().save()
205
206 def set_verbosity(self, verbose: bool | None = None) -> None:
207 import faiss
208
209 verbose = verbose if verbose is not None else self.verbose
210 index = faiss.extract_index_ivf(self.index)
211 for elem in (index, index.quantizer):
212 setattr(elem, "verbose", verbose)
213
214
[docs]
215class FaissIVFPQIndexer(FaissIVFIndexer):
216 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}"
217
[docs]
218 def __init__(
219 self,
220 index_dir: Path,
221 index_config: "FaissIVFPQIndexConfig",
222 bi_encoder_config: BiEncoderConfig,
223 verbose: bool = False,
224 ) -> None:
225 import faiss
226
227 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
228 self.index_config: FaissIVFPQIndexConfig
229
230 index_ivf = faiss.extract_index_ivf(self.index)
231 index_ivf.make_direct_map()
232
233 def set_verbosity(self, verbose: bool | None = None) -> None:
234 super().set_verbosity(verbose)
235 import faiss
236
237 verbose = verbose if verbose is not None else self.verbose
238 index_ivf_pq = faiss.downcast_index(self.index.index)
239 for elem in (
240 index_ivf_pq.pq,
241 index_ivf_pq.quantizer,
242 ):
243 setattr(elem, "verbose", verbose)
244
245
[docs]
246class FaissIndexConfig(IndexConfig):
247 indexer_class = FaissIndexer
248
249 def to_dict(self) -> dict:
250 return self.__dict__.copy()
251
252
[docs]
253class FaissFlatIndexConfig(FaissIndexConfig):
254 indexer_class = FaissFlatIndexer
255
256
[docs]
257class FaissIVFIndexConfig(FaissIndexConfig):
258 indexer_class = FaissIVFIndexer
259
[docs]
260 def __init__(
261 self,
262 num_train_embeddings: int | None = None,
263 num_centroids: int = 262144,
264 ef_construction: int = 40,
265 ) -> None:
266 super().__init__()
267 self.num_train_embeddings = num_train_embeddings
268 self.num_centroids = num_centroids
269 self.ef_construction = ef_construction
270
271
[docs]
272class FaissIVFPQIndexConfig(FaissIVFIndexConfig):
273 indexer_class = FaissIVFPQIndexer
274
[docs]
275 def __init__(
276 self,
277 num_train_embeddings: int | None = None,
278 num_centroids: int = 262144,
279 ef_construction: int = 40,
280 num_subquantizers: int = 16,
281 n_bits: int = 8,
282 ) -> None:
283 super().__init__(num_train_embeddings, num_centroids, ef_construction)
284 self.num_subquantizers = num_subquantizers
285 self.n_bits = n_bits