Source code for lightning_ir.base.tokenizer
1import warnings
2from typing import Dict, Sequence, Type
3
4from transformers import TOKENIZER_MAPPING, BatchEncoding
5
6from .class_factory import LightningIRTokenizerClassFactory
7from .config import LightningIRConfig
8from .external_model_hub import CHECKPOINT_MAPPING
9
10
[docs]
11class LightningIRTokenizer:
12 """Base class for LightningIR tokenizers. Derived classes implement the tokenize method for handling query
13 and document tokenization. It acts as mixin for a transformers.PreTrainedTokenizer_ backbone tokenizer.
14
15 .. _transformers.PreTrainedTokenizer: \
16https://huggingface.co/transformers/main_classes/tokenizer.htmltransformers.PreTrainedTokenizer
17 """
18
19 config_class: Type[LightningIRConfig] = LightningIRConfig
20 """Configuration class for the tokenizer."""
21
[docs]
22 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs):
23 """Initializes the tokenizer.
24
25 :param query_length: Maximum number of tokens per query, defaults to 32
26 :type query_length: int, optional
27 :param doc_length: Maximum number of tokens per document, defaults to 512
28 :type doc_length: int, optional
29 """
30 super().__init__(*args, **kwargs)
31 self.query_length = query_length
32 self.doc_length = doc_length
33
[docs]
34 def tokenize(
35 self, queries: str | Sequence[str] | None = None, docs: str | Sequence[str] | None = None, **kwargs
36 ) -> Dict[str, BatchEncoding]:
37 """Tokenizes queries and documents.
38
39 :param queries: Queries to tokenize, defaults to None
40 :type queries: str | Sequence[str] | None, optional
41 :param docs: Documents to tokenize, defaults to None
42 :type docs: str | Sequence[str] | None, optional
43 :raises NotImplementedError: Must be implemented by the derived class
44 :return: Dictionary of tokenized queries and documents
45 :rtype: Dict[str, BatchEncoding]
46 """
47 raise NotImplementedError
48
[docs]
49 @classmethod
50 def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> "LightningIRTokenizer":
51 """Loads a pretrained tokenizer. Wraps the transformers.PreTrainedTokenizer.from_pretrained_ method to return a
52 derived LightningIRTokenizer class. See :class:`.LightningIRTokenizerClassFactory` for more details.
53
54 .. _transformers.PreTrainedTokenizer.from_pretrained: \
55https://huggingface.co/docs/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizer.from_pretrained
56
57 .. highlight:: python
58 .. code-block:: python
59
60 >>> Loading using model class and backbone checkpoint
61 >>> type(BiEncoderTokenizer.from_pretrained("bert-base-uncased"))
62 ...
63 <class 'lightning_ir.base.class_factory.BiEncoderBertTokenizerFast'>
64 >>> Loading using base class and backbone checkpoint
65 >>> type(LightningIRTokenizer.from_pretrained("bert-base-uncased", config=BiEncoderConfig()))
66 ...
67 <class 'lightning_ir.base.class_factory.BiEncoderBertTokenizerFast'>
68
69 :param model_name_or_path: Name or path of the pretrained tokenizer
70 :type model_name_or_path: str
71 :raises ValueError: If called on the abstract class :class:`LightningIRTokenizer` and no config is passed
72 :return: A derived LightningIRTokenizer consisting of a backbone tokenizer and a LightningIRTokenizer mixin
73 :rtype: LightningIRTokenizer
74 """
75 config = kwargs.pop("config", None)
76 if config is not None:
77 kwargs.update(config.to_tokenizer_dict())
78 if cls is LightningIRTokenizer or all(issubclass(base, LightningIRTokenizer) for base in cls.__bases__):
79 # no backbone models found, create derived lightning-ir tokenizer based on backbone model
80 if model_name_or_path in CHECKPOINT_MAPPING:
81 _config = CHECKPOINT_MAPPING[model_name_or_path]
82 Config = _config.__class__
83 if config is not None:
84 warnings.warn(f"{model_name_or_path} is a registered checkpoint. The provided config is ignored.")
85 kwargs.update(_config.to_tokenizer_dict())
86 elif config is not None:
87 Config = config.__class__
88 elif cls is not LightningIRTokenizer and hasattr(cls, "config_class"):
89 Config = cls.config_class
90 else:
91 Config = LightningIRTokenizerClassFactory.get_lightning_ir_config(model_name_or_path)
92 if Config is None:
93 raise ValueError("Pass a config to `from_pretrained`.")
94 BackboneConfig = LightningIRTokenizerClassFactory.get_backbone_config(model_name_or_path)
95 BackboneTokenizers = TOKENIZER_MAPPING[BackboneConfig]
96 if kwargs.get("use_fast", True):
97 BackboneTokenizer = BackboneTokenizers[1]
98 else:
99 BackboneTokenizer = BackboneTokenizers[0]
100 cls = LightningIRTokenizerClassFactory(Config).from_backbone_class(BackboneTokenizer)
101 return cls.from_pretrained(model_name_or_path, *args, **kwargs)
102 return super(LightningIRTokenizer, cls).from_pretrained(model_name_or_path, *args, **kwargs)