Source code for lightning_ir.cross_encoder.config
1from typing import Literal
2
3from ..base import LightningIRConfig
4
5
[docs]
6class CrossEncoderConfig(LightningIRConfig):
7 model_type = "cross-encoder"
8
9 ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union({"pooling_strategy", "linear_bias"})
10
[docs]
11 def __init__(
12 self,
13 query_length: int = 32,
14 doc_length: int = 512,
15 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
16 linear_bias: bool = False,
17 **kwargs
18 ):
19 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
20 self.pooling_strategy = pooling_strategy
21 self.linear_bias = linear_bias