1from collections import defaultdict
2from pathlib import Path
3from typing import Any, Dict, List, Sequence, Tuple
4
5import torch
6from lightning import LightningModule
7from transformers import BatchEncoding
8
9from ..data import RankBatch, SearchBatch, TrainBatch
10from ..loss.loss import InBatchLossFunction, LossFunction
11from .config import LightningIRConfig
12from .model import LightningIRModel, LightningIROutput
13from .tokenizer import LightningIRTokenizer
14from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
15
16
[docs]
17class LightningIRModule(LightningModule):
18 """LightningIRModule base class. LightningIRModules contain a LightningIRModel and a LightningIRTokenizer and
19 implements the training, validation, and testing steps for the model. Derived classes must implement the forward
20 method for the model.
21 """
22
[docs]
23 def __init__(
24 self,
25 model_name_or_path: str | None = None,
26 config: LightningIRConfig | None = None,
27 model: LightningIRModel | None = None,
28 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
29 evaluation_metrics: Sequence[str] | None = None,
30 ):
31 """Initializes the LightningIRModule.
32
33 .. _ir-measures: https://ir-measur.es/en/latest/index.html
34
35 :param model_name_or_path: Name or path of backbone model or fine-tuned LightningIR model, defaults to None
36 :type model_name_or_path: str | None, optional
37 :param config: LightningIRConfig to apply when loading from backbone model, defaults to None
38 :type config: LightningIRConfig | None, optional
39 :param model: Already instantiated LightningIR model, defaults to None
40 :type model: LightningIRModel | None, optional
41 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per
42 loss function, defaults to None
43 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional
44 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or
45 testing, defaults to None
46 :type evaluation_metrics: Sequence[str] | None, optional
47 :raises ValueError: If both model and model_name_or_path are provided
48 :raises ValueError: If neither model nor model_name_or_path are provided
49 """
50 super().__init__()
51 if model is not None and model_name_or_path is not None:
52 raise ValueError("Only one of model or model_name_or_path must be provided.")
53 if model is None:
54 if model_name_or_path is None:
55 raise ValueError("Either model or model_name_or_path must be provided.")
56 model = LightningIRModel.from_pretrained(model_name_or_path, config=config)
57
58 self.model: LightningIRModel = model
59 self.config = self.model.config
60 self.loss_functions: List[Tuple[LossFunction, float]] | None = None
61 if loss_functions is not None:
62 self.loss_functions = []
63 for loss_function in loss_functions:
64 if isinstance(loss_function, LossFunction):
65 self.loss_functions.append((loss_function, 1.0))
66 else:
67 self.loss_functions.append(loss_function)
68 self.evaluation_metrics = evaluation_metrics
69 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config)
70
[docs]
71 def on_fit_start(self) -> None:
72 """Called at the very beginning of fit.
73
74 If on DDP it is called on every process
75 """
76 # NOTE huggingface models are in eval mode by default
77 self.train()
78 return super().on_fit_start()
79
[docs]
80 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput:
81 """Computes relevance scores for queries and documents.
82
83 :param queries: Queries to score
84 :type queries: Sequence[str]
85 :param docs: Documents to score
86 :type docs: Sequence[Sequence[str]]
87 :return: Model output
88 :rtype: LightningIROutput
89 """
90 if isinstance(queries, str):
91 queries = (queries,)
92 if isinstance(docs[0], str):
93 docs = (docs,)
94 batch = RankBatch(queries, docs, None, None)
95 with torch.no_grad():
96 return self.forward(batch)
97
[docs]
98 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput:
99 """Handles the forward pass of the model.
100
101 :param batch: Batch of training or ranking data
102 :type batch: TrainBatch | RankBatch
103 :raises NotImplementedError: Must be implemented by derived class
104 :return: Model output
105 :rtype: LightningIROutput
106 """
107 raise NotImplementedError
108
132
[docs]
133 def compute_losses(self, batch: TrainBatch) -> List[torch.Tensor]:
134 """Computes the losses for the batch.
135
136 :param batch: Batch of training data
137 :type batch: TrainBatch
138 :raises NotImplementedError: Must be implemented by derived class
139 :return: List of losses, one for each loss function
140 :rtype: List[torch.Tensor]
141 """
142 raise NotImplementedError
143
[docs]
144 def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor:
145 """Handles the training step for the model.
146
147 :param batch: Batch of training data
148 :type batch: TrainBatch
149 :param batch_idx: Index of the batch
150 :type batch_idx: int
151 :raises ValueError: If no loss functions are set
152 :return: Sum of the losses weighted by the loss weights
153 :rtype: torch.Tensor
154 """
155 if self.loss_functions is None:
156 raise ValueError("Loss functions are not set")
157 losses = self.compute_losses(batch)
158 total_loss = torch.tensor(0)
159 assert len(losses) == len(self.loss_functions)
160 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses):
161 self.log(loss_function.__class__.__name__, loss)
162 total_loss = total_loss + loss * loss_weight
163 self.log("loss", total_loss, prog_bar=True)
164 return total_loss
165
[docs]
166 def validation_step(
167 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0
168 ) -> LightningIROutput:
169 """Handles the validation step for the model.
170
171 :param batch: Batch of validation or testing data
172 :type batch: TrainBatch | RankBatch
173 :param batch_idx: Index of the batch
174 :type batch_idx: int
175 :param dataloader_idx: Index of the dataloader, defaults to 0
176 :type dataloader_idx: int, optional
177 :return: Model output
178 :rtype: LightningIROutput
179 """
180 output = self.forward(batch)
181
182 if self.evaluation_metrics is None:
183 return output
184
185 dataset_id = self.get_dataset_id(dataloader_idx)
186 metrics = self.validate(
187 scores=output.scores,
188 query_ids=batch.query_ids,
189 doc_ids=batch.doc_ids,
190 qrels=batch.qrels,
191 targets=getattr(batch, "targets", None),
192 )
193 for key, value in metrics.items():
194 key = f"{dataset_id}/{key}"
195 self.log(key, value, batch_size=len(batch.queries))
196 return output
197
[docs]
198 def test_step(
199 self,
200 batch: TrainBatch | RankBatch,
201 batch_idx: int,
202 dataloader_idx: int = 0,
203 ) -> LightningIROutput:
204 """Handles the testing step for the model. Passes the batch to the validation step.
205
206 :param batch: Batch of testing data
207 :type batch: TrainBatch | RankBatch
208 :param batch_idx: Index of the batch
209 :type batch_idx: int
210 :param dataloader_idx: Index of the dataloader, defaults to 0
211 :type dataloader_idx: int, optional
212 :return: Model output
213 :rtype: LightningIROutput
214 """
215 return self.validation_step(batch, batch_idx, dataloader_idx)
216
[docs]
217 def get_dataset_id(self, dataloader_idx: int) -> str:
218 """Gets the dataset id from the dataloader index for logging.
219
220 .. _ir-datasets: https://ir-datasets.com/
221
222 :param dataloader_idx: Index of the dataloader
223 :type dataloader_idx: int
224 :return: ir-datasets_ dataset id or dataloader index
225 :rtype: str
226 """
227 dataset_id = str(dataloader_idx)
228 datamodule = None
229 try:
230 datamodule = getattr(self.trainer, "datamodule", None)
231 dataset_id = datamodule.inference_datasets[dataloader_idx].dataset_id
232 except Exception:
233 pass
234 return dataset_id
235
[docs]
236 def validate(
237 self,
238 scores: torch.Tensor | None = None,
239 query_ids: Sequence[str] | None = None,
240 doc_ids: Sequence[Sequence[str]] | None = None,
241 qrels: Sequence[Dict[str, int]] | None = None,
242 targets: torch.Tensor | None = None,
243 num_docs: Sequence[int] | None = None,
244 ) -> Dict[str, float]:
245 """Validates the model output with the evaluation metrics and loss functions.
246
247 :param scores: Model output scores, defaults to None
248 :type scores: torch.Tensor | None, optional
249 :param query_ids: ids of the queries, defaults to None
250 :type query_ids: Sequence[str] | None, optional
251 :param doc_ids: ids of the documents, defaults to None
252 :type doc_ids: Sequence[Sequence[str]] | None, optional
253 :param qrels: Mappings of doc_id -> relevance for each query, defaults to None
254 :type qrels: Sequence[Dict[str, int]] | None, optional
255 :param targets: Target tensor used during fine-tuning, defaults to None
256 :type targets: torch.Tensor | None, optional
257 :param num_docs: Number of documents per query, defaults to None
258 :type num_docs: Sequence[int] | None, optional
259 :raises ValueError: If num_docs can not be parsed and query_ids are not set
260 :raises ValueError: If num_docs can not be parsed and doc_ids are not set
261 :return: _description_
262 :rtype: Dict[str, float]
263 """
264 metrics: Dict[str, float] = {}
265 if self.evaluation_metrics is None or scores is None:
266 return metrics
267 if query_ids is None:
268 if num_docs is None:
269 raise ValueError("num_docs must be set if query_ids is not set")
270 query_ids = tuple(str(i) for i in range(len(num_docs)))
271 if doc_ids is None:
272 if num_docs is None:
273 raise ValueError("num_docs must be set if doc_ids is not set")
274 doc_ids = tuple(tuple(f"{i}-{j}" for j in range(docs)) for i, docs in enumerate(num_docs))
275 metrics.update(self.validate_metrics(scores, query_ids, doc_ids, qrels))
276 metrics.update(self.validate_loss(scores, query_ids, targets))
277 return metrics
278
[docs]
279 def validate_metrics(
280 self,
281 scores: torch.Tensor,
282 query_ids: Sequence[str],
283 doc_ids: Sequence[Sequence[str]],
284 qrels: Sequence[Dict[str, int]] | None,
285 ) -> Dict[str, float]:
286 """Validates the model output with the evaluation metrics.
287
288 :param scores: Model output scores
289 :type scores: torch.Tensor
290 :param query_ids: ids of the queries
291 :type query_ids: Sequence[str]
292 :param doc_ids: ids of the documents
293 :type doc_ids: Sequence[Sequence[str]]
294 :param qrels: Mappings of doc_id -> relevance for each query, defaults to None
295 :type qrels: Sequence[Dict[str, int]] | None
296 :return: Evaluation metrics
297 :rtype: Dict[str, float]
298 """
299 metrics: Dict[str, float] = {}
300 if self.evaluation_metrics is None or qrels is None:
301 return metrics
302 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"]
303 ir_measures_qrels = create_qrels_from_dicts(qrels)
304 if evaluation_metrics and qrels is not None:
305 run = create_run_from_scores(query_ids, doc_ids, scores)
306 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics))
307 return metrics
308
[docs]
309 def validate_loss(
310 self, scores: torch.Tensor, query_ids: Sequence[str], targets: torch.Tensor | None
311 ) -> Dict[str, float]:
312 """Validates the model output with the loss functions.
313
314 :param scores: Model output scores
315 :type scores: torch.Tensor
316 :param query_ids: ids of the queries
317 :type query_ids: Sequence[str]
318 :param targets: Target tensor used during fine-tuning
319 :type targets: torch.Tensor | None
320 :return: Loss metrics
321 :rtype: Dict[str, float]
322 """
323 metrics: Dict[str, float] = {}
324 if (
325 self.evaluation_metrics is None
326 or "loss" not in self.evaluation_metrics
327 or targets is None
328 or self.loss_functions is None
329 ):
330 return metrics
331 scores = scores.view(len(query_ids), -1)
332 for loss_function, _ in self.loss_functions:
333 # NOTE skip in-batch losses because they can use a lot of memory
334 if isinstance(loss_function, InBatchLossFunction):
335 continue
336 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(
337 scores, targets
338 ).item()
339 return metrics
340
[docs]
341 def on_validation_epoch_end(self) -> None:
342 """Logs the accumulated metrics for each dataloader."""
343 try:
344 trainer = self.trainer
345 except RuntimeError:
346 trainer = None
347 if trainer is not None:
348 metrics = trainer.callback_metrics
349 accum_metrics = defaultdict(list)
350 for key, value in metrics.items():
351 split = key.split("/")
352 if "dataloader_idx" in split[-1]:
353 accum_metrics[split[-2]].append(value)
354 for key, value in accum_metrics.items():
355 self.log(key, torch.stack(value).mean(), logger=False)
356
[docs]
357 def on_test_epoch_end(self) -> None:
358 """Logs the accumulated metrics for each dataloader."""
359 self.on_validation_epoch_end()
360
[docs]
361 def save_pretrained(self, save_path: str | Path) -> None:
362 """Saves the model and tokenizer to the save path.
363
364 :param save_path: Path to save the model and tokenizer
365 :type save_path: str | Path
366 """
367 self.model.save_pretrained(save_path)
368 self.tokenizer.save_pretrained(save_path)
369
[docs]
370 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
371 """Saves the model and tokenizer to the trainer's log directory."""
372 if self.trainer is not None and self.trainer.log_dir is not None:
373 if self.trainer.global_rank != 0:
374 return
375 _step = self.trainer.global_step
376 self.config.save_step = _step
377 log_dir = Path(self.trainer.log_dir)
378 save_path = log_dir / "huggingface_checkpoint"
379 self.save_pretrained(save_path)