Source code for lightning_ir.cross_encoder.tokenizer

 1from typing import Dict, List, Sequence, Tuple, Type
 2
 3from transformers import BatchEncoding
 4
 5from ..base import LightningIRTokenizer
 6from .config import CrossEncoderConfig
 7
 8
[docs] 9class CrossEncoderTokenizer(LightningIRTokenizer): 10 11 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig 12
[docs] 13 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs): 14 super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs)
15 16 def truncate(self, text: Sequence[str], max_length: int) -> List[str]: 17 return self.batch_decode( 18 self( 19 text, 20 add_special_tokens=False, 21 truncation=True, 22 max_length=max_length, 23 return_attention_mask=False, 24 return_token_type_ids=False, 25 ).input_ids 26 ) 27 28 def expand_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]: 29 return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])] 30 31 def preprocess( 32 self, 33 queries: str | Sequence[str] | None, 34 docs: str | Sequence[str] | None, 35 num_docs: Sequence[int] | None, 36 ) -> Tuple[str | Sequence[str], str | Sequence[str]]: 37 if queries is None or docs is None: 38 raise ValueError("Both queries and docs must be provided.") 39 queries_is_string = isinstance(queries, str) 40 docs_is_string = isinstance(docs, str) 41 if queries_is_string != docs_is_string: 42 raise ValueError("Queries and docs must be both lists or both strings.") 43 if queries_is_string and docs_is_string: 44 queries = [queries] 45 docs = [docs] 46 truncated_queries = self.truncate(queries, self.query_length) 47 truncated_docs = self.truncate(docs, self.doc_length) 48 if not queries_is_string: 49 if num_docs is None: 50 num_docs = [len(docs) // len(queries) for _ in range(len(queries))] 51 expanded_queries = self.expand_queries(truncated_queries, num_docs) 52 docs = truncated_docs 53 else: 54 expanded_queries = truncated_queries[0] 55 docs = truncated_docs[0] 56 return expanded_queries, docs 57 58 def tokenize( 59 self, 60 queries: str | Sequence[str] | None = None, 61 docs: str | Sequence[str] | None = None, 62 num_docs: Sequence[int] | None = None, 63 **kwargs, 64 ) -> Dict[str, BatchEncoding]: 65 expanded_queries, docs = self.preprocess(queries, docs, num_docs) 66 return_tensors = kwargs.get("return_tensors", None) 67 if return_tensors is not None: 68 kwargs["pad_to_multiple_of"] = 8 69 return {"encoding": self(expanded_queries, docs, **kwargs)}