1from __future__ import annotations
2
3from abc import ABC, abstractmethod
4from pathlib import Path
5from typing import TYPE_CHECKING, Any, Tuple, Type
6
7from transformers import (
8 CONFIG_MAPPING,
9 MODEL_MAPPING,
10 TOKENIZER_MAPPING,
11 AutoConfig,
12 AutoModel,
13 AutoTokenizer,
14 PretrainedConfig,
15 PreTrainedModel,
16 PreTrainedTokenizerBase,
17)
18from transformers.models.auto.tokenization_auto import get_tokenizer_config, tokenizer_class_from_name
19
20if TYPE_CHECKING:
21 from . import LightningIRConfig, LightningIRModel, LightningIRTokenizer
22
23
[docs]
24class LightningIRClassFactory(ABC):
25 """Base class for creating derived LightningIR classes from HuggingFace classes."""
26
[docs]
27 def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None:
28 """Creates a new LightningIRClassFactory.
29
30 :param MixinConfig: LightningIRConfig mixin class
31 :type MixinConfig: Type[LightningIRConfig]
32 """
33 if getattr(MixinConfig, "backbone_model_type", None) is not None:
34 MixinConfig = MixinConfig.__bases__[0]
35 self.MixinConfig = MixinConfig
36
[docs]
37 @staticmethod
38 def get_backbone_config(model_name_or_path: str | Path) -> Type[PretrainedConfig]:
39 """Grabs the configuration class from a checkpoint of a pretrained HuggingFace model.
40
41 :param model_name_or_path: Path to the model or its name
42 :type model_name_or_path: str | Path
43 :return: Configuration class of the backbone model
44 :rtype: PretrainedConfig
45 """
46 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(model_name_or_path)
47 return CONFIG_MAPPING[backbone_model_type]
48
[docs]
49 @staticmethod
50 def get_lightning_ir_config(model_name_or_path: str | Path) -> Type[LightningIRConfig] | None:
51 """Grabs the LightningIR configuration class from a checkpoint of a pretrained Lightning IR model.
52
53 :param model_name_or_path: Path to the model or its name
54 :type model_name_or_path: str | Path
55 :return: Configuration class of the Lightning IR model
56 :rtype: Type[LightningIRConfig]
57 """
58 model_type = LightningIRClassFactory.get_lightning_ir_model_type(model_name_or_path)
59 if model_type is None:
60 return None
61 return CONFIG_MAPPING[model_type]
62
[docs]
63 @staticmethod
64 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str:
65 """Grabs the model type from a checkpoint of a pretrained HuggingFace model.
66
67 :param model_name_or_path: Path to the model or its name
68 :type model_name_or_path: str | Path
69 :return: Model type of the backbone model
70 :rtype: str
71 """
72 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path, *args, **kwargs)
73 backbone_model_type = config_dict.get("backbone_model_type", None) or config_dict.get("model_type")
74 return backbone_model_type
75
[docs]
76 @staticmethod
77 def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None:
78 """Grabs the Lightning IR model type from a checkpoint of a pretrained HuggingFace model.
79
80 :param model_name_or_path: Path to the model or its name
81 :type model_name_or_path: str | Path
82 :return: Model type of the Lightning IR model
83 :rtype: str | None
84 """
85 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path)
86 if "backbone_model_type" not in config_dict:
87 return None
88 return config_dict.get("model_type", None)
89
90 @property
91 def cc_lir_model_type(self) -> str:
92 """Camel case model type of the LightningIR model."""
93 return "".join(s.title() for s in self.MixinConfig.model_type.split("-"))
94
[docs]
95 @abstractmethod
96 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any:
97 """Loads a derived LightningIR class from a pretrained HuggingFace model. Must be implemented by subclasses.
98
99 :param model_name_or_path: Path to the model or its name
100 :type model_name_or_path: str | Path
101 :return: Derived LightningIR class
102 :rtype: Any
103 """
104 ...
105
[docs]
106 @abstractmethod
107 def from_backbone_class(self, BackboneClass: Type) -> Type:
108 """Creates a derived LightningIR class from a backbone HuggingFace class. Must be implemented by subclasses.
109
110 :param BackboneClass: Backbone class
111 :type BackboneClass: Type
112 :return: Derived LightningIR class
113 :rtype: Type
114 """
115 ...
116
117
[docs]
118class LightningIRConfigClassFactory(LightningIRClassFactory):
119 """Class factory for creating derived LightningIRConfig classes from HuggingFace configuration classes."""
120
[docs]
121 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRConfig]:
122 """Loads a derived LightningIRConfig from a pretrained HuggingFace model.
123
124 :param model_name_or_path: Path to the model or its name
125 :type model_name_or_path: str | Path
126 :return: Derived LightningIRConfig
127 :rtype: Type[LightningIRConfig]
128 """
129 BackboneConfig = self.get_backbone_config(model_name_or_path)
130 DerivedLightningIRConfig = self.from_backbone_class(BackboneConfig)
131 return DerivedLightningIRConfig
132
[docs]
133 def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[LightningIRConfig]:
134 """Creates a derived LightningIRConfig from a transformers.PretrainedConfig_ backbone configuration class. If
135 the backbone configuration class is already a dervied LightningIRConfig, it is returned as is.
136
137 .. _transformers.PretrainedConfig: \
138https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig
139
140 :param BackboneClass: Backbone configuration class
141 :type BackboneClass: Type[PretrainedConfig]
142 :return: Derived LightningIRConfig
143 :rtype: Type[LightningIRConfig]
144 """
145 if getattr(BackboneClass, "backbone_model_type", None) is not None:
146 return BackboneClass
147 LightningIRConfigMixin: Type[LightningIRConfig] = CONFIG_MAPPING[self.MixinConfig.model_type]
148
149 DerivedLightningIRConfig = type(
150 f"{self.cc_lir_model_type}{BackboneClass.__name__}",
151 (LightningIRConfigMixin, BackboneClass),
152 {
153 "model_type": f"{BackboneClass.model_type}-{self.MixinConfig.model_type}",
154 "backbone_model_type": BackboneClass.model_type,
155 },
156 )
157
158 AutoConfig.register(DerivedLightningIRConfig.model_type, DerivedLightningIRConfig, exist_ok=True)
159
160 return DerivedLightningIRConfig
161
162
[docs]
163class LightningIRModelClassFactory(LightningIRClassFactory):
164 """Class factory for creating derived LightningIRModel classes from HuggingFace model classes."""
165
[docs]
166 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRModel]:
167 """Loads a derived LightningIRModel from a pretrained HuggingFace model.
168
169 :param model_name_or_path: Path to the model or its name
170 :type model_name_or_path: str | Path
171 :return: Derived LightningIRModel
172 :rtype: Type[LightningIRModel]
173 """
174 BackboneConfig = self.get_backbone_config(model_name_or_path)
175 BackboneModel = MODEL_MAPPING[BackboneConfig]
176 DerivedLightningIRModel = self.from_backbone_class(BackboneModel)
177 return DerivedLightningIRModel
178
[docs]
179 def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[LightningIRModel]:
180 """Creates a derived LightningIRModel from a transformers.PreTrainedModel_ backbone model. If the backbone model
181 is already a LightningIRModel, it is returned as is.
182
183 .. _transformers.PreTrainedModel: \
184https://huggingface.co/transformers/main_classes/model#transformers.PreTrainedModel
185
186 :param BackboneClass: Backbone model
187 :type BackboneClass: Type[PreTrainedModel]
188 :raises ValueError: If the backbone model is not a valid backbone model.
189 :raises ValueError: If the backbone model is not a LightningIRModel and no LightningIRConfig is passed.
190 :raises ValueError: If the LightningIRModel mixin is not registered with the Hugging Face model mapping.
191 :return: The derived LightningIRModel
192 :rtype: Type[LightningIRModel]
193 """
194 if getattr(BackboneClass.config_class, "backbone_model_type", None) is not None:
195 return BackboneClass
196 BackboneConfig = BackboneClass.config_class
197 if BackboneConfig is None:
198 raise ValueError(
199 f"Model {BackboneClass} is not a valid backbone model because it is missing a `config_class`."
200 )
201
202 LightningIRModelMixin: Type[LightningIRModel] = MODEL_MAPPING[self.MixinConfig]
203
204 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig)
205
206 DerivedLightningIRModel = type(
207 f"{self.cc_lir_model_type}{BackboneClass.__name__}",
208 (LightningIRModelMixin, BackboneClass),
209 {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward},
210 )
211
212 AutoModel.register(DerivedLightningIRConfig, DerivedLightningIRModel, exist_ok=True)
213
214 return DerivedLightningIRModel
215
216
[docs]
217class LightningIRTokenizerClassFactory(LightningIRClassFactory):
218 """Class factory for creating derived LightningIRTokenizer classes from HuggingFace tokenizer classes."""
219
[docs]
220 @staticmethod
221 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig:
222 """Grabs the tokenizer configuration class from a checkpoint of a pretrained HuggingFace tokenizer.
223
224 :param model_name_or_path: Path to the tokenizer or its name
225 :type model_name_or_path: str | Path
226 :return: Configuration class of the backbone tokenizer
227 :rtype: PretrainedConfig
228 """
229 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(model_name_or_path)
230 return CONFIG_MAPPING[backbone_model_type]
231
[docs]
232 @staticmethod
233 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str:
234 """Grabs the model type from a checkpoint of a pretrained HuggingFace tokenizer.
235
236 :param model_name_or_path: Path to the tokenizer or its name
237 :type model_name_or_path: str | Path
238 :return: Model type of the backbone tokenizer
239 :rtype: str
240 """
241 try:
242 return LightningIRClassFactory.get_backbone_model_type(model_name_or_path, *args, **kwargs)
243 except OSError:
244 # best guess at model type
245 config_dict = get_tokenizer_config(model_name_or_path)
246 Tokenizer = tokenizer_class_from_name(config_dict["tokenizer_class"])
247 for config, tokenizers in TOKENIZER_MAPPING.items():
248 if Tokenizer in tokenizers:
249 return getattr(config, "backbone_model_type", None) or getattr(config, "model_type")
250 raise ValueError("No backbone model found in the configuration")
251
[docs]
252 def from_pretrained(
253 self, model_name_or_path: str | Path, *args, use_fast: bool = True, **kwargs
254 ) -> Type[LightningIRTokenizer]:
255 """Loads a derived LightningIRTokenizer from a pretrained HuggingFace tokenizer.
256
257 :param model_name_or_path: Path to the tokenizer or its name
258 :type model_name_or_path: str | Path
259 :param use_fast: Whether to use the fast or slow tokenizer, defaults to True
260 :type use_fast: bool, optional
261 :raises ValueError: If use_fast is True and no fast tokenizer is found
262 :raises ValueError: If use_fast is False and no slow tokenizer is found
263 :return: Derived LightningIRTokenizer
264 :rtype: Type[LightningIRTokenizer]
265 """
266 BackboneConfig = self.get_backbone_config(model_name_or_path)
267 BackboneTokenizers = TOKENIZER_MAPPING[BackboneConfig]
268 DerivedLightningIRTokenizers = self.from_backbone_classes(BackboneTokenizers, BackboneConfig)
269 if use_fast:
270 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[1]
271 if DerivedLightningIRTokenizer is None:
272 raise ValueError("No fast tokenizer found.")
273 else:
274 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[0]
275 if DerivedLightningIRTokenizer is None:
276 raise ValueError("No slow tokenizer found.")
277 return DerivedLightningIRTokenizer
278
[docs]
279 def from_backbone_classes(
280 self,
281 BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None],
282 BackboneConfig: Type[PretrainedConfig] | None = None,
283 ) -> Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]:
284 """Creates derived slow and fastLightningIRTokenizers from a tuple of backbone HuggingFace tokenizer classes.
285
286 :param BackboneClasses: Slow and fast backbone tokenizer classes
287 :type BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None]
288 :param BackboneConfig: Backbone configuration class, defaults to None
289 :type BackboneConfig: Type[PretrainedConfig], optional
290 :return: Slow and fast derived LightningIRTokenizers
291 :rtype: Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]
292 """
293 DerivedLightningIRTokenizers = tuple(
294 None if BackboneClass is None else self.from_backbone_class(BackboneClass)
295 for BackboneClass in BackboneClasses
296 )
297 if DerivedLightningIRTokenizers[1] is not None:
298 DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0]
299
300 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig)
301 AutoTokenizer.register(
302 DerivedLightningIRConfig, DerivedLightningIRTokenizers[0], DerivedLightningIRTokenizers[1]
303 )
304
305 return DerivedLightningIRTokenizers
306
[docs]
307 def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]:
308 """Creates a derived LightningIRTokenizer from a transformers.PreTrainedTokenizerBase_ backbone tokenizer. If
309 the backbone tokenizer is already a LightningIRTokenizer, it is returned as is.
310
311 .. _transformers.PreTrainedTokenizerBase: \
312https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerBase
313
314 :param BackboneClass: Backbone tokenizer class
315 :type BackboneClass: Type[PreTrainedTokenizerBase]
316 :return: Derived LightningIRTokenizer
317 :rtype: Type[LightningIRTokenizer]
318 """
319 if hasattr(BackboneClass, "config_class"):
320 return BackboneClass
321 LightningIRTokenizerMixin = TOKENIZER_MAPPING[self.MixinConfig][0]
322
323 DerivedLightningIRTokenizer = type(
324 f"{self.cc_lir_model_type}{BackboneClass.__name__}", (LightningIRTokenizerMixin, BackboneClass), {}
325 )
326
327 return DerivedLightningIRTokenizer