Source code for lightning_ir.cross_encoder.model
1from dataclasses import dataclass
2from typing import Type
3
4import torch
5from transformers import BatchEncoding
6
7from ..base import LightningIRModel, LightningIROutput
8from ..base.model import _batch_encoding
9from . import CrossEncoderConfig
10
11
[docs]
12@dataclass
13class CrossEncoderOutput(LightningIROutput):
14 embeddings: torch.Tensor | None = None
15
16
[docs]
17class CrossEncoderModel(LightningIRModel):
18 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig
19
[docs]
20 def __init__(self, config: CrossEncoderConfig, *args, **kwargs):
21 super().__init__(config, *args, **kwargs)
22 self.config: CrossEncoderConfig
23 self.linear = torch.nn.Linear(config.hidden_size, 1, bias=config.linear_bias)
24
25 @_batch_encoding
26 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
27 embeddings = self._backbone_forward(**encoding).last_hidden_state
28 embeddings = self._pooling(
29 embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy
30 )
31 scores = self.linear(embeddings).view(-1)
32 return CrossEncoderOutput(scores=scores, embeddings=embeddings)