Source code for lightning_ir.lightning_utils.callbacks

  1from __future__ import annotations
  2
  3import itertools
  4from dataclasses import is_dataclass
  5from pathlib import Path
  6from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, TypeVar
  7
  8import pandas as pd
  9import torch
 10from lightning import LightningModule, Trainer
 11from lightning.pytorch.callbacks import BasePredictionWriter, Callback, TQDMProgressBar
 12
 13from ..data import RankBatch, SearchBatch
 14from ..data.dataset import RUN_HEADER, DocDataset, QueryDataset, RunDataset
 15from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher
 16
 17if TYPE_CHECKING:
 18    from ..base import LightningIRModule, LightningIROutput
 19    from ..bi_encoder import BiEncoderModule, BiEncoderOutput
 20
 21T = TypeVar("T")
 22
 23
[docs] 24def format_large_number(number: float) -> str: 25 suffixes = ["", "K", "M", "B", "T"] 26 suffix_index = 0 27 28 while number >= 1000 and suffix_index < len(suffixes) - 1: 29 number /= 1000.0 30 suffix_index += 1 31 32 formatted_number = "{:.2f}".format(number) 33 34 suffix = suffixes[suffix_index] 35 if suffix: 36 formatted_number += f" {suffix}" 37 return formatted_number
38 39
[docs] 40class GatherMixin: 41 def gather(self, pl_module: LightningIRModule, dataclass: T) -> T: 42 if is_dataclass(dataclass): 43 return dataclass.__class__( 44 **{k: self.gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__} 45 ) 46 return pl_module.all_gather(dataclass)
47 48
[docs] 49class IndexCallback(Callback, GatherMixin):
[docs] 50 def __init__( 51 self, 52 index_dir: Path | str | None, 53 index_config: IndexConfig, 54 verbose: bool = False, 55 ) -> None: 56 super().__init__() 57 self.index_config = index_config 58 self.index_dir = index_dir 59 self.verbose = verbose 60 self.indexer: Indexer
61 62 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 63 if stage != "test": 64 raise ValueError("IndexCallback can only be used in test stage") 65 66 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None: 67 dataloaders = trainer.test_dataloaders 68 if dataloaders is None: 69 raise ValueError("No test_dataloaders found") 70 datasets = [dataloader.dataset for dataloader in dataloaders] 71 if not all(isinstance(dataset, DocDataset) for dataset in datasets): 72 raise ValueError("Expected DocDatasets for indexing") 73 74 def get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path: 75 index_dir = self.index_dir 76 if index_dir is None: 77 default_index_dir = Path(pl_module.config.name_or_path) 78 if default_index_dir.exists(): 79 index_dir = default_index_dir / "indexes" 80 else: 81 raise ValueError("No index_dir provided and model_name_or_path is not a path") 82 index_dir = index_dir / dataset.docs_dataset_id 83 return Path(index_dir) 84 85 def get_indexer(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Indexer: 86 dataloaders = trainer.test_dataloaders 87 if dataloaders is None: 88 raise ValueError("No test_dataloaders found") 89 dataset = dataloaders[dataset_idx].dataset 90 91 index_dir = self.get_index_dir(pl_module, dataset) 92 93 indexer = self.index_config.indexer_class(index_dir, self.index_config, pl_module.config, self.verbose) 94 return indexer 95 96 def log_to_pg(self, info: Dict[str, Any], trainer: Trainer): 97 pg_callback = trainer.progress_bar_callback 98 if pg_callback is None or not isinstance(pg_callback, TQDMProgressBar): 99 return 100 pg = pg_callback.test_progress_bar 101 info = {k: format_large_number(v) for k, v in info.items()} 102 if pg is not None: 103 pg.set_postfix(info) 104 105 def on_test_batch_end( 106 self, 107 trainer: Trainer, 108 pl_module: BiEncoderModule, 109 outputs: BiEncoderOutput, 110 batch: Any, 111 batch_idx: int, 112 dataloader_idx: int = 0, 113 ) -> None: 114 if batch_idx == 0: 115 if hasattr(self, "indexer"): 116 self.indexer.save() 117 self.indexer = self.get_indexer(trainer, pl_module, dataloader_idx) 118 119 batch = self.gather(pl_module, batch) 120 outputs = self.gather(pl_module, outputs) 121 122 if not trainer.is_global_zero: 123 return 124 125 self.indexer.add(batch, outputs) 126 self.log_to_pg( 127 { 128 "num_docs": self.indexer.num_docs, 129 "num_embeddings": self.indexer.num_embeddings, 130 }, 131 trainer, 132 ) 133 return super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) 134 135 def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 136 self.indexer.save()
137 138
[docs] 139class RankCallback(BasePredictionWriter, GatherMixin):
[docs] 140 def __init__(self, save_dir: Path | str | None = None, run_name: str | None = None) -> None: 141 super().__init__() 142 self.save_dir = Path(save_dir) if save_dir is not None else None 143 self.run_name = run_name 144 self.run_dfs: List[pd.DataFrame] = []
145 146 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None: 147 if stage != "test": 148 raise ValueError(f"{self.__class__.__name__} can only be used in test stage") 149 150 def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None: 151 super().on_test_epoch_start(trainer, pl_module) 152 self.run_dfs = [] 153 if self.save_dir is None: 154 default_save_dir = Path(pl_module.config.name_or_path) 155 if default_save_dir.exists(): 156 self.save_dir = default_save_dir / "runs" 157 print(f"Using default save_dir {self.save_dir}") 158 else: 159 raise ValueError("No save_dir provided and model_name_or_path is not a path") 160 161 def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 162 super().on_test_epoch_end(trainer, pl_module) 163 if trainer.is_global_zero: 164 self.write_run_dfs(trainer, -1) 165 self.run_dfs = [] 166 167 def get_run_path(self, trainer: Trainer, dataset_idx: int) -> Path: 168 dataloaders = trainer.test_dataloaders 169 if self.save_dir is None: 170 raise ValueError("No save_dir found; call setup before using this method") 171 if dataloaders is None: 172 raise ValueError("No test_dataloaders found") 173 dataset = dataloaders[dataset_idx].dataset 174 if self.run_name is not None: 175 run_file = self.run_name 176 elif isinstance(dataset, QueryDataset): 177 run_file = f"{dataset.dataset_id.replace('/', '-')}.run" 178 elif isinstance(dataset, RunDataset): 179 if dataset.run_path is None: 180 run_file = f"{dataset.dataset_id.replace('/', '-')}.run" 181 else: 182 run_file = f"{dataset.run_path.name.split('.')[0]}.run" 183 run_file_path = self.save_dir / run_file 184 return run_file_path 185 186 def rank(self, batch: RankBatch, output: LightningIROutput) -> Tuple[torch.Tensor, List[str], List[int]]: 187 scores = output.scores 188 if scores is None: 189 raise ValueError("Expected output to have scores") 190 doc_ids = batch.doc_ids 191 if doc_ids is None: 192 raise ValueError("Expected batch to have doc_ids") 193 scores = scores.view(-1) 194 num_docs = [len(_doc_ids) for _doc_ids in doc_ids] 195 doc_ids = list(itertools.chain.from_iterable(doc_ids)) 196 if scores.shape[0] != len(doc_ids): 197 raise ValueError("scores and doc_ids must have the same length") 198 return scores, doc_ids, num_docs 199 200 def write_run_dfs(self, trainer: Trainer, dataloader_idx: int): 201 if not trainer.is_global_zero or not self.run_dfs: 202 return 203 run_file_path = self.get_run_path(trainer, dataloader_idx) 204 run_file_path.parent.mkdir(parents=True, exist_ok=True) 205 run_df = pd.concat(self.run_dfs, ignore_index=True) 206 run_df.to_csv(run_file_path, header=False, index=False, sep="\t") 207 208 def on_test_batch_end( 209 self, 210 trainer: Trainer, 211 pl_module: LightningIRModule, 212 outputs: LightningIROutput, 213 batch: Any, 214 batch_idx: int, 215 dataloader_idx: int = 0, 216 ) -> None: 217 batch_indices = trainer.predict_loop.current_batch_indices 218 self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx) 219 super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) 220 221 def write_on_batch_end( 222 self, 223 trainer: Trainer, 224 pl_module: LightningIRModule, 225 prediction: LightningIROutput, 226 batch_indices: Sequence[int] | None, 227 batch: Any, 228 batch_idx: int, 229 dataloader_idx: int, 230 ) -> None: 231 batch = self.gather(pl_module, batch) 232 prediction = self.gather(pl_module, prediction) 233 if not trainer.is_global_zero: 234 return 235 236 query_ids = batch.query_ids 237 if query_ids is None: 238 raise ValueError("Expected batch to have query_ids") 239 scores, doc_ids, num_docs = self.rank(batch, prediction) 240 scores = scores.float().cpu().numpy() 241 242 query_ids = list( 243 itertools.chain.from_iterable(itertools.repeat(query_id, num) for query_id, num in zip(query_ids, num_docs)) 244 ) 245 run_df = pd.DataFrame(zip(query_ids, doc_ids, scores), columns=["query_id", "doc_id", "score"]) 246 run_df = run_df.sort_values(["query_id", "score"], ascending=[True, False]) 247 run_df["rank"] = run_df.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int) 248 run_df["q0"] = 0 249 run_df["system"] = pl_module.model.__class__.__name__ 250 run_df = run_df[RUN_HEADER] 251 252 if batch_idx == 0: 253 self.write_run_dfs(trainer, dataloader_idx - 1) 254 self.run_dfs = [] 255 self.run_dfs.append(run_df)
256 257
[docs] 258class SearchCallback(RankCallback):
[docs] 259 def __init__( 260 self, 261 index_dir: Path | str, 262 search_config: SearchConfig, 263 save_dir: Path | str | None = None, 264 use_gpu: bool = True, 265 ) -> None: 266 super().__init__(save_dir) 267 self.index_dir = index_dir 268 self.search_config = search_config 269 self.use_gpu = use_gpu 270 self.searcher: Searcher
271 272 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 273 if stage != "test": 274 raise ValueError(f"{self.__class__.__name__} can only be used in test stage") 275 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, pl_module, self.use_gpu) 276 pl_module.searcher = self.searcher 277 278 def rank( 279 self, batch: SearchBatch | RankBatch, output: LightningIROutput 280 ) -> Tuple[torch.Tensor | List[str] | List[int]]: 281 if isinstance(batch, SearchBatch): 282 doc_scores, flat_doc_ids, num_docs = self.searcher.search(output) 283 cum_num_docs = [0] + [sum(num_docs[: i + 1]) for i in range(len(num_docs))] 284 doc_ids = tuple(tuple(flat_doc_ids[cum_num_docs[i] : cum_num_docs[i + 1]]) for i in range(len(num_docs))) 285 output.scores = doc_scores 286 dummy_docs = [[""] * num for num in num_docs] 287 batch = RankBatch(batch.queries, dummy_docs, batch.query_ids, doc_ids, batch.qrels) 288 return super().rank(batch, output)
289 290
[docs] 291class ReRankCallback(RankCallback): 292 pass