Source code for lightning_ir.base.tokenizer

  1import warnings
  2from typing import Dict, Sequence, Type
  3
  4from transformers import TOKENIZER_MAPPING, BatchEncoding
  5
  6from .class_factory import LightningIRTokenizerClassFactory
  7from .config import LightningIRConfig
  8from .external_model_hub import CHECKPOINT_MAPPING
  9
 10
[docs] 11class LightningIRTokenizer: 12 """Base class for LightningIR tokenizers. Derived classes implement the tokenize method for handling query 13 and document tokenization. It acts as mixin for a transformers.PreTrainedTokenizer_ backbone tokenizer. 14 15 .. _transformers.PreTrainedTokenizer: \ 16https://huggingface.co/transformers/main_classes/tokenizer.htmltransformers.PreTrainedTokenizer 17 """ 18 19 config_class: Type[LightningIRConfig] = LightningIRConfig 20 """Configuration class for the tokenizer.""" 21
[docs] 22 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs): 23 """Initializes the tokenizer. 24 25 :param query_length: Maximum number of tokens per query, defaults to 32 26 :type query_length: int, optional 27 :param doc_length: Maximum number of tokens per document, defaults to 512 28 :type doc_length: int, optional 29 """ 30 super().__init__(*args, **kwargs) 31 self.query_length = query_length 32 self.doc_length = doc_length
33
[docs] 34 def tokenize( 35 self, queries: str | Sequence[str] | None = None, docs: str | Sequence[str] | None = None, **kwargs 36 ) -> Dict[str, BatchEncoding]: 37 """Tokenizes queries and documents. 38 39 :param queries: Queries to tokenize, defaults to None 40 :type queries: str | Sequence[str] | None, optional 41 :param docs: Documents to tokenize, defaults to None 42 :type docs: str | Sequence[str] | None, optional 43 :raises NotImplementedError: Must be implemented by the derived class 44 :return: Dictionary of tokenized queries and documents 45 :rtype: Dict[str, BatchEncoding] 46 """ 47 raise NotImplementedError
48
[docs] 49 @classmethod 50 def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> "LightningIRTokenizer": 51 """Loads a pretrained tokenizer. Wraps the transformers.PreTrainedTokenizer.from_pretrained_ method to return a 52 derived LightningIRTokenizer class. See :class:`.LightningIRTokenizerClassFactory` for more details. 53 54 .. _transformers.PreTrainedTokenizer.from_pretrained: \ 55https://huggingface.co/docs/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizer.from_pretrained 56 57 .. highlight:: python 58 .. code-block:: python 59 60 >>> Loading using model class and backbone checkpoint 61 >>> type(BiEncoderTokenizer.from_pretrained("bert-base-uncased")) 62 ... 63 <class 'lightning_ir.base.class_factory.BiEncoderBertTokenizerFast'> 64 >>> Loading using base class and backbone checkpoint 65 >>> type(LightningIRTokenizer.from_pretrained("bert-base-uncased", config=BiEncoderConfig())) 66 ... 67 <class 'lightning_ir.base.class_factory.BiEncoderBertTokenizerFast'> 68 69 :param model_name_or_path: Name or path of the pretrained tokenizer 70 :type model_name_or_path: str 71 :raises ValueError: If called on the abstract class :class:`LightningIRTokenizer` and no config is passed 72 :return: A derived LightningIRTokenizer consisting of a backbone tokenizer and a LightningIRTokenizer mixin 73 :rtype: LightningIRTokenizer 74 """ 75 config = kwargs.pop("config", None) 76 if config is not None: 77 kwargs.update(config.to_tokenizer_dict()) 78 if cls is LightningIRTokenizer or all(issubclass(base, LightningIRTokenizer) for base in cls.__bases__): 79 # no backbone models found, create derived lightning-ir tokenizer based on backbone model 80 if model_name_or_path in CHECKPOINT_MAPPING: 81 _config = CHECKPOINT_MAPPING[model_name_or_path] 82 Config = _config.__class__ 83 if config is not None: 84 warnings.warn(f"{model_name_or_path} is a registered checkpoint. The provided config is ignored.") 85 kwargs.update(_config.to_tokenizer_dict()) 86 elif config is not None: 87 Config = config.__class__ 88 elif cls is not LightningIRTokenizer and hasattr(cls, "config_class"): 89 Config = cls.config_class 90 else: 91 Config = LightningIRTokenizerClassFactory.get_lightning_ir_config(model_name_or_path) 92 if Config is None: 93 raise ValueError("Pass a config to `from_pretrained`.") 94 BackboneConfig = LightningIRTokenizerClassFactory.get_backbone_config(model_name_or_path) 95 BackboneTokenizers = TOKENIZER_MAPPING[BackboneConfig] 96 if kwargs.get("use_fast", True): 97 BackboneTokenizer = BackboneTokenizers[1] 98 else: 99 BackboneTokenizer = BackboneTokenizers[0] 100 cls = LightningIRTokenizerClassFactory(Config).from_backbone_class(BackboneTokenizer) 101 return cls.from_pretrained(model_name_or_path, *args, **kwargs) 102 return super(LightningIRTokenizer, cls).from_pretrained(model_name_or_path, *args, **kwargs)