Source code for lightning_ir.models.t5.tokenizer

 1from typing import Dict, Literal, Sequence, Type
 2
 3from transformers import BatchEncoding
 4
 5from ...cross_encoder.tokenizer import CrossEncoderTokenizer
 6from .config import T5CrossEncoderConfig
 7
 8
[docs] 9class T5CrossEncoderTokenizer(CrossEncoderTokenizer): 10 11 config_class: Type[T5CrossEncoderConfig] = T5CrossEncoderConfig 12
[docs] 13 def __init__( 14 self, 15 *args, 16 query_length: int = 32, 17 doc_length: int = 512, 18 decoder_strategy: Literal["mono", "rank"] = "mono", 19 **kwargs, 20 ): 21 super().__init__( 22 *args, query_length=query_length, doc_length=doc_length, decoder_strategy=decoder_strategy, **kwargs 23 ) 24 self.decoder_strategy = decoder_strategy
25 26 def tokenize( 27 self, 28 queries: str | Sequence[str] | None = None, 29 docs: str | Sequence[str] | None = None, 30 num_docs: Sequence[int] | None = None, 31 **kwargs, 32 ) -> Dict[str, BatchEncoding]: 33 expanded_queries, docs = self.preprocess(queries, docs, num_docs) 34 if self.decoder_strategy == "mono": 35 pattern = "Query: {query} Document: {doc} Relevant:" 36 elif self.decoder_strategy == "rank": 37 pattern = "Query: {query} Document: {doc}" 38 else: 39 raise ValueError(f"Unknown decoder strategy: {self.decoder_strategy}") 40 input_texts = [pattern.format(query=query, doc=doc) for query, doc in zip(expanded_queries, docs)] 41 42 return_tensors = kwargs.get("return_tensors", None) 43 if return_tensors is not None: 44 kwargs["pad_to_multiple_of"] = 8 45 return {"encoding": self(input_texts, **kwargs)}