1from __future__ import annotations
2
3from abc import ABC, abstractmethod
4from pathlib import Path
5from typing import TYPE_CHECKING, List, Sequence, Tuple, Type
6
7import torch
8
9from ..bi_encoder.model import BiEncoderEmbedding
10
11if TYPE_CHECKING:
12 from ..bi_encoder import BiEncoderModule, BiEncoderOutput
13
14
[docs]
15class Searcher(ABC):
[docs]
16 def __init__(
17 self, index_dir: Path | str, search_config: SearchConfig, module: BiEncoderModule, use_gpu: bool = True
18 ) -> None:
19 super().__init__()
20 self.index_dir = Path(index_dir)
21 self.search_config = search_config
22 self.module = module
23 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
24
25 self.doc_ids = (self.index_dir / "doc_ids.txt").read_text().split()
26 self.doc_lengths = torch.load(self.index_dir / "doc_lengths.pt")
27
28 self.to_gpu()
29
30 self.num_docs = len(self.doc_ids)
31 self.cumulative_doc_lengths = torch.cumsum(self.doc_lengths, dim=0)
32
33 if self.doc_lengths.shape[0] != self.num_docs or self.doc_lengths.sum() != self.num_embeddings:
34 raise ValueError("doc_lengths do not match index")
35
36 def to_gpu(self) -> None:
37 self.doc_lengths = self.doc_lengths.to(self.device)
38
39 @property
40 @abstractmethod
41 def num_embeddings(self) -> int: ...
42
43 @abstractmethod
44 def _search(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: ...
45
46 def _filter_and_sort(
47 self,
48 doc_scores: torch.Tensor,
49 doc_idcs: torch.Tensor | None,
50 num_docs: Sequence[int] | None,
51 ) -> Tuple[torch.Tensor, List[str], List[int]]:
52 if (doc_idcs is None) != (num_docs is None):
53 raise ValueError("doc_ids and num_docs must be both None or not None")
54 if doc_idcs is None and num_docs is None:
55 # assume we have searched the whole index
56 k = min(self.search_config.k, doc_scores.shape[0])
57 values, idcs = torch.topk(doc_scores.view(-1, self.num_docs), k)
58 num_queries = values.shape[0]
59 values = values.view(-1)
60 idcs = idcs.view(-1)
61 doc_ids = [self.doc_ids[doc_idx] for doc_idx in idcs.cpu()]
62 return values, doc_ids, [k] * num_queries
63
64 assert doc_idcs is not None and num_docs is not None
65 per_query_doc_scores = torch.split(doc_scores, num_docs)
66 per_query_doc_idcs = torch.split(doc_idcs, num_docs)
67 new_num_docs = []
68 _doc_scores = []
69 doc_ids = []
70 for query_idx, scores in enumerate(per_query_doc_scores):
71 k = min(self.search_config.k, scores.shape[0])
72 values, idcs = torch.topk(scores, k)
73 _doc_scores.append(values)
74 doc_ids.extend([self.doc_ids[doc_idx] for doc_idx in per_query_doc_idcs[query_idx][idcs].cpu()])
75 new_num_docs.append(k)
76 doc_scores = torch.cat(_doc_scores)
77 return doc_scores, doc_ids, new_num_docs
78
79 def search(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, List[str], List[int]]:
80 query_embeddings = output.query_embeddings
81 if query_embeddings is None:
82 raise ValueError("Expected query_embeddings in BiEncoderOutput")
83 doc_scores, doc_idcs, num_docs = self._search(query_embeddings)
84 doc_scores, doc_ids, num_docs = self._filter_and_sort(doc_scores, doc_idcs, num_docs)
85
86 return doc_scores, doc_ids, num_docs
87
88
[docs]
89class SearchConfig:
90 search_class: Type[Searcher] = Searcher
91
[docs]
92 def __init__(self, k: int = 10) -> None:
93 self.k = k