Source code for lightning_ir.cross_encoder.config

 1from typing import Literal
 2
 3from ..base import LightningIRConfig
 4
 5
[docs] 6class CrossEncoderConfig(LightningIRConfig): 7 model_type = "cross-encoder" 8 9 ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union({"pooling_strategy", "linear_bias"}) 10
[docs] 11 def __init__( 12 self, 13 query_length: int = 32, 14 doc_length: int = 512, 15 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 16 linear_bias: bool = False, 17 **kwargs 18 ): 19 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) 20 self.pooling_strategy = pooling_strategy 21 self.linear_bias = linear_bias