Source code for lightning_ir.retrieve.sparse_indexer

 1import array
 2from pathlib import Path
 3
 4import torch
 5
 6from ..bi_encoder import BiEncoderConfig, BiEncoderOutput
 7from ..data import IndexBatch
 8from .indexer import IndexConfig, Indexer
 9
10
[docs] 11class SparseIndexer(Indexer):
[docs] 12 def __init__( 13 self, 14 index_dir: Path, 15 index_config: "SparseIndexConfig", 16 bi_encoder_config: BiEncoderConfig, 17 verbose: bool = False, 18 ) -> None: 19 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 20 self.crow_indices = array.array("L") 21 self.crow_indices.append(0) 22 self.col_idcs = array.array("I") 23 self.values = array.array("f")
24 25 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 26 doc_embeddings = output.doc_embeddings 27 if doc_embeddings is None: 28 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 29 30 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 31 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 32 num_docs = len(index_batch.doc_ids) 33 self.doc_ids.extend(index_batch.doc_ids) 34 35 token_idcs, dim_idcs = torch.nonzero(embeddings, as_tuple=True) 36 crow_indices = token_idcs.bincount().cumsum(0) + self.crow_indices[-1] 37 values = embeddings[token_idcs, dim_idcs] 38 self.crow_indices.extend(crow_indices.cpu().tolist()) 39 self.col_idcs.extend(dim_idcs.cpu().tolist()) 40 self.values.extend(values.cpu().tolist()) 41 42 self.doc_lengths.extend(doc_lengths.cpu().tolist()) 43 self.num_embeddings += embeddings.shape[0] 44 self.num_docs += num_docs 45 46 def to_gpu(self) -> None: 47 pass 48 49 def to_cpu(self) -> None: 50 pass 51 52 def save(self) -> None: 53 super().save() 54 index = torch.sparse_csr_tensor( 55 torch.tensor(self.crow_indices), 56 torch.tensor(self.col_idcs), 57 torch.tensor(self.values), 58 torch.Size([self.num_embeddings, self.bi_encoder_config.embedding_dim]), 59 ) 60 torch.save(index, self.index_dir / "index.pt")
61 62
[docs] 63class SparseIndexConfig(IndexConfig): 64 indexer_class = SparseIndexer