Source code for lightning_ir.retrieve.indexer
1from __future__ import annotations
2
3import array
4import json
5from abc import ABC, abstractmethod
6from pathlib import Path
7from typing import TYPE_CHECKING
8
9import torch
10
11if TYPE_CHECKING:
12 from ..bi_encoder import BiEncoderConfig, BiEncoderOutput
13 from ..data import IndexBatch
14
15
[docs]
16class Indexer(ABC):
[docs]
17 def __init__(
18 self,
19 index_dir: Path,
20 index_config: IndexConfig,
21 bi_encoder_config: BiEncoderConfig,
22 verbose: bool = False,
23 ) -> None:
24 self.index_dir = index_dir
25 self.index_config = index_config
26 self.bi_encoder_config = bi_encoder_config
27 self.doc_ids = []
28 self.doc_lengths = array.array("I")
29 self.num_embeddings = 0
30 self.num_docs = 0
31 self.verbose = verbose
32
33 @abstractmethod
34 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: ...
35
36 def save(self) -> None:
37 self.index_config.save(self.index_dir)
38 (self.index_dir / "doc_ids.txt").write_text("\n".join(self.doc_ids))
39 doc_lengths = torch.tensor(self.doc_lengths)
40 torch.save(doc_lengths, self.index_dir / "doc_lengths.pt")
41
42
[docs]
43class IndexConfig:
44 indexer_class = Indexer
45
46 @classmethod
47 def from_pretrained(cls, index_dir: Path) -> "IndexConfig":
48 with open(index_dir / "config.json", "r") as f:
49 data = json.load(f)
50 if data["index_type"] != cls.__name__:
51 raise ValueError(f"Expected index_type {cls.__name__}, got {data['index_type']}")
52 data.pop("index_type", None)
53 data.pop("index_dir", None)
54 return cls(**data)
55
56 def save(self, index_dir: Path) -> None:
57 index_dir.mkdir(parents=True, exist_ok=True)
58 with open(index_dir / "config.json", "w") as f:
59 data = self.__dict__.copy()
60 data["index_dir"] = str(index_dir)
61 data["index_type"] = self.__class__.__name__
62 json.dump(data, f)