Source code for lightning_ir.retrieve.indexer

 1from __future__ import annotations
 2
 3import array
 4import json
 5from abc import ABC, abstractmethod
 6from pathlib import Path
 7from typing import TYPE_CHECKING
 8
 9import torch
10
11if TYPE_CHECKING:
12    from ..bi_encoder import BiEncoderConfig, BiEncoderOutput
13    from ..data import IndexBatch
14
15
[docs] 16class Indexer(ABC):
[docs] 17 def __init__( 18 self, 19 index_dir: Path, 20 index_config: IndexConfig, 21 bi_encoder_config: BiEncoderConfig, 22 verbose: bool = False, 23 ) -> None: 24 self.index_dir = index_dir 25 self.index_config = index_config 26 self.bi_encoder_config = bi_encoder_config 27 self.doc_ids = [] 28 self.doc_lengths = array.array("I") 29 self.num_embeddings = 0 30 self.num_docs = 0 31 self.verbose = verbose
32 33 @abstractmethod 34 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: ... 35 36 def save(self) -> None: 37 self.index_config.save(self.index_dir) 38 (self.index_dir / "doc_ids.txt").write_text("\n".join(self.doc_ids)) 39 doc_lengths = torch.tensor(self.doc_lengths) 40 torch.save(doc_lengths, self.index_dir / "doc_lengths.pt")
41 42
[docs] 43class IndexConfig: 44 indexer_class = Indexer 45 46 @classmethod 47 def from_pretrained(cls, index_dir: Path) -> "IndexConfig": 48 with open(index_dir / "config.json", "r") as f: 49 data = json.load(f) 50 if data["index_type"] != cls.__name__: 51 raise ValueError(f"Expected index_type {cls.__name__}, got {data['index_type']}") 52 data.pop("index_type", None) 53 data.pop("index_dir", None) 54 return cls(**data) 55 56 def save(self, index_dir: Path) -> None: 57 index_dir.mkdir(parents=True, exist_ok=True) 58 with open(index_dir / "config.json", "w") as f: 59 data = self.__dict__.copy() 60 data["index_dir"] = str(index_dir) 61 data["index_type"] = self.__class__.__name__ 62 json.dump(data, f)