Source code for lightning_ir.retrieve.sparse_searcher

  1from __future__ import annotations
  2
  3from pathlib import Path
  4from typing import TYPE_CHECKING, Literal, Tuple
  5
  6import torch
  7
  8from .searcher import SearchConfig, Searcher
  9from .sparse_indexer import SparseIndexConfig
 10
 11if TYPE_CHECKING:
 12    from ..bi_encoder import BiEncoderEmbedding, BiEncoderModule
 13
 14
[docs] 15class SparseIndex:
[docs] 16 def __init__(self, index_dir: Path, similarity_function: Literal["dot", "cosine"], use_gpu: bool = False) -> None: 17 self.index = torch.load(index_dir / "index.pt") 18 self.config = SparseIndexConfig.from_pretrained(index_dir) 19 if similarity_function == "dot": 20 self.similarity_function = self.dot_similarity 21 elif similarity_function == "cosine": 22 self.similarity_function = self.cosine_similarity 23 else: 24 raise ValueError("Unknown similarity function") 25 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
26 27 def score(self, embeddings: torch.Tensor) -> torch.Tensor: 28 embeddings = embeddings.to(self.device) 29 similarity = self.similarity_function(embeddings, self.index).to_dense() 30 return similarity 31 32 @property 33 def num_embeddings(self) -> int: 34 return self.index.shape[0] 35 36 def cosine_similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 37 dot_product = self.dot_similarity(x, y) 38 dot_product = dot_product / (torch.norm(x, dim=-1) * torch.norm(y, dim=-1)) 39 return -1 * torch.cdist(x, y).squeeze(-2) 40 41 def dot_similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 42 return torch.matmul(y, x.T).T 43 44 def to_gpu(self) -> None: 45 self.index = self.index.to(self.device)
46 47
[docs] 48class SparseSearcher(Searcher):
[docs] 49 def __init__( 50 self, 51 index_dir: Path, 52 search_config: SparseSearchConfig, 53 module: BiEncoderModule, 54 use_gpu: bool = True, 55 ) -> None: 56 self.search_config: SparseSearchConfig 57 self.index = SparseIndex(index_dir, module.config.similarity_function, use_gpu) 58 super().__init__(index_dir, search_config, module, use_gpu) 59 self.doc_token_idcs = ( 60 torch.arange(self.doc_lengths.shape[0]).to(self.doc_lengths).repeat_interleave(self.doc_lengths) 61 ) 62 self.use_gpu = use_gpu 63 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
64 65 @property 66 def doc_is_single_vector(self) -> bool: 67 return self.cumulative_doc_lengths[-1].item() == self.cumulative_doc_lengths.shape[0] 68 69 def to_gpu(self) -> None: 70 super().to_gpu() 71 self.index.to_gpu() 72 73 @property 74 def num_embeddings(self) -> int: 75 return self.index.num_embeddings 76 77 def _search(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, None, None]: 78 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask] 79 query_lengths = query_embeddings.scoring_mask.sum(-1) 80 scores = self.index.score(embeddings) 81 82 # aggregate doc token scores 83 if not self.doc_is_single_vector: 84 scores = torch.scatter_reduce( 85 torch.zeros(scores.shape[0], self.num_docs, device=scores.device), 86 1, 87 self.doc_token_idcs[None].expand_as(scores), 88 scores, 89 "amax", 90 ) 91 92 # aggregate query token scores 93 query_is_single_vector = (query_lengths == 1).all() 94 if not query_is_single_vector: 95 query_token_idcs = torch.arange(query_lengths.shape[0]).to(query_lengths).repeat_interleave(query_lengths) 96 scores = torch.scatter_reduce( 97 torch.zeros(query_lengths.shape[0], self.num_docs, device=scores.device), 98 0, 99 query_token_idcs[:, None].expand_as(scores), 100 scores, 101 self.module.config.query_aggregation_function, 102 ) 103 scores = scores.reshape(-1) 104 return scores, None, None
105 106
[docs] 107class SparseSearchConfig(SearchConfig): 108 search_class = SparseSearcher