Source code for lightning_ir.cross_encoder.module

 1from typing import List, Sequence, Tuple
 2
 3import torch
 4
 5from ..base.module import LightningIRModule
 6from ..data import RankBatch, SearchBatch, TrainBatch
 7from ..loss.loss import InBatchLossFunction, LossFunction
 8from .config import CrossEncoderConfig
 9from .model import CrossEncoderModel, CrossEncoderOutput
10from .tokenizer import CrossEncoderTokenizer
11
12
[docs] 13class CrossEncoderModule(LightningIRModule):
[docs] 14 def __init__( 15 self, 16 model_name_or_path: str | None = None, 17 config: CrossEncoderConfig | None = None, 18 model: CrossEncoderModel | None = None, 19 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 20 evaluation_metrics: Sequence[str] | None = None, 21 ): 22 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics) 23 self.model: CrossEncoderModel 24 self.config: CrossEncoderConfig 25 self.tokenizer: CrossEncoderTokenizer
26 27 def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput: 28 if isinstance(batch, SearchBatch): 29 raise NotImplementedError("Searching is not available for cross-encoders") 30 queries = batch.queries 31 docs = [d for docs in batch.docs for d in docs] 32 num_docs = [len(docs) for docs in batch.docs] 33 encoding = self.prepare_input(queries, docs, num_docs) 34 output = self.model.forward(encoding["encoding"]) 35 return output 36 37 def compute_losses(self, batch: TrainBatch) -> List[torch.Tensor]: 38 if self.loss_functions is None: 39 raise ValueError("loss_functions must be set in the module") 40 output = self.forward(batch) 41 scores = output.scores 42 if scores is None or batch.targets is None: 43 raise ValueError("scores and targets must be set in the output and batch") 44 45 scores = scores.view(len(batch.query_ids), -1) 46 targets = batch.targets.view(*scores.shape, -1) 47 48 losses = [] 49 for loss_function, _ in self.loss_functions: 50 if isinstance(loss_function, InBatchLossFunction): 51 raise NotImplementedError("InBatchLossFunction not implemented for cross-encoders") 52 losses.append(loss_function.compute_loss(scores, targets)) 53 return losses