Source code for lightning_ir.models.t5.model

 1import torch
 2from transformers import BatchEncoding
 3
 4from ...cross_encoder.model import CrossEncoderModel, CrossEncoderOutput
 5from .config import T5CrossEncoderConfig
 6
 7
[docs] 8class ScaleLinear(torch.nn.Linear): 9 10 def forward(self, input: torch.Tensor) -> torch.Tensor: 11 # Rescale output before projecting on vocab 12 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa 13 input = input * (input.shape[-1] ** -0.5) 14 return super().forward(input)
15 16
[docs] 17class T5CrossEncoderModel(CrossEncoderModel): 18 config_class = T5CrossEncoderConfig 19 20 _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "linear.weight"] 21
[docs] 22 def __init__(self, config: T5CrossEncoderConfig, *args, **kwargs): 23 super().__init__(config, *args, **kwargs) 24 self.config: T5CrossEncoderConfig 25 if self.config.decoder_strategy == "mono": 26 self.linear = ScaleLinear(config.hidden_size, 2, bias=config.linear_bias) 27 else: 28 self.linear = ScaleLinear(config.hidden_size, 1, bias=config.linear_bias)
29 30 # TODO tieing of weights does not work when setting linear to only use slice of lm head for efficiency 31 # def get_output_embeddings(self): 32 # shared = self.shared 33 # if self.config.decoder_strategy == "mono": 34 # self.linear.weight.data = shared.weight.data[[1176, 6136]] 35 # elif self.config.decoder_strategy == "rank": 36 # self.linear.weight.data = shared.weight.data[[32089]] 37 # else: 38 # raise ValueError("Unknown decoder strategy") 39 # return shared 40 41 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 42 decoder_input_ids = torch.zeros( 43 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long 44 ) 45 encoding["decoder_input_ids"] = decoder_input_ids 46 output = super().forward(encoding) 47 if output.scores is None: 48 raise ValueError("Scores are None") 49 if self.config.decoder_strategy == "mono": 50 scores = output.scores.view(-1, 2) 51 scores = torch.nn.functional.log_softmax(scores, dim=-1)[:, 0] 52 output.scores = scores.view(-1) 53 return output