Source code for lightning_ir.base.config

  1from pathlib import Path
  2from typing import Any, Dict, Set
  3
  4from transformers import CONFIG_MAPPING
  5
  6from .class_factory import LightningIRConfigClassFactory
  7from .external_model_hub import CHECKPOINT_MAPPING
  8
  9
[docs] 10class LightningIRConfig: 11 """The configuration class to instantiate a LightningIR model. Acts as a mixin for the 12 transformers.PretrainedConfig_ class. 13 14 .. _transformers.PretrainedConfig: \ 15https://huggingface.co/transformers/main_classes/configuration.html#transformers.PretrainedConfig 16 """ 17 18 model_type = "lightning-ir" 19 """Model type for the configuration.""" 20 backbone_model_type: str | None = None 21 """Backbone model type for the configuration. Set by :func:`LightningIRModelClassFactory`.""" 22 23 TOKENIZER_ARGS: Set[str] = {"query_length", "doc_length"} 24 """Arguments for the tokenizer.""" 25 ADDED_ARGS: Set[str] = TOKENIZER_ARGS 26 """Arguments added to the configuration.""" 27
[docs] 28 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs): 29 """Initializes the configuration. 30 31 :param query_length: Maximum query length, defaults to 32 32 :type query_length: int, optional 33 :param doc_length: Maximum document length, defaults to 512 34 :type doc_length: int, optional 35 """ 36 super().__init__(*args, **kwargs) 37 self.query_length = query_length 38 self.doc_length = doc_length
39
[docs] 40 def to_added_args_dict(self) -> Dict[str, Any]: 41 """Outputs a dictionary of the added arguments. 42 43 :return: Added arguments 44 :rtype: Dict[str, Any] 45 """ 46 return {arg: getattr(self, arg) for arg in self.ADDED_ARGS if hasattr(self, arg)}
47
[docs] 48 def to_tokenizer_dict(self) -> Dict[str, Any]: 49 """Outputs a dictionary of the tokenizer arguments. 50 51 :return: Tokenizer arguments 52 :rtype: Dict[str, Any] 53 """ 54 return {arg: getattr(self, arg) for arg in self.TOKENIZER_ARGS}
55
[docs] 56 def to_dict(self) -> Dict[str, Any]: 57 """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments and the backbone 58 model type. 59 60 .. _transformers.PretrainedConfig.to_dict: \ 61https://huggingface.co/docs/transformers/main_classes/configuration.html#transformers.PretrainedConfig.to_dict 62 63 :return: Configuration dictionary 64 :rtype: Dict[str, Any] 65 """ 66 if hasattr(super(), "to_dict"): 67 output = getattr(super(), "to_dict")() 68 else: 69 output = self.to_added_args_dict() 70 if self.backbone_model_type is not None: 71 output["backbone_model_type"] = self.backbone_model_type 72 return output
73 74 @classmethod 75 def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRConfig": 76 if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__): 77 config = None 78 if pretrained_model_name_or_path in CHECKPOINT_MAPPING: 79 config = CHECKPOINT_MAPPING[pretrained_model_name_or_path] 80 config_class = config.__class__ 81 elif cls is not LightningIRConfig: 82 config_class = cls 83 else: 84 config_class = LightningIRConfigClassFactory.get_lightning_ir_config(pretrained_model_name_or_path) 85 if config_class is None: 86 raise ValueError("Pass a config to `from_pretrained`.") 87 BackboneConfig = LightningIRConfigClassFactory.get_backbone_config(pretrained_model_name_or_path) 88 cls = LightningIRConfigClassFactory(config_class).from_backbone_class(BackboneConfig) 89 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 90 derived_config = cls.from_pretrained(pretrained_model_name_or_path, config=config) 91 derived_config.update(config.to_dict()) 92 return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) 93 return super(LightningIRConfig, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs) 94
[docs] 95 @classmethod 96 def from_dict(cls, config_dict: Dict[str, Any], *args, **kwargs) -> "LightningIRConfig": 97 """Loads the configuration from a dictionary. Wraps the transformers.PretrainedConfig.from_dict_ method to 98 return a derived LightningIRConfig class. See :class:`.LightningIRConfigClassFactory` for more details. 99 100 .. _transformers.PretrainedConfig.from_dict: \ 101https://huggingface.co/docs/transformers/main_classes/configuration.html#transformers.PretrainedConfig.from_dict 102 103 :param config_dict: Configuration dictionary 104 :type config_dict: Dict[str, Any] 105 :raises ValueError: If the model type does not match the configuration model type 106 :return: Derived LightningIRConfig class 107 :rtype: LightningIRConfig 108 """ 109 if all(issubclass(base, LightningIRConfig) for base in cls.__bases__) or cls is LightningIRConfig: 110 if "backbone_model_type" in config_dict: 111 backbone_model_type = config_dict["backbone_model_type"] 112 model_type = config_dict["model_type"] 113 if cls is not LightningIRConfig and model_type != cls.model_type: 114 raise ValueError( 115 f"Model type {model_type} does not match configuration model type {cls.model_type}" 116 ) 117 else: 118 backbone_model_type = config_dict["model_type"] 119 model_type = cls.model_type 120 MixinConfig = CONFIG_MAPPING[model_type] 121 BackboneConfig = CONFIG_MAPPING[backbone_model_type] 122 cls = LightningIRConfigClassFactory(MixinConfig).from_backbone_class(BackboneConfig) 123 return cls.from_dict(config_dict, *args, **kwargs) 124 return super(LightningIRConfig, cls).from_dict(config_dict, *args, **kwargs)