Source code for lightning_ir.retrieve.faiss_indexer

  1import warnings
  2from abc import abstractmethod
  3from pathlib import Path
  4
  5import torch
  6
  7from ..bi_encoder import BiEncoderConfig, BiEncoderOutput
  8from ..data import IndexBatch
  9from .indexer import IndexConfig, Indexer
 10
 11
[docs] 12class FaissIndexer(Indexer): 13 INDEX_FACTORY: str 14
[docs] 15 def __init__( 16 self, 17 index_dir: Path, 18 index_config: "FaissIndexConfig", 19 bi_encoder_config: BiEncoderConfig, 20 verbose: bool = False, 21 ) -> None: 22 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 23 import faiss 24 25 similarity_function = bi_encoder_config.similarity_function 26 if similarity_function in ("cosine", "dot"): 27 self.metric_type = faiss.METRIC_INNER_PRODUCT 28 else: 29 raise ValueError(f"similarity_function {similarity_function} unknown") 30 31 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict()) 32 if similarity_function == "cosine": 33 index_factory = "L2norm," + index_factory 34 self.index = faiss.index_factory(self.bi_encoder_config.embedding_dim, index_factory, self.metric_type) 35 36 self.set_verbosity() 37 38 if torch.cuda.is_available(): 39 self.to_gpu()
40 41 @abstractmethod 42 def to_gpu(self) -> None: ... 43 44 @abstractmethod 45 def to_cpu(self) -> None: ... 46 47 @abstractmethod 48 def set_verbosity(self, verbose: bool | None = None) -> None: ... 49 50 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 51 return embeddings 52 53 def save(self) -> None: 54 super().save() 55 import faiss 56 57 if self.num_embeddings != self.index.ntotal: 58 raise ValueError("number of embeddings does not match index.ntotal") 59 if torch.cuda.is_available(): 60 self.index = faiss.index_gpu_to_cpu(self.index) 61 62 faiss.write_index(self.index, str(self.index_dir / "index.faiss")) 63 64 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 65 doc_embeddings = output.doc_embeddings 66 if doc_embeddings is None: 67 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 68 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 69 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 70 doc_ids = index_batch.doc_ids 71 embeddings = self.process_embeddings(embeddings) 72 73 if embeddings.shape[0]: 74 self.index.add(embeddings.float().cpu()) 75 76 self.num_embeddings += embeddings.shape[0] 77 self.num_docs += len(doc_ids) 78 79 self.doc_lengths.extend(doc_lengths.cpu().tolist()) 80 self.doc_ids.extend(doc_ids)
81 82
[docs] 83class FaissFlatIndexer(FaissIndexer): 84 INDEX_FACTORY = "Flat" 85
[docs] 86 def __init__( 87 self, 88 index_dir: Path, 89 index_config: "FaissFlatIndexConfig", 90 bi_encoder_config: BiEncoderConfig, 91 verbose: bool = False, 92 ) -> None: 93 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 94 self.index_config: FaissFlatIndexConfig
95 96 def to_gpu(self) -> None: 97 pass 98 99 def to_cpu(self) -> None: 100 pass 101 102 def set_verbosity(self, verbose: bool | None = None) -> None: 103 self.index.verbose = self.verbose if verbose is None else verbose
104 105
[docs] 106class FaissIVFIndexer(FaissIndexer): 107 INDEX_FACTORY = "IVF{num_centroids},Flat" 108
[docs] 109 def __init__( 110 self, 111 index_dir: Path, 112 index_config: "FaissIVFIndexConfig", 113 bi_encoder_config: BiEncoderConfig, 114 verbose: bool = False, 115 ) -> None: 116 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 117 118 import faiss 119 120 ivf_index = faiss.extract_index_ivf(self.index) 121 if hasattr(ivf_index, "quantizer"): 122 quantizer = ivf_index.quantizer 123 if hasattr(faiss.downcast_index(quantizer), "hnsw"): 124 downcasted_quantizer = faiss.downcast_index(quantizer) 125 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction 126 127 # default faiss values 128 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43 129 max_points_per_centroid = 256 130 self.num_train_embeddings = ( 131 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid 132 ) 133 134 self._train_embeddings = torch.full( 135 ( 136 self.num_train_embeddings, 137 self.bi_encoder_config.embedding_dim, 138 ), 139 torch.nan, 140 dtype=torch.float32, 141 )
142 143 def to_gpu(self) -> None: 144 import faiss 145 146 # clustering_index overrides the index used during clustering but leaves 147 # the quantizer on the gpu 148 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html 149 clustering_index = faiss.index_cpu_to_all_gpus( 150 faiss.IndexFlat(self.bi_encoder_config.embedding_dim, self.metric_type) 151 ) 152 clustering_index.verbose = self.verbose 153 index_ivf = faiss.extract_index_ivf(self.index) 154 index_ivf.clustering_index = clustering_index 155 156 def to_cpu(self) -> None: 157 import faiss 158 159 self.index = faiss.index_gpu_to_cpu(self.index) 160 161 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e 162 # move index to cpu but leave quantizer on gpu 163 index_ivf = faiss.extract_index_ivf(self.index) 164 quantizer = index_ivf.quantizer 165 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer) 166 index_ivf.quantizer = gpu_quantizer 167 168 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 169 embeddings = self._grab_train_embeddings(embeddings) 170 self._train() 171 return embeddings 172 173 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 174 if self._train_embeddings is not None: 175 # save training embeddings until num_train_embeddings is reached 176 # if num_train_embeddings overflows, save the remaining embeddings 177 start = self.num_embeddings 178 end = start + embeddings.shape[0] 179 if end > self.num_train_embeddings: 180 end = self.num_train_embeddings 181 length = end - start 182 self._train_embeddings[start:end] = embeddings[:length] 183 self.num_embeddings += length 184 embeddings = embeddings[length:] 185 return embeddings 186 187 def _train(self, force: bool = False): 188 if self._train_embeddings is not None and (force or self.num_embeddings >= self.num_train_embeddings): 189 if torch.isnan(self._train_embeddings).any(): 190 warnings.warn( 191 "Corpus contains less tokens/documents than num_train_embeddings. " "Removing NaN embeddings." 192 ) 193 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)] 194 self.index.train(self._train_embeddings) 195 if torch.cuda.is_available(): 196 self.to_cpu() 197 self.index.add(self._train_embeddings) 198 self._train_embeddings = None 199 self.set_verbosity(False) 200 201 def save(self) -> None: 202 if not self.index.is_trained: 203 self._train(force=True) 204 return super().save() 205 206 def set_verbosity(self, verbose: bool | None = None) -> None: 207 import faiss 208 209 verbose = verbose if verbose is not None else self.verbose 210 index = faiss.extract_index_ivf(self.index) 211 for elem in (index, index.quantizer): 212 setattr(elem, "verbose", verbose)
213 214
[docs] 215class FaissIVFPQIndexer(FaissIVFIndexer): 216 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}" 217
[docs] 218 def __init__( 219 self, 220 index_dir: Path, 221 index_config: "FaissIVFPQIndexConfig", 222 bi_encoder_config: BiEncoderConfig, 223 verbose: bool = False, 224 ) -> None: 225 import faiss 226 227 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 228 self.index_config: FaissIVFPQIndexConfig 229 230 index_ivf = faiss.extract_index_ivf(self.index) 231 index_ivf.make_direct_map()
232 233 def set_verbosity(self, verbose: bool | None = None) -> None: 234 super().set_verbosity(verbose) 235 import faiss 236 237 verbose = verbose if verbose is not None else self.verbose 238 index_ivf_pq = faiss.downcast_index(self.index.index) 239 for elem in ( 240 index_ivf_pq.pq, 241 index_ivf_pq.quantizer, 242 ): 243 setattr(elem, "verbose", verbose)
244 245
[docs] 246class FaissIndexConfig(IndexConfig): 247 indexer_class = FaissIndexer 248 249 def to_dict(self) -> dict: 250 return self.__dict__.copy()
251 252
[docs] 253class FaissFlatIndexConfig(FaissIndexConfig): 254 indexer_class = FaissFlatIndexer
255 256
[docs] 257class FaissIVFIndexConfig(FaissIndexConfig): 258 indexer_class = FaissIVFIndexer 259
[docs] 260 def __init__( 261 self, 262 num_train_embeddings: int | None = None, 263 num_centroids: int = 262144, 264 ef_construction: int = 40, 265 ) -> None: 266 super().__init__() 267 self.num_train_embeddings = num_train_embeddings 268 self.num_centroids = num_centroids 269 self.ef_construction = ef_construction
270 271
[docs] 272class FaissIVFPQIndexConfig(FaissIVFIndexConfig): 273 indexer_class = FaissIVFPQIndexer 274
[docs] 275 def __init__( 276 self, 277 num_train_embeddings: int | None = None, 278 num_centroids: int = 262144, 279 ef_construction: int = 40, 280 num_subquantizers: int = 16, 281 n_bits: int = 8, 282 ) -> None: 283 super().__init__(num_train_embeddings, num_centroids, ef_construction) 284 self.num_subquantizers = num_subquantizers 285 self.n_bits = n_bits