Source code for lightning_ir.base.config
1from pathlib import Path
2from typing import Any, Dict, Set
3
4from transformers import CONFIG_MAPPING
5
6from .class_factory import LightningIRConfigClassFactory
7from .external_model_hub import CHECKPOINT_MAPPING
8
9
[docs]
10class LightningIRConfig:
11 """The configuration class to instantiate a LightningIR model. Acts as a mixin for the
12 transformers.PretrainedConfig_ class.
13
14 .. _transformers.PretrainedConfig: \
15https://huggingface.co/transformers/main_classes/configuration.html#transformers.PretrainedConfig
16 """
17
18 model_type = "lightning-ir"
19 """Model type for the configuration."""
20 backbone_model_type: str | None = None
21 """Backbone model type for the configuration. Set by :func:`LightningIRModelClassFactory`."""
22
23 TOKENIZER_ARGS: Set[str] = {"query_length", "doc_length"}
24 """Arguments for the tokenizer."""
25 ADDED_ARGS: Set[str] = TOKENIZER_ARGS
26 """Arguments added to the configuration."""
27
[docs]
28 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs):
29 """Initializes the configuration.
30
31 :param query_length: Maximum query length, defaults to 32
32 :type query_length: int, optional
33 :param doc_length: Maximum document length, defaults to 512
34 :type doc_length: int, optional
35 """
36 super().__init__(*args, **kwargs)
37 self.query_length = query_length
38 self.doc_length = doc_length
39
[docs]
40 def to_added_args_dict(self) -> Dict[str, Any]:
41 """Outputs a dictionary of the added arguments.
42
43 :return: Added arguments
44 :rtype: Dict[str, Any]
45 """
46 return {arg: getattr(self, arg) for arg in self.ADDED_ARGS if hasattr(self, arg)}
47
[docs]
48 def to_tokenizer_dict(self) -> Dict[str, Any]:
49 """Outputs a dictionary of the tokenizer arguments.
50
51 :return: Tokenizer arguments
52 :rtype: Dict[str, Any]
53 """
54 return {arg: getattr(self, arg) for arg in self.TOKENIZER_ARGS}
55
[docs]
56 def to_dict(self) -> Dict[str, Any]:
57 """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments and the backbone
58 model type.
59
60 .. _transformers.PretrainedConfig.to_dict: \
61https://huggingface.co/docs/transformers/main_classes/configuration.html#transformers.PretrainedConfig.to_dict
62
63 :return: Configuration dictionary
64 :rtype: Dict[str, Any]
65 """
66 if hasattr(super(), "to_dict"):
67 output = getattr(super(), "to_dict")()
68 else:
69 output = self.to_added_args_dict()
70 if self.backbone_model_type is not None:
71 output["backbone_model_type"] = self.backbone_model_type
72 return output
73
74 @classmethod
75 def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRConfig":
76 if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__):
77 config = None
78 if pretrained_model_name_or_path in CHECKPOINT_MAPPING:
79 config = CHECKPOINT_MAPPING[pretrained_model_name_or_path]
80 config_class = config.__class__
81 elif cls is not LightningIRConfig:
82 config_class = cls
83 else:
84 config_class = LightningIRConfigClassFactory.get_lightning_ir_config(pretrained_model_name_or_path)
85 if config_class is None:
86 raise ValueError("Pass a config to `from_pretrained`.")
87 BackboneConfig = LightningIRConfigClassFactory.get_backbone_config(pretrained_model_name_or_path)
88 cls = LightningIRConfigClassFactory(config_class).from_backbone_class(BackboneConfig)
89 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__):
90 derived_config = cls.from_pretrained(pretrained_model_name_or_path, config=config)
91 derived_config.update(config.to_dict())
92 return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
93 return super(LightningIRConfig, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
94
[docs]
95 @classmethod
96 def from_dict(cls, config_dict: Dict[str, Any], *args, **kwargs) -> "LightningIRConfig":
97 """Loads the configuration from a dictionary. Wraps the transformers.PretrainedConfig.from_dict_ method to
98 return a derived LightningIRConfig class. See :class:`.LightningIRConfigClassFactory` for more details.
99
100 .. _transformers.PretrainedConfig.from_dict: \
101https://huggingface.co/docs/transformers/main_classes/configuration.html#transformers.PretrainedConfig.from_dict
102
103 :param config_dict: Configuration dictionary
104 :type config_dict: Dict[str, Any]
105 :raises ValueError: If the model type does not match the configuration model type
106 :return: Derived LightningIRConfig class
107 :rtype: LightningIRConfig
108 """
109 if all(issubclass(base, LightningIRConfig) for base in cls.__bases__) or cls is LightningIRConfig:
110 if "backbone_model_type" in config_dict:
111 backbone_model_type = config_dict["backbone_model_type"]
112 model_type = config_dict["model_type"]
113 if cls is not LightningIRConfig and model_type != cls.model_type:
114 raise ValueError(
115 f"Model type {model_type} does not match configuration model type {cls.model_type}"
116 )
117 else:
118 backbone_model_type = config_dict["model_type"]
119 model_type = cls.model_type
120 MixinConfig = CONFIG_MAPPING[model_type]
121 BackboneConfig = CONFIG_MAPPING[backbone_model_type]
122 cls = LightningIRConfigClassFactory(MixinConfig).from_backbone_class(BackboneConfig)
123 return cls.from_dict(config_dict, *args, **kwargs)
124 return super(LightningIRConfig, cls).from_dict(config_dict, *args, **kwargs)