Source code for lightning_ir.cross_encoder.module
1from typing import List, Sequence, Tuple
2
3import torch
4
5from ..base.module import LightningIRModule
6from ..data import RankBatch, SearchBatch, TrainBatch
7from ..loss.loss import InBatchLossFunction, LossFunction
8from .config import CrossEncoderConfig
9from .model import CrossEncoderModel, CrossEncoderOutput
10from .tokenizer import CrossEncoderTokenizer
11
12
[docs]
13class CrossEncoderModule(LightningIRModule):
[docs]
14 def __init__(
15 self,
16 model_name_or_path: str | None = None,
17 config: CrossEncoderConfig | None = None,
18 model: CrossEncoderModel | None = None,
19 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
20 evaluation_metrics: Sequence[str] | None = None,
21 ):
22 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics)
23 self.model: CrossEncoderModel
24 self.config: CrossEncoderConfig
25 self.tokenizer: CrossEncoderTokenizer
26
27 def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput:
28 if isinstance(batch, SearchBatch):
29 raise NotImplementedError("Searching is not available for cross-encoders")
30 queries = batch.queries
31 docs = [d for docs in batch.docs for d in docs]
32 num_docs = [len(docs) for docs in batch.docs]
33 encoding = self.prepare_input(queries, docs, num_docs)
34 output = self.model.forward(encoding["encoding"])
35 return output
36
37 def compute_losses(self, batch: TrainBatch) -> List[torch.Tensor]:
38 if self.loss_functions is None:
39 raise ValueError("loss_functions must be set in the module")
40 output = self.forward(batch)
41 scores = output.scores
42 if scores is None or batch.targets is None:
43 raise ValueError("scores and targets must be set in the output and batch")
44
45 scores = scores.view(len(batch.query_ids), -1)
46 targets = batch.targets.view(*scores.shape, -1)
47
48 losses = []
49 for loss_function, _ in self.loss_functions:
50 if isinstance(loss_function, InBatchLossFunction):
51 raise NotImplementedError("InBatchLossFunction not implemented for cross-encoders")
52 losses.append(loss_function.compute_loss(scores, targets))
53 return losses