Source code for lightning_ir.cross_encoder.tokenizer
1from typing import Dict, List, Sequence, Tuple, Type
2
3from transformers import BatchEncoding
4
5from ..base import LightningIRTokenizer
6from .config import CrossEncoderConfig
7
8
[docs]
9class CrossEncoderTokenizer(LightningIRTokenizer):
10
11 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig
12
[docs]
13 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs):
14 super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs)
15
16 def truncate(self, text: Sequence[str], max_length: int) -> List[str]:
17 return self.batch_decode(
18 self(
19 text,
20 add_special_tokens=False,
21 truncation=True,
22 max_length=max_length,
23 return_attention_mask=False,
24 return_token_type_ids=False,
25 ).input_ids
26 )
27
28 def expand_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]:
29 return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])]
30
31 def preprocess(
32 self,
33 queries: str | Sequence[str] | None,
34 docs: str | Sequence[str] | None,
35 num_docs: Sequence[int] | None,
36 ) -> Tuple[str | Sequence[str], str | Sequence[str]]:
37 if queries is None or docs is None:
38 raise ValueError("Both queries and docs must be provided.")
39 queries_is_string = isinstance(queries, str)
40 docs_is_string = isinstance(docs, str)
41 if queries_is_string != docs_is_string:
42 raise ValueError("Queries and docs must be both lists or both strings.")
43 if queries_is_string and docs_is_string:
44 queries = [queries]
45 docs = [docs]
46 truncated_queries = self.truncate(queries, self.query_length)
47 truncated_docs = self.truncate(docs, self.doc_length)
48 if not queries_is_string:
49 if num_docs is None:
50 num_docs = [len(docs) // len(queries) for _ in range(len(queries))]
51 expanded_queries = self.expand_queries(truncated_queries, num_docs)
52 docs = truncated_docs
53 else:
54 expanded_queries = truncated_queries[0]
55 docs = truncated_docs[0]
56 return expanded_queries, docs
57
58 def tokenize(
59 self,
60 queries: str | Sequence[str] | None = None,
61 docs: str | Sequence[str] | None = None,
62 num_docs: Sequence[int] | None = None,
63 **kwargs,
64 ) -> Dict[str, BatchEncoding]:
65 expanded_queries, docs = self.preprocess(queries, docs, num_docs)
66 return_tensors = kwargs.get("return_tensors", None)
67 if return_tensors is not None:
68 kwargs["pad_to_multiple_of"] = 8
69 return {"encoding": self(expanded_queries, docs, **kwargs)}