Source code for lightning_ir.base.model

  1import warnings
  2from collections import defaultdict
  3from dataclasses import dataclass
  4from functools import partial, wraps
  5from pathlib import Path
  6from typing import Any, Callable, Literal, Mapping, Sequence, Type, TypeVar
  7
  8import torch
  9from transformers import MODEL_MAPPING, BatchEncoding, BertModel
 10from transformers.modeling_outputs import ModelOutput
 11
 12from ..flash import FLASH_ATTENTION_MAP
 13from .class_factory import LightningIRModelClassFactory
 14from .config import LightningIRConfig
 15from .external_model_hub import CHECKPOINT_MAPPING, POST_LOAD_CALLBACKS, STATE_DICT_KEY_MAPPING
 16
 17
[docs] 18@dataclass 19class LightningIROutput(ModelOutput): 20 """Base class for the output of the LightningIR model. It is a subclass of transformers.ModelOutput_. 21 22 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput 23 24 :param scores: Output relevance scores for query--document pairs, defaults to None 25 :type scores: torch.Tensor | None, optional 26 """ 27 28 scores: torch.Tensor | None = None
29 30
[docs] 31class LightningIRModel: 32 """Base class for LightningIR models. Derived classes implement the forward method for handling query 33 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model. 34 35 .. _transformers.PreTrainedModel: \ 36https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel 37 """ 38 39 config_class: Type[LightningIRConfig] = LightningIRConfig 40 """Configuration class for the model.""" 41 42 ALLOW_SUB_BATCHING = True 43 """Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure 44 correctness.""" 45
[docs] 46 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None: 47 """Initializes the model. 48 49 :param config: Configuration class for the model 50 :type config: LightningIRConfig 51 """ 52 super().__init__(config, *args, **kwargs) 53 self.config = config 54 55 self._sub_batch_size: int | None = None 56 57 if self.config.backbone_model_type is not None: 58 flash_attn = FLASH_ATTENTION_MAP.get(self.config.backbone_model_type, None) 59 if flash_attn is not None: 60 flash_attn_forward, self_attn_pattern = flash_attn 61 for name, module in self.named_modules(): 62 if name.endswith(self_attn_pattern): 63 module.forward = partial(flash_attn_forward, module)
64 65 def _backbone_forward(self, *args, **kwargs): 66 raise NotImplementedError 67
[docs] 68 def forward(self, *args, **kwargs) -> LightningIROutput: 69 """Forward method of the model. Must be implemented by the derived class.""" 70 raise NotImplementedError
71 72 def _sparsification( 73 self, embeddings: torch.Tensor, sparsification_strategy: Literal["relu", "relu_log"] | None = None 74 ) -> torch.Tensor: 75 """Helper method to apply sparsification to the embeddings. 76 77 :param embeddings: Query or document embeddings 78 :type embeddings: torch.Tensor 79 :param sparsification_strategy: The sparsification strategy. No sparsification is applied if None, 80 defaults to None 81 :type sparsification_strategy: Literal["relu", "relu_log"] | None, optional 82 :raises ValueError: If an unknown sparsification strategy is passed 83 :return: (Optionally) sparsified embeddings 84 :rtype: torch.Tensor 85 """ 86 if sparsification_strategy is None: 87 return embeddings 88 if sparsification_strategy == "relu": 89 return torch.relu(embeddings) 90 if sparsification_strategy == "relu_log": 91 return torch.log1p(torch.relu(embeddings)) 92 raise ValueError(f"Unknown sparsification strategy: {sparsification_strategy}") 93 94 def _pooling( 95 self, 96 embeddings: torch.Tensor, 97 attention_mask: torch.Tensor | None, 98 pooling_strategy: Literal["first", "mean", "max", "sum"] | None, 99 ) -> torch.Tensor: 100 """Helper method to apply pooling to the embeddings. 101 102 :param embeddings: Query or document embeddings 103 :type embeddings: torch.Tensor 104 :param attention_mask: Query or document attention mask 105 :type attention_mask: torch.Tensor | None 106 :param pooling_strategy: The pooling strategy. No pooling is applied if None. 107 :type pooling_strategy: Literal["first", "mean", "max", "sum"] | None 108 :raises ValueError: If an unknown pooling strategy is passed 109 :return: (Optionally) pooled embeddings 110 :rtype: torch.Tensor 111 """ 112 if pooling_strategy is None: 113 return embeddings 114 if pooling_strategy == "first": 115 return embeddings[:, [0]] 116 if pooling_strategy in ("sum", "mean"): 117 if attention_mask is not None: 118 embeddings = embeddings * attention_mask.unsqueeze(-1) 119 embeddings = embeddings.sum(dim=1, keepdim=True) 120 if pooling_strategy == "mean": 121 if attention_mask is not None: 122 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1) 123 return embeddings 124 if pooling_strategy == "max": 125 if attention_mask is not None: 126 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), -1e9) 127 return embeddings.max(dim=1, keepdim=True).values 128 raise ValueError(f"Unknown pooling strategy: {self.pooling_strategy}") 129 130 @classmethod 131 def _load_pretrained_model( 132 cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs 133 ): 134 if pretrained_model_name_or_path in STATE_DICT_KEY_MAPPING: 135 map_keys = STATE_DICT_KEY_MAPPING[pretrained_model_name_or_path] 136 for orig_key, new_key in map_keys: 137 if orig_key is not None: 138 state_dict[new_key] = state_dict.pop(orig_key) 139 loaded_keys[loaded_keys.index(orig_key)] = new_key 140 else: 141 loaded_keys.append(new_key) 142 model, *out = super()._load_pretrained_model( 143 model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs 144 ) 145 if pretrained_model_name_or_path in POST_LOAD_CALLBACKS: 146 model = POST_LOAD_CALLBACKS[pretrained_model_name_or_path](model) 147 return (model, *out) 148
[docs] 149 @classmethod 150 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRModel": 151 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method and to return a 152 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details. 153 154 .. _transformers.PreTrainedModel.from_pretrained: https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained # noqa 155 156 :param model_name_or_path: Name or path of the pretrained model 157 :type model_name_or_path: str | Path 158 :raises ValueError: If called on the abstract class :class:`LightningIRModel` and no config is passed 159 :return: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin 160 :rtype: LightningIRModel 161 162 .. ::doctest 163 .. highlight:: python 164 .. code-block:: python 165 166 >>> # Loading using model class and backbone checkpoint 167 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 168 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 169 >>> # Loading using base class and backbone checkpoint 170 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 171 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 172 """ 173 # provides AutoModel.from_pretrained support 174 config = kwargs.get("config", None) 175 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__): 176 # no backbone models found, create derived lightning-ir model based on backbone model 177 if model_name_or_path in CHECKPOINT_MAPPING: 178 _config = CHECKPOINT_MAPPING[model_name_or_path] 179 config_class = _config.__class__ 180 if config is not None: 181 warnings.warn(f"{model_name_or_path} is a registered checkpoint. The provided config is ignored.") 182 config = _config 183 elif config is not None: 184 config_class = config.__class__ 185 elif cls is not LightningIRModel: 186 config_class = cls.config_class 187 else: 188 config_class = LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path) 189 if config_class is None: 190 raise ValueError("Pass a config to `from_pretrained`.") 191 BackboneConfig = LightningIRModelClassFactory.get_backbone_config(model_name_or_path) 192 BackboneModel = MODEL_MAPPING[BackboneConfig] 193 cls = LightningIRModelClassFactory(config_class).from_backbone_class(BackboneModel) 194 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 195 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config) 196 derived_config.update(config.to_dict()) 197 kwargs["config"] = derived_config 198 return cls.from_pretrained(model_name_or_path, *args, **kwargs) 199 if issubclass(cls, BertModel): 200 kwargs["add_pooling_layer"] = False 201 return super(LightningIRModel, cls).from_pretrained(model_name_or_path, *args, **kwargs)
202 203 204T = TypeVar("T") 205 206 207def _cat_outputs( 208 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: Type[T] | None 209) -> torch.Tensor | T | None: 210 if len(outputs) == 1: 211 return outputs[0] 212 if len(outputs) == 0 or outputs[0] is None or OutputClass is None: 213 return None 214 if isinstance(outputs[0], torch.Tensor): 215 return torch.cat(outputs, dim=0) 216 agg = defaultdict(list) 217 types = {} 218 for output in outputs: 219 for key, value in output.items(): 220 agg[key].append(value) 221 types[key] = type(value) 222 return OutputClass(**{key: _cat_outputs(value, types[key]) for key, value in agg.items()}) 223 224 225def _batch_encoding( 226 func: Callable[[LightningIRModel, BatchEncoding, ...], Any] 227) -> Callable[[LightningIRModel, BatchEncoding, ...], Any]: 228 229 @wraps(func) 230 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any: 231 if not self.ALLOW_SUB_BATCHING: 232 return func(self, encoding, *args, **kwargs) 233 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0] 234 sub_encoding = encoding 235 remaining_encoding = encoding 236 OutputClass = None 237 outputs = [] 238 while True: 239 try: 240 # ceil division 241 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size) 242 for _ in range(num_batches): 243 sub_encoding = BatchEncoding( 244 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()} 245 ) 246 output = func(self, sub_encoding, *args, **kwargs) 247 OutputClass = output.__class__ 248 outputs.append(output) 249 remaining_encoding = BatchEncoding( 250 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()} 251 ) 252 break 253 except RuntimeError as e: 254 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e): 255 self._sub_batch_size = sub_batch_size = sub_batch_size // 2 256 if sub_batch_size == 0: 257 raise e 258 else: 259 raise e 260 if OutputClass is None: 261 raise ValueError("No output was generated.") 262 return _cat_outputs(outputs, OutputClass) 263 264 return wrapper