1import warnings
2from typing import Dict, Sequence
3
4from tokenizers.processors import TemplateProcessing
5from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast
6
7from ..base import LightningIRTokenizer
8from .config import BiEncoderConfig
9
10
[docs]
11class BiEncoderTokenizer(LightningIRTokenizer):
12
13 config_class = BiEncoderConfig
14
[docs]
15 def __init__(
16 self,
17 *args,
18 query_token: str = "[QUE]",
19 doc_token: str = "[DOC]",
20 query_expansion: bool = False,
21 query_length: int = 32,
22 attend_to_query_expanded_tokens: bool = False,
23 doc_expansion: bool = False,
24 doc_length: int = 512,
25 attend_to_doc_expanded_tokens: bool = False,
26 add_marker_tokens: bool = True,
27 **kwargs,
28 ):
29 super().__init__(
30 *args,
31 query_token=query_token,
32 doc_token=doc_token,
33 query_expansion=query_expansion,
34 query_length=query_length,
35 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens,
36 doc_expansion=doc_expansion,
37 doc_length=doc_length,
38 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens,
39 add_marker_tokens=add_marker_tokens,
40 **kwargs,
41 )
42 self.query_expansion = query_expansion
43 self.query_length = query_length
44 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
45 self.doc_expansion = doc_expansion
46 self.doc_length = doc_length
47 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
48 self.add_marker_tokens = add_marker_tokens
49
50 self._query_token = query_token
51 self._doc_token = doc_token
52
53 self.query_post_processor: TemplateProcessing | None = None
54 self.doc_post_processor: TemplateProcessing | None = None
55 if add_marker_tokens:
56 # TODO support other tokenizers
57 if not isinstance(self, (BertTokenizer, BertTokenizerFast)):
58 raise ValueError("Adding marker tokens is only supported for BertTokenizer.")
59 self.add_tokens([query_token, doc_token], special_tokens=True)
60 self.query_post_processor = TemplateProcessing(
61 single=f"[CLS] {self.query_token} $0 [SEP]",
62 pair=f"[CLS] {self.query_token} $A [SEP] {self.doc_token} $B:1 [SEP]:1",
63 special_tokens=[
64 ("[CLS]", self.cls_token_id),
65 ("[SEP]", self.sep_token_id),
66 (self.query_token, self.query_token_id),
67 (self.doc_token, self.doc_token_id),
68 ],
69 )
70 self.doc_post_processor = TemplateProcessing(
71 single=f"[CLS] {self.doc_token} $0 [SEP]",
72 pair=f"[CLS] {self.query_token} $A [SEP] {self.doc_token} $B:1 [SEP]:1",
73 special_tokens=[
74 ("[CLS]", self.cls_token_id),
75 ("[SEP]", self.sep_token_id),
76 (self.query_token, self.query_token_id),
77 (self.doc_token, self.doc_token_id),
78 ],
79 )
80
81 @property
82 def query_token(self) -> str:
83 return self._query_token
84
85 @property
86 def doc_token(self) -> str:
87 return self._doc_token
88
89 @property
90 def query_token_id(self) -> int | None:
91 if self.query_token in self.added_tokens_encoder:
92 return self.added_tokens_encoder[self.query_token]
93 return None
94
95 @property
96 def doc_token_id(self) -> int | None:
97 if self.doc_token in self.added_tokens_encoder:
98 return self.added_tokens_encoder[self.doc_token]
99 return None
100
101 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding:
102 if warn:
103 warnings.warn(
104 "BiEncoderTokenizer is being directly called. Use tokenize_query and "
105 "tokenize_doc to make sure marker_tokens and query/doc expansion is "
106 "applied."
107 )
108 return super().__call__(*args, **kwargs)
109
110 def _encode(
111 self,
112 text: str | Sequence[str],
113 *args,
114 post_processor: TemplateProcessing | None = None,
115 **kwargs,
116 ) -> BatchEncoding:
117 orig_post_processor = self._tokenizer.post_processor
118 if post_processor is not None:
119 self._tokenizer.post_processor = post_processor
120 if kwargs.get("return_tensors", None) is not None:
121 kwargs["pad_to_multiple_of"] = 8
122 encoding = self(text, *args, warn=False, **kwargs)
123 self._tokenizer.post_processor = orig_post_processor
124 return encoding
125
126 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
127 input_ids = encoding["input_ids"]
128 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
129 encoding["input_ids"] = input_ids
130 if attend_to_expanded_tokens:
131 encoding["attention_mask"].fill_(1)
132 return encoding
133
134 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
135 kwargs["max_length"] = self.query_length
136 if self.query_expansion:
137 kwargs["padding"] = "max_length"
138 else:
139 kwargs["truncation"] = True
140 encoding = self._encode(queries, *args, post_processor=self.query_post_processor, **kwargs)
141 if self.query_expansion:
142 self._expand(encoding, self.attend_to_query_expanded_tokens)
143 return encoding
144
145 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
146 kwargs["max_length"] = self.doc_length
147 if self.doc_expansion:
148 kwargs["padding"] = "max_length"
149 else:
150 kwargs["truncation"] = True
151 encoding = self._encode(docs, *args, post_processor=self.doc_post_processor, **kwargs)
152 if self.doc_expansion:
153 self._expand(encoding, self.attend_to_doc_expanded_tokens)
154 return encoding
155
156 def tokenize(
157 self,
158 queries: str | Sequence[str] | None = None,
159 docs: str | Sequence[str] | None = None,
160 **kwargs,
161 ) -> Dict[str, BatchEncoding]:
162 encodings = {}
163 kwargs.pop("num_docs", None)
164 if queries is not None:
165 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs)
166 if docs is not None:
167 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs)
168 return encodings