Source code for lightning_ir.base.class_factory

  1from __future__ import annotations
  2
  3from abc import ABC, abstractmethod
  4from pathlib import Path
  5from typing import TYPE_CHECKING, Any, Tuple, Type
  6
  7from transformers import (
  8    CONFIG_MAPPING,
  9    MODEL_MAPPING,
 10    TOKENIZER_MAPPING,
 11    AutoConfig,
 12    AutoModel,
 13    AutoTokenizer,
 14    PretrainedConfig,
 15    PreTrainedModel,
 16    PreTrainedTokenizerBase,
 17)
 18from transformers.models.auto.tokenization_auto import get_tokenizer_config, tokenizer_class_from_name
 19
 20if TYPE_CHECKING:
 21    from . import LightningIRConfig, LightningIRModel, LightningIRTokenizer
 22
 23
[docs] 24class LightningIRClassFactory(ABC): 25 """Base class for creating derived LightningIR classes from HuggingFace classes.""" 26
[docs] 27 def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None: 28 """Creates a new LightningIRClassFactory. 29 30 :param MixinConfig: LightningIRConfig mixin class 31 :type MixinConfig: Type[LightningIRConfig] 32 """ 33 if getattr(MixinConfig, "backbone_model_type", None) is not None: 34 MixinConfig = MixinConfig.__bases__[0] 35 self.MixinConfig = MixinConfig
36
[docs] 37 @staticmethod 38 def get_backbone_config(model_name_or_path: str | Path) -> Type[PretrainedConfig]: 39 """Grabs the configuration class from a checkpoint of a pretrained HuggingFace model. 40 41 :param model_name_or_path: Path to the model or its name 42 :type model_name_or_path: str | Path 43 :return: Configuration class of the backbone model 44 :rtype: PretrainedConfig 45 """ 46 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(model_name_or_path) 47 return CONFIG_MAPPING[backbone_model_type]
48
[docs] 49 @staticmethod 50 def get_lightning_ir_config(model_name_or_path: str | Path) -> Type[LightningIRConfig] | None: 51 """Grabs the LightningIR configuration class from a checkpoint of a pretrained Lightning IR model. 52 53 :param model_name_or_path: Path to the model or its name 54 :type model_name_or_path: str | Path 55 :return: Configuration class of the Lightning IR model 56 :rtype: Type[LightningIRConfig] 57 """ 58 model_type = LightningIRClassFactory.get_lightning_ir_model_type(model_name_or_path) 59 if model_type is None: 60 return None 61 return CONFIG_MAPPING[model_type]
62
[docs] 63 @staticmethod 64 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str: 65 """Grabs the model type from a checkpoint of a pretrained HuggingFace model. 66 67 :param model_name_or_path: Path to the model or its name 68 :type model_name_or_path: str | Path 69 :return: Model type of the backbone model 70 :rtype: str 71 """ 72 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path, *args, **kwargs) 73 backbone_model_type = config_dict.get("backbone_model_type", None) or config_dict.get("model_type") 74 return backbone_model_type
75
[docs] 76 @staticmethod 77 def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None: 78 """Grabs the Lightning IR model type from a checkpoint of a pretrained HuggingFace model. 79 80 :param model_name_or_path: Path to the model or its name 81 :type model_name_or_path: str | Path 82 :return: Model type of the Lightning IR model 83 :rtype: str | None 84 """ 85 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path) 86 if "backbone_model_type" not in config_dict: 87 return None 88 return config_dict.get("model_type", None)
89 90 @property 91 def cc_lir_model_type(self) -> str: 92 """Camel case model type of the LightningIR model.""" 93 return "".join(s.title() for s in self.MixinConfig.model_type.split("-")) 94
[docs] 95 @abstractmethod 96 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any: 97 """Loads a derived LightningIR class from a pretrained HuggingFace model. Must be implemented by subclasses. 98 99 :param model_name_or_path: Path to the model or its name 100 :type model_name_or_path: str | Path 101 :return: Derived LightningIR class 102 :rtype: Any 103 """ 104 ...
105
[docs] 106 @abstractmethod 107 def from_backbone_class(self, BackboneClass: Type) -> Type: 108 """Creates a derived LightningIR class from a backbone HuggingFace class. Must be implemented by subclasses. 109 110 :param BackboneClass: Backbone class 111 :type BackboneClass: Type 112 :return: Derived LightningIR class 113 :rtype: Type 114 """ 115 ...
116 117
[docs] 118class LightningIRConfigClassFactory(LightningIRClassFactory): 119 """Class factory for creating derived LightningIRConfig classes from HuggingFace configuration classes.""" 120
[docs] 121 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRConfig]: 122 """Loads a derived LightningIRConfig from a pretrained HuggingFace model. 123 124 :param model_name_or_path: Path to the model or its name 125 :type model_name_or_path: str | Path 126 :return: Derived LightningIRConfig 127 :rtype: Type[LightningIRConfig] 128 """ 129 BackboneConfig = self.get_backbone_config(model_name_or_path) 130 DerivedLightningIRConfig = self.from_backbone_class(BackboneConfig) 131 return DerivedLightningIRConfig
132
[docs] 133 def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[LightningIRConfig]: 134 """Creates a derived LightningIRConfig from a transformers.PretrainedConfig_ backbone configuration class. If 135 the backbone configuration class is already a dervied LightningIRConfig, it is returned as is. 136 137 .. _transformers.PretrainedConfig: \ 138https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig 139 140 :param BackboneClass: Backbone configuration class 141 :type BackboneClass: Type[PretrainedConfig] 142 :return: Derived LightningIRConfig 143 :rtype: Type[LightningIRConfig] 144 """ 145 if getattr(BackboneClass, "backbone_model_type", None) is not None: 146 return BackboneClass 147 LightningIRConfigMixin: Type[LightningIRConfig] = CONFIG_MAPPING[self.MixinConfig.model_type] 148 149 DerivedLightningIRConfig = type( 150 f"{self.cc_lir_model_type}{BackboneClass.__name__}", 151 (LightningIRConfigMixin, BackboneClass), 152 { 153 "model_type": f"{BackboneClass.model_type}-{self.MixinConfig.model_type}", 154 "backbone_model_type": BackboneClass.model_type, 155 }, 156 ) 157 158 AutoConfig.register(DerivedLightningIRConfig.model_type, DerivedLightningIRConfig, exist_ok=True) 159 160 return DerivedLightningIRConfig
161 162
[docs] 163class LightningIRModelClassFactory(LightningIRClassFactory): 164 """Class factory for creating derived LightningIRModel classes from HuggingFace model classes.""" 165
[docs] 166 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRModel]: 167 """Loads a derived LightningIRModel from a pretrained HuggingFace model. 168 169 :param model_name_or_path: Path to the model or its name 170 :type model_name_or_path: str | Path 171 :return: Derived LightningIRModel 172 :rtype: Type[LightningIRModel] 173 """ 174 BackboneConfig = self.get_backbone_config(model_name_or_path) 175 BackboneModel = MODEL_MAPPING[BackboneConfig] 176 DerivedLightningIRModel = self.from_backbone_class(BackboneModel) 177 return DerivedLightningIRModel
178
[docs] 179 def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[LightningIRModel]: 180 """Creates a derived LightningIRModel from a transformers.PreTrainedModel_ backbone model. If the backbone model 181 is already a LightningIRModel, it is returned as is. 182 183 .. _transformers.PreTrainedModel: \ 184https://huggingface.co/transformers/main_classes/model#transformers.PreTrainedModel 185 186 :param BackboneClass: Backbone model 187 :type BackboneClass: Type[PreTrainedModel] 188 :raises ValueError: If the backbone model is not a valid backbone model. 189 :raises ValueError: If the backbone model is not a LightningIRModel and no LightningIRConfig is passed. 190 :raises ValueError: If the LightningIRModel mixin is not registered with the Hugging Face model mapping. 191 :return: The derived LightningIRModel 192 :rtype: Type[LightningIRModel] 193 """ 194 if getattr(BackboneClass.config_class, "backbone_model_type", None) is not None: 195 return BackboneClass 196 BackboneConfig = BackboneClass.config_class 197 if BackboneConfig is None: 198 raise ValueError( 199 f"Model {BackboneClass} is not a valid backbone model because it is missing a `config_class`." 200 ) 201 202 LightningIRModelMixin: Type[LightningIRModel] = MODEL_MAPPING[self.MixinConfig] 203 204 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig) 205 206 DerivedLightningIRModel = type( 207 f"{self.cc_lir_model_type}{BackboneClass.__name__}", 208 (LightningIRModelMixin, BackboneClass), 209 {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward}, 210 ) 211 212 AutoModel.register(DerivedLightningIRConfig, DerivedLightningIRModel, exist_ok=True) 213 214 return DerivedLightningIRModel
215 216
[docs] 217class LightningIRTokenizerClassFactory(LightningIRClassFactory): 218 """Class factory for creating derived LightningIRTokenizer classes from HuggingFace tokenizer classes.""" 219
[docs] 220 @staticmethod 221 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig: 222 """Grabs the tokenizer configuration class from a checkpoint of a pretrained HuggingFace tokenizer. 223 224 :param model_name_or_path: Path to the tokenizer or its name 225 :type model_name_or_path: str | Path 226 :return: Configuration class of the backbone tokenizer 227 :rtype: PretrainedConfig 228 """ 229 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(model_name_or_path) 230 return CONFIG_MAPPING[backbone_model_type]
231
[docs] 232 @staticmethod 233 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str: 234 """Grabs the model type from a checkpoint of a pretrained HuggingFace tokenizer. 235 236 :param model_name_or_path: Path to the tokenizer or its name 237 :type model_name_or_path: str | Path 238 :return: Model type of the backbone tokenizer 239 :rtype: str 240 """ 241 try: 242 return LightningIRClassFactory.get_backbone_model_type(model_name_or_path, *args, **kwargs) 243 except OSError: 244 # best guess at model type 245 config_dict = get_tokenizer_config(model_name_or_path) 246 Tokenizer = tokenizer_class_from_name(config_dict["tokenizer_class"]) 247 for config, tokenizers in TOKENIZER_MAPPING.items(): 248 if Tokenizer in tokenizers: 249 return getattr(config, "backbone_model_type", None) or getattr(config, "model_type") 250 raise ValueError("No backbone model found in the configuration")
251
[docs] 252 def from_pretrained( 253 self, model_name_or_path: str | Path, *args, use_fast: bool = True, **kwargs 254 ) -> Type[LightningIRTokenizer]: 255 """Loads a derived LightningIRTokenizer from a pretrained HuggingFace tokenizer. 256 257 :param model_name_or_path: Path to the tokenizer or its name 258 :type model_name_or_path: str | Path 259 :param use_fast: Whether to use the fast or slow tokenizer, defaults to True 260 :type use_fast: bool, optional 261 :raises ValueError: If use_fast is True and no fast tokenizer is found 262 :raises ValueError: If use_fast is False and no slow tokenizer is found 263 :return: Derived LightningIRTokenizer 264 :rtype: Type[LightningIRTokenizer] 265 """ 266 BackboneConfig = self.get_backbone_config(model_name_or_path) 267 BackboneTokenizers = TOKENIZER_MAPPING[BackboneConfig] 268 DerivedLightningIRTokenizers = self.from_backbone_classes(BackboneTokenizers, BackboneConfig) 269 if use_fast: 270 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[1] 271 if DerivedLightningIRTokenizer is None: 272 raise ValueError("No fast tokenizer found.") 273 else: 274 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[0] 275 if DerivedLightningIRTokenizer is None: 276 raise ValueError("No slow tokenizer found.") 277 return DerivedLightningIRTokenizer
278
[docs] 279 def from_backbone_classes( 280 self, 281 BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None], 282 BackboneConfig: Type[PretrainedConfig] | None = None, 283 ) -> Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]: 284 """Creates derived slow and fastLightningIRTokenizers from a tuple of backbone HuggingFace tokenizer classes. 285 286 :param BackboneClasses: Slow and fast backbone tokenizer classes 287 :type BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None] 288 :param BackboneConfig: Backbone configuration class, defaults to None 289 :type BackboneConfig: Type[PretrainedConfig], optional 290 :return: Slow and fast derived LightningIRTokenizers 291 :rtype: Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None] 292 """ 293 DerivedLightningIRTokenizers = tuple( 294 None if BackboneClass is None else self.from_backbone_class(BackboneClass) 295 for BackboneClass in BackboneClasses 296 ) 297 if DerivedLightningIRTokenizers[1] is not None: 298 DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0] 299 300 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig) 301 AutoTokenizer.register( 302 DerivedLightningIRConfig, DerivedLightningIRTokenizers[0], DerivedLightningIRTokenizers[1] 303 ) 304 305 return DerivedLightningIRTokenizers
306
[docs] 307 def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]: 308 """Creates a derived LightningIRTokenizer from a transformers.PreTrainedTokenizerBase_ backbone tokenizer. If 309 the backbone tokenizer is already a LightningIRTokenizer, it is returned as is. 310 311 .. _transformers.PreTrainedTokenizerBase: \ 312https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerBase 313 314 :param BackboneClass: Backbone tokenizer class 315 :type BackboneClass: Type[PreTrainedTokenizerBase] 316 :return: Derived LightningIRTokenizer 317 :rtype: Type[LightningIRTokenizer] 318 """ 319 if hasattr(BackboneClass, "config_class"): 320 return BackboneClass 321 LightningIRTokenizerMixin = TOKENIZER_MAPPING[self.MixinConfig][0] 322 323 DerivedLightningIRTokenizer = type( 324 f"{self.cc_lir_model_type}{BackboneClass.__name__}", (LightningIRTokenizerMixin, BackboneClass), {} 325 ) 326 327 return DerivedLightningIRTokenizer