Source code for lightning_ir.retrieve.faiss_searcher

  1from __future__ import annotations
  2
  3from pathlib import Path
  4from typing import TYPE_CHECKING, List, Literal, Tuple
  5
  6import torch
  7
  8from ..bi_encoder.model import BiEncoderEmbedding
  9from .searcher import SearchConfig, Searcher
 10
 11if TYPE_CHECKING:
 12    from ..bi_encoder import BiEncoderModule
 13
 14
[docs] 15class FaissSearcher(Searcher):
[docs] 16 def __init__( 17 self, index_dir: Path | str, search_config: FaissSearchConfig, module: BiEncoderModule, use_gpu: bool = False 18 ) -> None: 19 import faiss 20 21 self.search_config: FaissSearchConfig 22 self.index = faiss.read_index(str(Path(index_dir) / "index.faiss")) 23 ivf_index = None 24 try: 25 ivf_index = faiss.extract_index_ivf(self.index) 26 except RuntimeError: 27 pass 28 if ivf_index is not None: 29 ivf_index.nprobe = search_config.n_probe 30 quantizer = getattr(ivf_index, "quantizer", None) 31 if quantizer is not None: 32 downcasted_quantizer = faiss.downcast_index(quantizer) 33 hnsw = getattr(downcasted_quantizer, "hnsw", None) 34 if hnsw is not None: 35 hnsw.efSearch = search_config.ef_search 36 super().__init__(index_dir, search_config, module, use_gpu)
37 38 @property 39 def num_embeddings(self) -> int: 40 return self.index.ntotal 41 42 @property 43 def doc_is_single_vector(self) -> bool: 44 return self.num_docs == self.num_embeddings 45 46 def _search(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: 47 query_embeddings = query_embeddings.to(self.device) 48 candidate_scores, candidate_doc_idcs = self.candidate_retrieval(query_embeddings) 49 query_lengths = query_embeddings.scoring_mask.sum(-1) 50 if self.search_config.imputation_strategy == "gather": 51 doc_embeddings, doc_idcs, num_docs = self.gather_imputation(candidate_doc_idcs, query_lengths) 52 doc_scores = self.module.model.score(query_embeddings, doc_embeddings, num_docs) 53 else: 54 doc_scores, doc_idcs, num_docs = self.intra_ranking_imputation( 55 candidate_scores, candidate_doc_idcs, query_lengths 56 ) 57 return doc_scores, doc_idcs, num_docs 58 59 def candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor]: 60 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask] 61 candidate_scores, candidate_idcs = self.index.search(embeddings.float().cpu(), self.search_config.candidate_k) 62 candidate_scores = torch.from_numpy(candidate_scores) 63 candidate_idcs = torch.from_numpy(candidate_idcs) 64 if self.doc_is_single_vector: 65 candidate_doc_idcs = candidate_idcs.to(self.cumulative_doc_lengths.device) 66 else: 67 candidate_doc_idcs = torch.searchsorted( 68 self.cumulative_doc_lengths, 69 candidate_idcs.to(self.cumulative_doc_lengths.device), 70 side="right", 71 ) 72 return candidate_scores, candidate_doc_idcs 73 74 def gather_imputation( 75 self, candidate_doc_idcs: torch.Tensor, query_lengths: torch.Tensor 76 ) -> Tuple[BiEncoderEmbedding, torch.Tensor, List[int]]: 77 # unique doc_idcs per query 78 doc_idcs_per_query = [ 79 list(sorted(set(idcs.reshape(-1).tolist()))) 80 for idcs in torch.split(candidate_doc_idcs, query_lengths.tolist()) 81 ] 82 num_docs = [len(idcs) for idcs in doc_idcs_per_query] 83 doc_idcs = torch.tensor(sum(doc_idcs_per_query, [])).to(candidate_doc_idcs) 84 unique_doc_idcs, inverse_idcs = torch.unique(doc_idcs, return_inverse=True) 85 86 # gather all vectors for unique doc_idcs 87 doc_lengths = self.doc_lengths[unique_doc_idcs] 88 start_doc_idcs = self.cumulative_doc_lengths[unique_doc_idcs - 1] 89 start_doc_idcs[unique_doc_idcs == 0] = 0 90 all_doc_idcs = torch.cat( 91 [ 92 torch.arange(start.item(), start.item() + length.item()) 93 for start, length in zip(start_doc_idcs.cpu(), doc_lengths.cpu()) 94 ] 95 ) 96 all_doc_embeddings = torch.from_numpy(self.index.reconstruct_batch(all_doc_idcs)) 97 unique_embeddings = torch.nn.utils.rnn.pad_sequence( 98 [embeddings for embeddings in torch.split(all_doc_embeddings, doc_lengths.tolist())], 99 batch_first=True, 100 ).to(inverse_idcs.device) 101 embeddings = unique_embeddings[inverse_idcs] 102 103 # mask out padding 104 doc_lengths = doc_lengths[inverse_idcs] 105 scoring_mask = torch.arange(embeddings.shape[1], device=embeddings.device) < doc_lengths[:, None] 106 doc_embeddings = BiEncoderEmbedding(embeddings=embeddings, scoring_mask=scoring_mask) 107 return doc_embeddings, doc_idcs, num_docs 108 109 def intra_ranking_imputation( 110 self, 111 candidate_scores: torch.Tensor, 112 candidate_doc_idcs: torch.Tensor, 113 query_lengths: torch.Tensor, 114 ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: 115 max_query_length = int(query_lengths.max().item()) 116 is_query_single_vector = max_query_length == 1 117 118 if self.doc_is_single_vector: 119 scores = candidate_scores.view(-1) 120 doc_idcs = candidate_doc_idcs.view(-1) 121 num_docs = torch.full((candidate_scores.shape[0],), candidate_scores.shape[1]) 122 else: 123 # grab unique doc ids per query candidate 124 query_idcs = torch.arange(query_lengths.shape[0], device=query_lengths.device).repeat_interleave( 125 query_lengths 126 ) 127 query_candidate_idcs = torch.cat( 128 [torch.arange(length.item(), device=query_lengths.device) for length in query_lengths] 129 ) 130 paired_idcs = torch.stack( 131 [ 132 query_idcs.repeat_interleave(candidate_scores.shape[1]), 133 query_candidate_idcs.repeat_interleave(candidate_scores.shape[1]), 134 candidate_doc_idcs.view(-1), 135 ] 136 ).T 137 unique_paired_idcs, inverse_idcs = torch.unique(paired_idcs[:, [0, 2]], return_inverse=True, dim=0) 138 doc_idcs = unique_paired_idcs[:, 1] 139 num_docs = unique_paired_idcs[:, 0].bincount() 140 141 # accumulate max score per doc 142 ranking_doc_idcs = torch.arange(doc_idcs.shape[0], device=query_lengths.device)[inverse_idcs] 143 idcs = ranking_doc_idcs * max_query_length + paired_idcs[:, 1] 144 shape = torch.Size((doc_idcs.shape[0], max_query_length)) 145 scores = torch.scatter_reduce( 146 torch.full((shape.numel(),), float("inf"), device=query_lengths.device), 147 0, 148 idcs, 149 candidate_scores.view(-1).to(query_lengths.device), 150 "max", 151 include_self=False, 152 ).view(shape) 153 154 if is_query_single_vector: 155 scores = scores.squeeze(-1) 156 else: 157 # impute missing values 158 if self.search_config.imputation_strategy == "min": 159 impute_values = ( 160 scores.masked_fill(scores == torch.finfo(scores.dtype).min, float("inf")) 161 .min(0, keepdim=True) 162 .values.expand_as(scores) 163 ) 164 elif self.search_config.imputation_strategy is None: 165 impute_values = torch.zeros_like(scores) 166 else: 167 raise ValueError("Invalid imputation strategy: " f"{self.search_config.imputation_strategy}") 168 is_inf = torch.isinf(scores) 169 scores[is_inf] = impute_values[is_inf] 170 171 # aggregate score per query vector 172 mask = ( 173 torch.arange(max_query_length, device=query_lengths.device) < query_lengths[:, None] 174 ).repeat_interleave(num_docs, dim=0) 175 scores = self.module.scoring_function.aggregate( 176 scores, mask, self.module.config.query_aggregation_function, dim=1 177 ).squeeze(-1) 178 return scores, doc_idcs, num_docs.tolist()
179 180
[docs] 181class FaissSearchConfig(SearchConfig): 182 search_class = FaissSearcher 183
[docs] 184 def __init__( 185 self, 186 k: int = 10, 187 candidate_k: int = 100, 188 imputation_strategy: Literal["min", "gather"] | None = None, 189 n_probe: int = 1, 190 ef_search: int = 16, 191 ) -> None: 192 super().__init__(k) 193 self.candidate_k = candidate_k 194 self.imputation_strategy = imputation_strategy 195 self.n_probe = n_probe 196 self.ef_search = ef_search