Source code for lightning_ir.models.xtr.model

  1from pathlib import Path
  2
  3import torch
  4from huggingface_hub import hf_hub_download
  5from transformers.modeling_utils import load_state_dict
  6
  7from ...base import LightningIRModelClassFactory
  8from ...bi_encoder.model import BiEncoderEmbedding, ScoringFunction
  9from ..col import ColModel
 10from .config import XTRConfig
 11
 12
[docs] 13class XTRScoringFunction(ScoringFunction):
[docs] 14 def __init__(self, config: XTRConfig) -> None: 15 super().__init__(config) 16 self.config: XTRConfig
17 18 def compute_similarity( 19 self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding 20 ) -> torch.Tensor: 21 similarity = super().compute_similarity(query_embeddings, doc_embeddings) 22 23 if self.training and self.xtr_token_retrieval_k is not None: 24 pass 25 # TODO implement simulated token retrieval 26 27 # if not torch.all(num_docs == num_docs[0]): 28 # raise ValueError("XTR token retrieval does not support variable number of documents.") 29 # query_embeddings = query_embeddings[:: num_docs[0]] 30 # doc_embeddings = doc_embeddings.view(1, 1, -1, doc_embeddings.shape[-1]) 31 # ib_similarity = super().compute_similarity( 32 # query_embeddings, 33 # doc_embeddings, 34 # query_scoring_mask[:: num_docs[0]], 35 # doc_scoring_mask.view(1, -1), 36 # num_docs, 37 # ) 38 # top_k_similarity = ib_similarity.topk(self.xtr_token_retrieval_k, dim=-1) 39 # cut_off_similarity = top_k_similarity.values[..., [-1]].repeat_interleave(num_docs, dim=0) 40 # if self.fill_strategy == "min": 41 # fill = cut_off_similarity.expand_as(similarity)[similarity < cut_off_similarity] 42 # elif self.fill_strategy == "zero": 43 # fill = 0 44 # similarity[similarity < cut_off_similarity] = fill 45 return similarity
46 47 # def aggregate( 48 # self, 49 # scores: torch.Tensor, 50 # mask: torch.Tensor, 51 # query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"], 52 # ) -> torch.Tensor: 53 # if self.training and self.normalization == "Z": 54 # # Z-normalization 55 # mask = mask & (scores != 0) 56 # return super().aggregate(scores, mask, query_aggregation_function) 57 58
[docs] 59class XTRModel(ColModel): 60 config_class = XTRConfig 61
[docs] 62 def __init__(self, config: XTRConfig, *args, **kwargs) -> None: 63 super().__init__(config) 64 self.scoring_function = XTRScoringFunction(config) 65 self.config: XTRConfig
66 67 @classmethod 68 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "XTRModel": 69 try: 70 hf_hub_download(repo_id=str(model_name_or_path), filename="2_Dense/pytorch_model.bin") 71 except Exception: 72 return super().from_pretrained(model_name_or_path, *args, **kwargs) 73 finally: 74 return cls.from_xtr_checkpoint(model_name_or_path) 75 76 @classmethod 77 def from_xtr_checkpoint(cls, model_name_or_path: Path | str) -> "XTRModel": 78 from transformers import T5EncoderModel 79 80 cls = LightningIRModelClassFactory(XTRConfig).from_backbone_class(T5EncoderModel) 81 config = cls.config_class.from_pretrained(model_name_or_path) 82 config.update( 83 { 84 "name_or_path": str(model_name_or_path), 85 "similarity_function": "dot", 86 "query_aggregation_function": "sum", 87 "query_expansion": False, 88 "doc_expansion": False, 89 "doc_pooling_strategy": None, 90 "doc_mask_scoring_tokens": None, 91 "normalize": True, 92 "sparsification": None, 93 "add_marker_tokens": False, 94 "embedding_dim": 128, 95 "projection": "linear_no_bias", 96 } 97 ) 98 state_dict_path = hf_hub_download(repo_id=str(model_name_or_path), filename="model.safetensors") 99 state_dict = load_state_dict(state_dict_path) 100 linear_state_dict_path = hf_hub_download(repo_id=str(model_name_or_path), filename="2_Dense/pytorch_model.bin") 101 linear_state_dict = load_state_dict(linear_state_dict_path) 102 linear_state_dict["projection.weight"] = linear_state_dict.pop("linear.weight") 103 state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] 104 state_dict.update(linear_state_dict) 105 model = cls(config=config) 106 model.load_state_dict(state_dict) 107 return model