1from __future__ import annotations
2
3import itertools
4from dataclasses import is_dataclass
5from pathlib import Path
6from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, TypeVar
7
8import pandas as pd
9import torch
10from lightning import LightningModule, Trainer
11from lightning.pytorch.callbacks import BasePredictionWriter, Callback, TQDMProgressBar
12
13from ..data import RankBatch, SearchBatch
14from ..data.dataset import RUN_HEADER, DocDataset, QueryDataset, RunDataset
15from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher
16
17if TYPE_CHECKING:
18 from ..base import LightningIRModule, LightningIROutput
19 from ..bi_encoder import BiEncoderModule, BiEncoderOutput
20
21T = TypeVar("T")
22
23
38
39
[docs]
40class GatherMixin:
41 def gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
42 if is_dataclass(dataclass):
43 return dataclass.__class__(
44 **{k: self.gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
45 )
46 return pl_module.all_gather(dataclass)
47
48
[docs]
49class IndexCallback(Callback, GatherMixin):
[docs]
50 def __init__(
51 self,
52 index_dir: Path | str | None,
53 index_config: IndexConfig,
54 verbose: bool = False,
55 ) -> None:
56 super().__init__()
57 self.index_config = index_config
58 self.index_dir = index_dir
59 self.verbose = verbose
60 self.indexer: Indexer
61
62 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
63 if stage != "test":
64 raise ValueError("IndexCallback can only be used in test stage")
65
66 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
67 dataloaders = trainer.test_dataloaders
68 if dataloaders is None:
69 raise ValueError("No test_dataloaders found")
70 datasets = [dataloader.dataset for dataloader in dataloaders]
71 if not all(isinstance(dataset, DocDataset) for dataset in datasets):
72 raise ValueError("Expected DocDatasets for indexing")
73
74 def get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
75 index_dir = self.index_dir
76 if index_dir is None:
77 default_index_dir = Path(pl_module.config.name_or_path)
78 if default_index_dir.exists():
79 index_dir = default_index_dir / "indexes"
80 else:
81 raise ValueError("No index_dir provided and model_name_or_path is not a path")
82 index_dir = index_dir / dataset.docs_dataset_id
83 return Path(index_dir)
84
85 def get_indexer(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Indexer:
86 dataloaders = trainer.test_dataloaders
87 if dataloaders is None:
88 raise ValueError("No test_dataloaders found")
89 dataset = dataloaders[dataset_idx].dataset
90
91 index_dir = self.get_index_dir(pl_module, dataset)
92
93 indexer = self.index_config.indexer_class(index_dir, self.index_config, pl_module.config, self.verbose)
94 return indexer
95
96 def log_to_pg(self, info: Dict[str, Any], trainer: Trainer):
97 pg_callback = trainer.progress_bar_callback
98 if pg_callback is None or not isinstance(pg_callback, TQDMProgressBar):
99 return
100 pg = pg_callback.test_progress_bar
101 info = {k: format_large_number(v) for k, v in info.items()}
102 if pg is not None:
103 pg.set_postfix(info)
104
105 def on_test_batch_end(
106 self,
107 trainer: Trainer,
108 pl_module: BiEncoderModule,
109 outputs: BiEncoderOutput,
110 batch: Any,
111 batch_idx: int,
112 dataloader_idx: int = 0,
113 ) -> None:
114 if batch_idx == 0:
115 if hasattr(self, "indexer"):
116 self.indexer.save()
117 self.indexer = self.get_indexer(trainer, pl_module, dataloader_idx)
118
119 batch = self.gather(pl_module, batch)
120 outputs = self.gather(pl_module, outputs)
121
122 if not trainer.is_global_zero:
123 return
124
125 self.indexer.add(batch, outputs)
126 self.log_to_pg(
127 {
128 "num_docs": self.indexer.num_docs,
129 "num_embeddings": self.indexer.num_embeddings,
130 },
131 trainer,
132 )
133 return super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
134
135 def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
136 self.indexer.save()
137
138
[docs]
139class RankCallback(BasePredictionWriter, GatherMixin):
[docs]
140 def __init__(self, save_dir: Path | str | None = None, run_name: str | None = None) -> None:
141 super().__init__()
142 self.save_dir = Path(save_dir) if save_dir is not None else None
143 self.run_name = run_name
144 self.run_dfs: List[pd.DataFrame] = []
145
146 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
147 if stage != "test":
148 raise ValueError(f"{self.__class__.__name__} can only be used in test stage")
149
150 def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
151 super().on_test_epoch_start(trainer, pl_module)
152 self.run_dfs = []
153 if self.save_dir is None:
154 default_save_dir = Path(pl_module.config.name_or_path)
155 if default_save_dir.exists():
156 self.save_dir = default_save_dir / "runs"
157 print(f"Using default save_dir {self.save_dir}")
158 else:
159 raise ValueError("No save_dir provided and model_name_or_path is not a path")
160
161 def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
162 super().on_test_epoch_end(trainer, pl_module)
163 if trainer.is_global_zero:
164 self.write_run_dfs(trainer, -1)
165 self.run_dfs = []
166
167 def get_run_path(self, trainer: Trainer, dataset_idx: int) -> Path:
168 dataloaders = trainer.test_dataloaders
169 if self.save_dir is None:
170 raise ValueError("No save_dir found; call setup before using this method")
171 if dataloaders is None:
172 raise ValueError("No test_dataloaders found")
173 dataset = dataloaders[dataset_idx].dataset
174 if self.run_name is not None:
175 run_file = self.run_name
176 elif isinstance(dataset, QueryDataset):
177 run_file = f"{dataset.dataset_id.replace('/', '-')}.run"
178 elif isinstance(dataset, RunDataset):
179 if dataset.run_path is None:
180 run_file = f"{dataset.dataset_id.replace('/', '-')}.run"
181 else:
182 run_file = f"{dataset.run_path.name.split('.')[0]}.run"
183 run_file_path = self.save_dir / run_file
184 return run_file_path
185
186 def rank(self, batch: RankBatch, output: LightningIROutput) -> Tuple[torch.Tensor, List[str], List[int]]:
187 scores = output.scores
188 if scores is None:
189 raise ValueError("Expected output to have scores")
190 doc_ids = batch.doc_ids
191 if doc_ids is None:
192 raise ValueError("Expected batch to have doc_ids")
193 scores = scores.view(-1)
194 num_docs = [len(_doc_ids) for _doc_ids in doc_ids]
195 doc_ids = list(itertools.chain.from_iterable(doc_ids))
196 if scores.shape[0] != len(doc_ids):
197 raise ValueError("scores and doc_ids must have the same length")
198 return scores, doc_ids, num_docs
199
200 def write_run_dfs(self, trainer: Trainer, dataloader_idx: int):
201 if not trainer.is_global_zero or not self.run_dfs:
202 return
203 run_file_path = self.get_run_path(trainer, dataloader_idx)
204 run_file_path.parent.mkdir(parents=True, exist_ok=True)
205 run_df = pd.concat(self.run_dfs, ignore_index=True)
206 run_df.to_csv(run_file_path, header=False, index=False, sep="\t")
207
208 def on_test_batch_end(
209 self,
210 trainer: Trainer,
211 pl_module: LightningIRModule,
212 outputs: LightningIROutput,
213 batch: Any,
214 batch_idx: int,
215 dataloader_idx: int = 0,
216 ) -> None:
217 batch_indices = trainer.predict_loop.current_batch_indices
218 self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)
219 super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
220
221 def write_on_batch_end(
222 self,
223 trainer: Trainer,
224 pl_module: LightningIRModule,
225 prediction: LightningIROutput,
226 batch_indices: Sequence[int] | None,
227 batch: Any,
228 batch_idx: int,
229 dataloader_idx: int,
230 ) -> None:
231 batch = self.gather(pl_module, batch)
232 prediction = self.gather(pl_module, prediction)
233 if not trainer.is_global_zero:
234 return
235
236 query_ids = batch.query_ids
237 if query_ids is None:
238 raise ValueError("Expected batch to have query_ids")
239 scores, doc_ids, num_docs = self.rank(batch, prediction)
240 scores = scores.float().cpu().numpy()
241
242 query_ids = list(
243 itertools.chain.from_iterable(itertools.repeat(query_id, num) for query_id, num in zip(query_ids, num_docs))
244 )
245 run_df = pd.DataFrame(zip(query_ids, doc_ids, scores), columns=["query_id", "doc_id", "score"])
246 run_df = run_df.sort_values(["query_id", "score"], ascending=[True, False])
247 run_df["rank"] = run_df.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int)
248 run_df["q0"] = 0
249 run_df["system"] = pl_module.model.__class__.__name__
250 run_df = run_df[RUN_HEADER]
251
252 if batch_idx == 0:
253 self.write_run_dfs(trainer, dataloader_idx - 1)
254 self.run_dfs = []
255 self.run_dfs.append(run_df)
256
257
[docs]
258class SearchCallback(RankCallback):
[docs]
259 def __init__(
260 self,
261 index_dir: Path | str,
262 search_config: SearchConfig,
263 save_dir: Path | str | None = None,
264 use_gpu: bool = True,
265 ) -> None:
266 super().__init__(save_dir)
267 self.index_dir = index_dir
268 self.search_config = search_config
269 self.use_gpu = use_gpu
270 self.searcher: Searcher
271
272 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
273 if stage != "test":
274 raise ValueError(f"{self.__class__.__name__} can only be used in test stage")
275 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, pl_module, self.use_gpu)
276 pl_module.searcher = self.searcher
277
278 def rank(
279 self, batch: SearchBatch | RankBatch, output: LightningIROutput
280 ) -> Tuple[torch.Tensor | List[str] | List[int]]:
281 if isinstance(batch, SearchBatch):
282 doc_scores, flat_doc_ids, num_docs = self.searcher.search(output)
283 cum_num_docs = [0] + [sum(num_docs[: i + 1]) for i in range(len(num_docs))]
284 doc_ids = tuple(tuple(flat_doc_ids[cum_num_docs[i] : cum_num_docs[i + 1]]) for i in range(len(num_docs)))
285 output.scores = doc_scores
286 dummy_docs = [[""] * num for num in num_docs]
287 batch = RankBatch(batch.queries, dummy_docs, batch.query_ids, doc_ids, batch.qrels)
288 return super().rank(batch, output)
289
290
[docs]
291class ReRankCallback(RankCallback):
292 pass