Source code for lightning_ir.bi_encoder.tokenizer

  1import warnings
  2from typing import Dict, Sequence
  3
  4from tokenizers.processors import TemplateProcessing
  5from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast
  6
  7from ..base import LightningIRTokenizer
  8from .config import BiEncoderConfig
  9
 10
[docs] 11class BiEncoderTokenizer(LightningIRTokenizer): 12 13 config_class = BiEncoderConfig 14
[docs] 15 def __init__( 16 self, 17 *args, 18 query_token: str = "[QUE]", 19 doc_token: str = "[DOC]", 20 query_expansion: bool = False, 21 query_length: int = 32, 22 attend_to_query_expanded_tokens: bool = False, 23 doc_expansion: bool = False, 24 doc_length: int = 512, 25 attend_to_doc_expanded_tokens: bool = False, 26 add_marker_tokens: bool = True, 27 **kwargs, 28 ): 29 super().__init__( 30 *args, 31 query_token=query_token, 32 doc_token=doc_token, 33 query_expansion=query_expansion, 34 query_length=query_length, 35 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens, 36 doc_expansion=doc_expansion, 37 doc_length=doc_length, 38 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens, 39 add_marker_tokens=add_marker_tokens, 40 **kwargs, 41 ) 42 self.query_expansion = query_expansion 43 self.query_length = query_length 44 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 45 self.doc_expansion = doc_expansion 46 self.doc_length = doc_length 47 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens 48 self.add_marker_tokens = add_marker_tokens 49 50 self._query_token = query_token 51 self._doc_token = doc_token 52 53 self.query_post_processor: TemplateProcessing | None = None 54 self.doc_post_processor: TemplateProcessing | None = None 55 if add_marker_tokens: 56 # TODO support other tokenizers 57 if not isinstance(self, (BertTokenizer, BertTokenizerFast)): 58 raise ValueError("Adding marker tokens is only supported for BertTokenizer.") 59 self.add_tokens([query_token, doc_token], special_tokens=True) 60 self.query_post_processor = TemplateProcessing( 61 single=f"[CLS] {self.query_token} $0 [SEP]", 62 pair=f"[CLS] {self.query_token} $A [SEP] {self.doc_token} $B:1 [SEP]:1", 63 special_tokens=[ 64 ("[CLS]", self.cls_token_id), 65 ("[SEP]", self.sep_token_id), 66 (self.query_token, self.query_token_id), 67 (self.doc_token, self.doc_token_id), 68 ], 69 ) 70 self.doc_post_processor = TemplateProcessing( 71 single=f"[CLS] {self.doc_token} $0 [SEP]", 72 pair=f"[CLS] {self.query_token} $A [SEP] {self.doc_token} $B:1 [SEP]:1", 73 special_tokens=[ 74 ("[CLS]", self.cls_token_id), 75 ("[SEP]", self.sep_token_id), 76 (self.query_token, self.query_token_id), 77 (self.doc_token, self.doc_token_id), 78 ], 79 )
80 81 @property 82 def query_token(self) -> str: 83 return self._query_token 84 85 @property 86 def doc_token(self) -> str: 87 return self._doc_token 88 89 @property 90 def query_token_id(self) -> int | None: 91 if self.query_token in self.added_tokens_encoder: 92 return self.added_tokens_encoder[self.query_token] 93 return None 94 95 @property 96 def doc_token_id(self) -> int | None: 97 if self.doc_token in self.added_tokens_encoder: 98 return self.added_tokens_encoder[self.doc_token] 99 return None 100 101 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding: 102 if warn: 103 warnings.warn( 104 "BiEncoderTokenizer is being directly called. Use tokenize_query and " 105 "tokenize_doc to make sure marker_tokens and query/doc expansion is " 106 "applied." 107 ) 108 return super().__call__(*args, **kwargs) 109 110 def _encode( 111 self, 112 text: str | Sequence[str], 113 *args, 114 post_processor: TemplateProcessing | None = None, 115 **kwargs, 116 ) -> BatchEncoding: 117 orig_post_processor = self._tokenizer.post_processor 118 if post_processor is not None: 119 self._tokenizer.post_processor = post_processor 120 if kwargs.get("return_tensors", None) is not None: 121 kwargs["pad_to_multiple_of"] = 8 122 encoding = self(text, *args, warn=False, **kwargs) 123 self._tokenizer.post_processor = orig_post_processor 124 return encoding 125 126 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 127 input_ids = encoding["input_ids"] 128 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 129 encoding["input_ids"] = input_ids 130 if attend_to_expanded_tokens: 131 encoding["attention_mask"].fill_(1) 132 return encoding 133 134 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 135 kwargs["max_length"] = self.query_length 136 if self.query_expansion: 137 kwargs["padding"] = "max_length" 138 else: 139 kwargs["truncation"] = True 140 encoding = self._encode(queries, *args, post_processor=self.query_post_processor, **kwargs) 141 if self.query_expansion: 142 self._expand(encoding, self.attend_to_query_expanded_tokens) 143 return encoding 144 145 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 146 kwargs["max_length"] = self.doc_length 147 if self.doc_expansion: 148 kwargs["padding"] = "max_length" 149 else: 150 kwargs["truncation"] = True 151 encoding = self._encode(docs, *args, post_processor=self.doc_post_processor, **kwargs) 152 if self.doc_expansion: 153 self._expand(encoding, self.attend_to_doc_expanded_tokens) 154 return encoding 155 156 def tokenize( 157 self, 158 queries: str | Sequence[str] | None = None, 159 docs: str | Sequence[str] | None = None, 160 **kwargs, 161 ) -> Dict[str, BatchEncoding]: 162 encodings = {} 163 kwargs.pop("num_docs", None) 164 if queries is not None: 165 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs) 166 if docs is not None: 167 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs) 168 return encodings