Source code for lightning_ir.cross_encoder.model

 1from dataclasses import dataclass
 2from typing import Type
 3
 4import torch
 5from transformers import BatchEncoding
 6
 7from ..base import LightningIRModel, LightningIROutput
 8from ..base.model import _batch_encoding
 9from . import CrossEncoderConfig
10
11
[docs] 12@dataclass 13class CrossEncoderOutput(LightningIROutput): 14 embeddings: torch.Tensor | None = None
15 16
[docs] 17class CrossEncoderModel(LightningIRModel): 18 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig 19
[docs] 20 def __init__(self, config: CrossEncoderConfig, *args, **kwargs): 21 super().__init__(config, *args, **kwargs) 22 self.config: CrossEncoderConfig 23 self.linear = torch.nn.Linear(config.hidden_size, 1, bias=config.linear_bias)
24 25 @_batch_encoding 26 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 27 embeddings = self._backbone_forward(**encoding).last_hidden_state 28 embeddings = self._pooling( 29 embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy 30 ) 31 scores = self.linear(embeddings).view(-1) 32 return CrossEncoderOutput(scores=scores, embeddings=embeddings)