Source code for lightning_ir.base.module

  1from collections import defaultdict
  2from pathlib import Path
  3from typing import Any, Dict, List, Sequence, Tuple
  4
  5import torch
  6from lightning import LightningModule
  7from transformers import BatchEncoding
  8
  9from ..data import RankBatch, SearchBatch, TrainBatch
 10from ..loss.loss import InBatchLossFunction, LossFunction
 11from .config import LightningIRConfig
 12from .model import LightningIRModel, LightningIROutput
 13from .tokenizer import LightningIRTokenizer
 14from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
 15
 16
[docs] 17class LightningIRModule(LightningModule): 18 """LightningIRModule base class. LightningIRModules contain a LightningIRModel and a LightningIRTokenizer and 19 implements the training, validation, and testing steps for the model. Derived classes must implement the forward 20 method for the model. 21 """ 22
[docs] 23 def __init__( 24 self, 25 model_name_or_path: str | None = None, 26 config: LightningIRConfig | None = None, 27 model: LightningIRModel | None = None, 28 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 29 evaluation_metrics: Sequence[str] | None = None, 30 ): 31 """Initializes the LightningIRModule. 32 33 .. _ir-measures: https://ir-measur.es/en/latest/index.html 34 35 :param model_name_or_path: Name or path of backbone model or fine-tuned LightningIR model, defaults to None 36 :type model_name_or_path: str | None, optional 37 :param config: LightningIRConfig to apply when loading from backbone model, defaults to None 38 :type config: LightningIRConfig | None, optional 39 :param model: Already instantiated LightningIR model, defaults to None 40 :type model: LightningIRModel | None, optional 41 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per 42 loss function, defaults to None 43 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional 44 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or 45 testing, defaults to None 46 :type evaluation_metrics: Sequence[str] | None, optional 47 :raises ValueError: If both model and model_name_or_path are provided 48 :raises ValueError: If neither model nor model_name_or_path are provided 49 """ 50 super().__init__() 51 if model is not None and model_name_or_path is not None: 52 raise ValueError("Only one of model or model_name_or_path must be provided.") 53 if model is None: 54 if model_name_or_path is None: 55 raise ValueError("Either model or model_name_or_path must be provided.") 56 model = LightningIRModel.from_pretrained(model_name_or_path, config=config) 57 58 self.model: LightningIRModel = model 59 self.config = self.model.config 60 self.loss_functions: List[Tuple[LossFunction, float]] | None = None 61 if loss_functions is not None: 62 self.loss_functions = [] 63 for loss_function in loss_functions: 64 if isinstance(loss_function, LossFunction): 65 self.loss_functions.append((loss_function, 1.0)) 66 else: 67 self.loss_functions.append(loss_function) 68 self.evaluation_metrics = evaluation_metrics 69 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config)
70
[docs] 71 def on_fit_start(self) -> None: 72 """Called at the very beginning of fit. 73 74 If on DDP it is called on every process 75 """ 76 # NOTE huggingface models are in eval mode by default 77 self.train() 78 return super().on_fit_start()
79
[docs] 80 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput: 81 """Computes relevance scores for queries and documents. 82 83 :param queries: Queries to score 84 :type queries: Sequence[str] 85 :param docs: Documents to score 86 :type docs: Sequence[Sequence[str]] 87 :return: Model output 88 :rtype: LightningIROutput 89 """ 90 if isinstance(queries, str): 91 queries = (queries,) 92 if isinstance(docs[0], str): 93 docs = (docs,) 94 batch = RankBatch(queries, docs, None, None) 95 with torch.no_grad(): 96 return self.forward(batch)
97
[docs] 98 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput: 99 """Handles the forward pass of the model. 100 101 :param batch: Batch of training or ranking data 102 :type batch: TrainBatch | RankBatch 103 :raises NotImplementedError: Must be implemented by derived class 104 :return: Model output 105 :rtype: LightningIROutput 106 """ 107 raise NotImplementedError
108
[docs] 109 def prepare_input( 110 self, queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None 111 ) -> Dict[str, BatchEncoding]: 112 """Tokenizes queries and documents and returns the tokenized BatchEncoding_. 113 114 :: _BatchEncoding: https://huggingface.co/transformers/main_classes/tokenizer#transformers.BatchEncoding 115 116 :param queries: Queries to tokenize 117 :type queries: Sequence[str] | None 118 :param docs: Documents to tokenize 119 :type docs: Sequence[str] | None 120 :param num_docs: Number of documents per query, if None num_docs is inferred by `len(docs) // len(queries)`, 121 defaults to None 122 :type num_docs: Sequence[int] | int | None 123 :return: Tokenized queries and documents, format depends on the tokenizer 124 :rtype: Dict[str, BatchEncoding] 125 """ 126 encodings = self.tokenizer.tokenize( 127 queries, docs, return_tensors="pt", padding=True, truncation=True, num_docs=num_docs 128 ) 129 for key in encodings: 130 encodings[key] = encodings[key].to(self.device) 131 return encodings
132
[docs] 133 def compute_losses(self, batch: TrainBatch) -> List[torch.Tensor]: 134 """Computes the losses for the batch. 135 136 :param batch: Batch of training data 137 :type batch: TrainBatch 138 :raises NotImplementedError: Must be implemented by derived class 139 :return: List of losses, one for each loss function 140 :rtype: List[torch.Tensor] 141 """ 142 raise NotImplementedError
143
[docs] 144 def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor: 145 """Handles the training step for the model. 146 147 :param batch: Batch of training data 148 :type batch: TrainBatch 149 :param batch_idx: Index of the batch 150 :type batch_idx: int 151 :raises ValueError: If no loss functions are set 152 :return: Sum of the losses weighted by the loss weights 153 :rtype: torch.Tensor 154 """ 155 if self.loss_functions is None: 156 raise ValueError("Loss functions are not set") 157 losses = self.compute_losses(batch) 158 total_loss = torch.tensor(0) 159 assert len(losses) == len(self.loss_functions) 160 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses): 161 self.log(loss_function.__class__.__name__, loss) 162 total_loss = total_loss + loss * loss_weight 163 self.log("loss", total_loss, prog_bar=True) 164 return total_loss
165
[docs] 166 def validation_step( 167 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0 168 ) -> LightningIROutput: 169 """Handles the validation step for the model. 170 171 :param batch: Batch of validation or testing data 172 :type batch: TrainBatch | RankBatch 173 :param batch_idx: Index of the batch 174 :type batch_idx: int 175 :param dataloader_idx: Index of the dataloader, defaults to 0 176 :type dataloader_idx: int, optional 177 :return: Model output 178 :rtype: LightningIROutput 179 """ 180 output = self.forward(batch) 181 182 if self.evaluation_metrics is None: 183 return output 184 185 dataset_id = self.get_dataset_id(dataloader_idx) 186 metrics = self.validate( 187 scores=output.scores, 188 query_ids=batch.query_ids, 189 doc_ids=batch.doc_ids, 190 qrels=batch.qrels, 191 targets=getattr(batch, "targets", None), 192 ) 193 for key, value in metrics.items(): 194 key = f"{dataset_id}/{key}" 195 self.log(key, value, batch_size=len(batch.queries)) 196 return output
197
[docs] 198 def test_step( 199 self, 200 batch: TrainBatch | RankBatch, 201 batch_idx: int, 202 dataloader_idx: int = 0, 203 ) -> LightningIROutput: 204 """Handles the testing step for the model. Passes the batch to the validation step. 205 206 :param batch: Batch of testing data 207 :type batch: TrainBatch | RankBatch 208 :param batch_idx: Index of the batch 209 :type batch_idx: int 210 :param dataloader_idx: Index of the dataloader, defaults to 0 211 :type dataloader_idx: int, optional 212 :return: Model output 213 :rtype: LightningIROutput 214 """ 215 return self.validation_step(batch, batch_idx, dataloader_idx)
216
[docs] 217 def get_dataset_id(self, dataloader_idx: int) -> str: 218 """Gets the dataset id from the dataloader index for logging. 219 220 .. _ir-datasets: https://ir-datasets.com/ 221 222 :param dataloader_idx: Index of the dataloader 223 :type dataloader_idx: int 224 :return: ir-datasets_ dataset id or dataloader index 225 :rtype: str 226 """ 227 dataset_id = str(dataloader_idx) 228 datamodule = None 229 try: 230 datamodule = getattr(self.trainer, "datamodule", None) 231 dataset_id = datamodule.inference_datasets[dataloader_idx].dataset_id 232 except Exception: 233 pass 234 return dataset_id
235
[docs] 236 def validate( 237 self, 238 scores: torch.Tensor | None = None, 239 query_ids: Sequence[str] | None = None, 240 doc_ids: Sequence[Sequence[str]] | None = None, 241 qrels: Sequence[Dict[str, int]] | None = None, 242 targets: torch.Tensor | None = None, 243 num_docs: Sequence[int] | None = None, 244 ) -> Dict[str, float]: 245 """Validates the model output with the evaluation metrics and loss functions. 246 247 :param scores: Model output scores, defaults to None 248 :type scores: torch.Tensor | None, optional 249 :param query_ids: ids of the queries, defaults to None 250 :type query_ids: Sequence[str] | None, optional 251 :param doc_ids: ids of the documents, defaults to None 252 :type doc_ids: Sequence[Sequence[str]] | None, optional 253 :param qrels: Mappings of doc_id -> relevance for each query, defaults to None 254 :type qrels: Sequence[Dict[str, int]] | None, optional 255 :param targets: Target tensor used during fine-tuning, defaults to None 256 :type targets: torch.Tensor | None, optional 257 :param num_docs: Number of documents per query, defaults to None 258 :type num_docs: Sequence[int] | None, optional 259 :raises ValueError: If num_docs can not be parsed and query_ids are not set 260 :raises ValueError: If num_docs can not be parsed and doc_ids are not set 261 :return: _description_ 262 :rtype: Dict[str, float] 263 """ 264 metrics: Dict[str, float] = {} 265 if self.evaluation_metrics is None or scores is None: 266 return metrics 267 if query_ids is None: 268 if num_docs is None: 269 raise ValueError("num_docs must be set if query_ids is not set") 270 query_ids = tuple(str(i) for i in range(len(num_docs))) 271 if doc_ids is None: 272 if num_docs is None: 273 raise ValueError("num_docs must be set if doc_ids is not set") 274 doc_ids = tuple(tuple(f"{i}-{j}" for j in range(docs)) for i, docs in enumerate(num_docs)) 275 metrics.update(self.validate_metrics(scores, query_ids, doc_ids, qrels)) 276 metrics.update(self.validate_loss(scores, query_ids, targets)) 277 return metrics
278
[docs] 279 def validate_metrics( 280 self, 281 scores: torch.Tensor, 282 query_ids: Sequence[str], 283 doc_ids: Sequence[Sequence[str]], 284 qrels: Sequence[Dict[str, int]] | None, 285 ) -> Dict[str, float]: 286 """Validates the model output with the evaluation metrics. 287 288 :param scores: Model output scores 289 :type scores: torch.Tensor 290 :param query_ids: ids of the queries 291 :type query_ids: Sequence[str] 292 :param doc_ids: ids of the documents 293 :type doc_ids: Sequence[Sequence[str]] 294 :param qrels: Mappings of doc_id -> relevance for each query, defaults to None 295 :type qrels: Sequence[Dict[str, int]] | None 296 :return: Evaluation metrics 297 :rtype: Dict[str, float] 298 """ 299 metrics: Dict[str, float] = {} 300 if self.evaluation_metrics is None or qrels is None: 301 return metrics 302 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"] 303 ir_measures_qrels = create_qrels_from_dicts(qrels) 304 if evaluation_metrics and qrels is not None: 305 run = create_run_from_scores(query_ids, doc_ids, scores) 306 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics)) 307 return metrics
308
[docs] 309 def validate_loss( 310 self, scores: torch.Tensor, query_ids: Sequence[str], targets: torch.Tensor | None 311 ) -> Dict[str, float]: 312 """Validates the model output with the loss functions. 313 314 :param scores: Model output scores 315 :type scores: torch.Tensor 316 :param query_ids: ids of the queries 317 :type query_ids: Sequence[str] 318 :param targets: Target tensor used during fine-tuning 319 :type targets: torch.Tensor | None 320 :return: Loss metrics 321 :rtype: Dict[str, float] 322 """ 323 metrics: Dict[str, float] = {} 324 if ( 325 self.evaluation_metrics is None 326 or "loss" not in self.evaluation_metrics 327 or targets is None 328 or self.loss_functions is None 329 ): 330 return metrics 331 scores = scores.view(len(query_ids), -1) 332 for loss_function, _ in self.loss_functions: 333 # NOTE skip in-batch losses because they can use a lot of memory 334 if isinstance(loss_function, InBatchLossFunction): 335 continue 336 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss( 337 scores, targets 338 ).item() 339 return metrics
340
[docs] 341 def on_validation_epoch_end(self) -> None: 342 """Logs the accumulated metrics for each dataloader.""" 343 try: 344 trainer = self.trainer 345 except RuntimeError: 346 trainer = None 347 if trainer is not None: 348 metrics = trainer.callback_metrics 349 accum_metrics = defaultdict(list) 350 for key, value in metrics.items(): 351 split = key.split("/") 352 if "dataloader_idx" in split[-1]: 353 accum_metrics[split[-2]].append(value) 354 for key, value in accum_metrics.items(): 355 self.log(key, torch.stack(value).mean(), logger=False)
356
[docs] 357 def on_test_epoch_end(self) -> None: 358 """Logs the accumulated metrics for each dataloader.""" 359 self.on_validation_epoch_end()
360
[docs] 361 def save_pretrained(self, save_path: str | Path) -> None: 362 """Saves the model and tokenizer to the save path. 363 364 :param save_path: Path to save the model and tokenizer 365 :type save_path: str | Path 366 """ 367 self.model.save_pretrained(save_path) 368 self.tokenizer.save_pretrained(save_path)
369
[docs] 370 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 371 """Saves the model and tokenizer to the trainer's log directory.""" 372 if self.trainer is not None and self.trainer.log_dir is not None: 373 if self.trainer.global_rank != 0: 374 return 375 _step = self.trainer.global_step 376 self.config.save_step = _step 377 log_dir = Path(self.trainer.log_dir) 378 save_path = log_dir / "huggingface_checkpoint" 379 self.save_pretrained(save_path)