1import warnings
2from collections import defaultdict
3from dataclasses import dataclass
4from functools import partial, wraps
5from pathlib import Path
6from typing import Any, Callable, Literal, Mapping, Sequence, Type, TypeVar
7
8import torch
9from transformers import MODEL_MAPPING, BatchEncoding, BertModel
10from transformers.modeling_outputs import ModelOutput
11
12from ..flash import FLASH_ATTENTION_MAP
13from .class_factory import LightningIRModelClassFactory
14from .config import LightningIRConfig
15from .external_model_hub import CHECKPOINT_MAPPING, POST_LOAD_CALLBACKS, STATE_DICT_KEY_MAPPING
16
17
[docs]
18@dataclass
19class LightningIROutput(ModelOutput):
20 """Base class for the output of the LightningIR model. It is a subclass of transformers.ModelOutput_.
21
22 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput
23
24 :param scores: Output relevance scores for query--document pairs, defaults to None
25 :type scores: torch.Tensor | None, optional
26 """
27
28 scores: torch.Tensor | None = None
29
30
[docs]
31class LightningIRModel:
32 """Base class for LightningIR models. Derived classes implement the forward method for handling query
33 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model.
34
35 .. _transformers.PreTrainedModel: \
36https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel
37 """
38
39 config_class: Type[LightningIRConfig] = LightningIRConfig
40 """Configuration class for the model."""
41
42 ALLOW_SUB_BATCHING = True
43 """Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure
44 correctness."""
45
[docs]
46 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None:
47 """Initializes the model.
48
49 :param config: Configuration class for the model
50 :type config: LightningIRConfig
51 """
52 super().__init__(config, *args, **kwargs)
53 self.config = config
54
55 self._sub_batch_size: int | None = None
56
57 if self.config.backbone_model_type is not None:
58 flash_attn = FLASH_ATTENTION_MAP.get(self.config.backbone_model_type, None)
59 if flash_attn is not None:
60 flash_attn_forward, self_attn_pattern = flash_attn
61 for name, module in self.named_modules():
62 if name.endswith(self_attn_pattern):
63 module.forward = partial(flash_attn_forward, module)
64
65 def _backbone_forward(self, *args, **kwargs):
66 raise NotImplementedError
67
[docs]
68 def forward(self, *args, **kwargs) -> LightningIROutput:
69 """Forward method of the model. Must be implemented by the derived class."""
70 raise NotImplementedError
71
72 def _sparsification(
73 self, embeddings: torch.Tensor, sparsification_strategy: Literal["relu", "relu_log"] | None = None
74 ) -> torch.Tensor:
75 """Helper method to apply sparsification to the embeddings.
76
77 :param embeddings: Query or document embeddings
78 :type embeddings: torch.Tensor
79 :param sparsification_strategy: The sparsification strategy. No sparsification is applied if None,
80 defaults to None
81 :type sparsification_strategy: Literal["relu", "relu_log"] | None, optional
82 :raises ValueError: If an unknown sparsification strategy is passed
83 :return: (Optionally) sparsified embeddings
84 :rtype: torch.Tensor
85 """
86 if sparsification_strategy is None:
87 return embeddings
88 if sparsification_strategy == "relu":
89 return torch.relu(embeddings)
90 if sparsification_strategy == "relu_log":
91 return torch.log1p(torch.relu(embeddings))
92 raise ValueError(f"Unknown sparsification strategy: {sparsification_strategy}")
93
94 def _pooling(
95 self,
96 embeddings: torch.Tensor,
97 attention_mask: torch.Tensor | None,
98 pooling_strategy: Literal["first", "mean", "max", "sum"] | None,
99 ) -> torch.Tensor:
100 """Helper method to apply pooling to the embeddings.
101
102 :param embeddings: Query or document embeddings
103 :type embeddings: torch.Tensor
104 :param attention_mask: Query or document attention mask
105 :type attention_mask: torch.Tensor | None
106 :param pooling_strategy: The pooling strategy. No pooling is applied if None.
107 :type pooling_strategy: Literal["first", "mean", "max", "sum"] | None
108 :raises ValueError: If an unknown pooling strategy is passed
109 :return: (Optionally) pooled embeddings
110 :rtype: torch.Tensor
111 """
112 if pooling_strategy is None:
113 return embeddings
114 if pooling_strategy == "first":
115 return embeddings[:, [0]]
116 if pooling_strategy in ("sum", "mean"):
117 if attention_mask is not None:
118 embeddings = embeddings * attention_mask.unsqueeze(-1)
119 embeddings = embeddings.sum(dim=1, keepdim=True)
120 if pooling_strategy == "mean":
121 if attention_mask is not None:
122 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1)
123 return embeddings
124 if pooling_strategy == "max":
125 if attention_mask is not None:
126 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), -1e9)
127 return embeddings.max(dim=1, keepdim=True).values
128 raise ValueError(f"Unknown pooling strategy: {self.pooling_strategy}")
129
130 @classmethod
131 def _load_pretrained_model(
132 cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs
133 ):
134 if pretrained_model_name_or_path in STATE_DICT_KEY_MAPPING:
135 map_keys = STATE_DICT_KEY_MAPPING[pretrained_model_name_or_path]
136 for orig_key, new_key in map_keys:
137 if orig_key is not None:
138 state_dict[new_key] = state_dict.pop(orig_key)
139 loaded_keys[loaded_keys.index(orig_key)] = new_key
140 else:
141 loaded_keys.append(new_key)
142 model, *out = super()._load_pretrained_model(
143 model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs
144 )
145 if pretrained_model_name_or_path in POST_LOAD_CALLBACKS:
146 model = POST_LOAD_CALLBACKS[pretrained_model_name_or_path](model)
147 return (model, *out)
148
[docs]
149 @classmethod
150 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRModel":
151 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method and to return a
152 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details.
153
154 .. _transformers.PreTrainedModel.from_pretrained: https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained # noqa
155
156 :param model_name_or_path: Name or path of the pretrained model
157 :type model_name_or_path: str | Path
158 :raises ValueError: If called on the abstract class :class:`LightningIRModel` and no config is passed
159 :return: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin
160 :rtype: LightningIRModel
161
162 .. ::doctest
163 .. highlight:: python
164 .. code-block:: python
165
166 >>> # Loading using model class and backbone checkpoint
167 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
168 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
169 >>> # Loading using base class and backbone checkpoint
170 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
171 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
172 """
173 # provides AutoModel.from_pretrained support
174 config = kwargs.get("config", None)
175 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__):
176 # no backbone models found, create derived lightning-ir model based on backbone model
177 if model_name_or_path in CHECKPOINT_MAPPING:
178 _config = CHECKPOINT_MAPPING[model_name_or_path]
179 config_class = _config.__class__
180 if config is not None:
181 warnings.warn(f"{model_name_or_path} is a registered checkpoint. The provided config is ignored.")
182 config = _config
183 elif config is not None:
184 config_class = config.__class__
185 elif cls is not LightningIRModel:
186 config_class = cls.config_class
187 else:
188 config_class = LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path)
189 if config_class is None:
190 raise ValueError("Pass a config to `from_pretrained`.")
191 BackboneConfig = LightningIRModelClassFactory.get_backbone_config(model_name_or_path)
192 BackboneModel = MODEL_MAPPING[BackboneConfig]
193 cls = LightningIRModelClassFactory(config_class).from_backbone_class(BackboneModel)
194 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__):
195 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config)
196 derived_config.update(config.to_dict())
197 kwargs["config"] = derived_config
198 return cls.from_pretrained(model_name_or_path, *args, **kwargs)
199 if issubclass(cls, BertModel):
200 kwargs["add_pooling_layer"] = False
201 return super(LightningIRModel, cls).from_pretrained(model_name_or_path, *args, **kwargs)
202
203
204T = TypeVar("T")
205
206
207def _cat_outputs(
208 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: Type[T] | None
209) -> torch.Tensor | T | None:
210 if len(outputs) == 1:
211 return outputs[0]
212 if len(outputs) == 0 or outputs[0] is None or OutputClass is None:
213 return None
214 if isinstance(outputs[0], torch.Tensor):
215 return torch.cat(outputs, dim=0)
216 agg = defaultdict(list)
217 types = {}
218 for output in outputs:
219 for key, value in output.items():
220 agg[key].append(value)
221 types[key] = type(value)
222 return OutputClass(**{key: _cat_outputs(value, types[key]) for key, value in agg.items()})
223
224
225def _batch_encoding(
226 func: Callable[[LightningIRModel, BatchEncoding, ...], Any]
227) -> Callable[[LightningIRModel, BatchEncoding, ...], Any]:
228
229 @wraps(func)
230 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any:
231 if not self.ALLOW_SUB_BATCHING:
232 return func(self, encoding, *args, **kwargs)
233 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0]
234 sub_encoding = encoding
235 remaining_encoding = encoding
236 OutputClass = None
237 outputs = []
238 while True:
239 try:
240 # ceil division
241 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size)
242 for _ in range(num_batches):
243 sub_encoding = BatchEncoding(
244 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()}
245 )
246 output = func(self, sub_encoding, *args, **kwargs)
247 OutputClass = output.__class__
248 outputs.append(output)
249 remaining_encoding = BatchEncoding(
250 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()}
251 )
252 break
253 except RuntimeError as e:
254 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e):
255 self._sub_batch_size = sub_batch_size = sub_batch_size // 2
256 if sub_batch_size == 0:
257 raise e
258 else:
259 raise e
260 if OutputClass is None:
261 raise ValueError("No output was generated.")
262 return _cat_outputs(outputs, OutputClass)
263
264 return wrapper