Source code for lightning_ir.retrieve.searcher

 1from __future__ import annotations
 2
 3from abc import ABC, abstractmethod
 4from pathlib import Path
 5from typing import TYPE_CHECKING, List, Sequence, Tuple, Type
 6
 7import torch
 8
 9from ..bi_encoder.model import BiEncoderEmbedding
10
11if TYPE_CHECKING:
12    from ..bi_encoder import BiEncoderModule, BiEncoderOutput
13
14
[docs] 15class Searcher(ABC):
[docs] 16 def __init__( 17 self, index_dir: Path | str, search_config: SearchConfig, module: BiEncoderModule, use_gpu: bool = True 18 ) -> None: 19 super().__init__() 20 self.index_dir = Path(index_dir) 21 self.search_config = search_config 22 self.module = module 23 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu") 24 25 self.doc_ids = (self.index_dir / "doc_ids.txt").read_text().split() 26 self.doc_lengths = torch.load(self.index_dir / "doc_lengths.pt") 27 28 self.to_gpu() 29 30 self.num_docs = len(self.doc_ids) 31 self.cumulative_doc_lengths = torch.cumsum(self.doc_lengths, dim=0) 32 33 if self.doc_lengths.shape[0] != self.num_docs or self.doc_lengths.sum() != self.num_embeddings: 34 raise ValueError("doc_lengths do not match index")
35 36 def to_gpu(self) -> None: 37 self.doc_lengths = self.doc_lengths.to(self.device) 38 39 @property 40 @abstractmethod 41 def num_embeddings(self) -> int: ... 42 43 @abstractmethod 44 def _search(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: ... 45 46 def _filter_and_sort( 47 self, 48 doc_scores: torch.Tensor, 49 doc_idcs: torch.Tensor | None, 50 num_docs: Sequence[int] | None, 51 ) -> Tuple[torch.Tensor, List[str], List[int]]: 52 if (doc_idcs is None) != (num_docs is None): 53 raise ValueError("doc_ids and num_docs must be both None or not None") 54 if doc_idcs is None and num_docs is None: 55 # assume we have searched the whole index 56 k = min(self.search_config.k, doc_scores.shape[0]) 57 values, idcs = torch.topk(doc_scores.view(-1, self.num_docs), k) 58 num_queries = values.shape[0] 59 values = values.view(-1) 60 idcs = idcs.view(-1) 61 doc_ids = [self.doc_ids[doc_idx] for doc_idx in idcs.cpu()] 62 return values, doc_ids, [k] * num_queries 63 64 assert doc_idcs is not None and num_docs is not None 65 per_query_doc_scores = torch.split(doc_scores, num_docs) 66 per_query_doc_idcs = torch.split(doc_idcs, num_docs) 67 new_num_docs = [] 68 _doc_scores = [] 69 doc_ids = [] 70 for query_idx, scores in enumerate(per_query_doc_scores): 71 k = min(self.search_config.k, scores.shape[0]) 72 values, idcs = torch.topk(scores, k) 73 _doc_scores.append(values) 74 doc_ids.extend([self.doc_ids[doc_idx] for doc_idx in per_query_doc_idcs[query_idx][idcs].cpu()]) 75 new_num_docs.append(k) 76 doc_scores = torch.cat(_doc_scores) 77 return doc_scores, doc_ids, new_num_docs 78 79 def search(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, List[str], List[int]]: 80 query_embeddings = output.query_embeddings 81 if query_embeddings is None: 82 raise ValueError("Expected query_embeddings in BiEncoderOutput") 83 doc_scores, doc_idcs, num_docs = self._search(query_embeddings) 84 doc_scores, doc_ids, num_docs = self._filter_and_sort(doc_scores, doc_idcs, num_docs) 85 86 return doc_scores, doc_ids, num_docs
87 88
[docs] 89class SearchConfig: 90 search_class: Type[Searcher] = Searcher 91
[docs] 92 def __init__(self, k: int = 10) -> None: 93 self.k = k