Source code for lightning_ir.loss.loss

  1from abc import ABC, abstractmethod
  2from typing import Literal, Tuple
  3
  4import torch
  5
  6
[docs] 7class LossFunction(ABC): 8 @abstractmethod 9 def compute_loss(self, *args, **kwargs) -> torch.Tensor: ...
10 11
[docs] 12class ScoringLossFunction(LossFunction): 13 @abstractmethod 14 def compute_loss( 15 self, 16 scores: torch.Tensor, 17 targets: torch.Tensor, 18 ) -> torch.Tensor: ... 19 20 def process_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 21 if targets.ndim > scores.ndim: 22 return targets.max(-1).values 23 return targets
24 25
[docs] 26class EmbeddingLossFunction(LossFunction): 27 @abstractmethod 28 def compute_loss( 29 self, 30 query_embeddings: torch.Tensor, 31 doc_embeddings: torch.Tensor, 32 ) -> torch.Tensor: ...
33 34
[docs] 35class PairwiseLossFunction(ScoringLossFunction): 36 def get_pairwise_idcs(self, targets: torch.Tensor) -> Tuple[torch.Tensor, ...]: 37 # positive items are items where label is greater than other label in sample 38 return torch.nonzero(targets[..., None] > targets[:, None], as_tuple=True)
39 40
[docs] 41class ListwiseLossFunction(ScoringLossFunction): 42 pass
43 44
[docs] 45class MarginMSE(PairwiseLossFunction):
[docs] 46 def __init__(self, margin: float | Literal["scores"] = 1.0): 47 self.margin = margin
48 49 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 50 targets = self.process_targets(scores, targets) 51 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 52 pos = scores[query_idcs, pos_idcs] 53 neg = scores[query_idcs, neg_idcs] 54 margin = pos - neg 55 if isinstance(self.margin, float): 56 target_margin = torch.tensor(self.margin, device=scores.device) 57 elif self.margin == "scores": 58 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs] 59 else: 60 raise ValueError("invalid margin type") 61 loss = torch.nn.functional.mse_loss(margin, target_margin.clamp(min=0)) 62 return loss
63 64
[docs] 65class ConstantMarginMSE(MarginMSE):
[docs] 66 def __init__(self, margin: float = 1.0): 67 super().__init__(margin)
68 69
[docs] 70class SupervisedMarginMSE(MarginMSE):
[docs] 71 def __init__(self): 72 super().__init__("scores")
73 74
[docs] 75class RankNet(PairwiseLossFunction): 76 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 77 targets = self.process_targets(scores, targets) 78 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 79 pos = scores[query_idcs, pos_idcs] 80 neg = scores[query_idcs, neg_idcs] 81 margin = pos - neg 82 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin)) 83 return loss
84 85
[docs] 86class KLDivergence(ListwiseLossFunction): 87 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 88 targets = self.process_targets(scores, targets) 89 scores = torch.nn.functional.log_softmax(scores, dim=-1) 90 targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1) 91 loss = torch.nn.functional.kl_div(scores, targets, log_target=True, reduction="batchmean") 92 return loss
93 94
[docs] 95class LocalizedContrastiveEstimation(ListwiseLossFunction): 96 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 97 targets = self.process_targets(scores, targets) 98 targets = targets.argmax(dim=1) 99 loss = torch.nn.functional.cross_entropy(scores, targets) 100 return loss
101 102
[docs] 103class ApproxLossFunction(ListwiseLossFunction):
[docs] 104 def __init__(self, temperature: float = 1) -> None: 105 super().__init__() 106 self.temperature = temperature
107 108 @staticmethod 109 def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor: 110 score_diff = scores[:, None] - scores[..., None] 111 normalized_score_diff = torch.sigmoid(score_diff / temperature) 112 # set diagonal to 0 113 normalized_score_diff = normalized_score_diff * (1 - torch.eye(scores.shape[1], device=scores.device)) 114 approx_ranks = normalized_score_diff.sum(-1) + 1 115 return approx_ranks
116 117
[docs] 118class ApproxNDCG(ApproxLossFunction):
[docs] 119 def __init__(self, temperature: float = 1, scale_gains: bool = True): 120 super().__init__(temperature) 121 self.scale_gains = scale_gains
122 123 @staticmethod 124 def get_dcg( 125 ranks: torch.Tensor, 126 targets: torch.Tensor, 127 k: int | None = None, 128 scale_gains: bool = True, 129 ) -> torch.Tensor: 130 log_ranks = torch.log2(1 + ranks) 131 discounts = 1 / log_ranks 132 if scale_gains: 133 gains = 2**targets - 1 134 else: 135 gains = targets 136 dcgs = gains * discounts 137 if k is not None: 138 dcgs = dcgs.masked_fill(ranks > k, 0) 139 return dcgs.sum(dim=-1) 140 141 @staticmethod 142 def get_ndcg( 143 ranks: torch.Tensor, 144 targets: torch.Tensor, 145 k: int | None = None, 146 scale_gains: bool = True, 147 optimal_targets: torch.Tensor | None = None, 148 ) -> torch.Tensor: 149 targets = targets.clamp(min=0) 150 if optimal_targets is None: 151 optimal_targets = targets 152 optimal_ranks = torch.argsort(torch.argsort(optimal_targets, descending=True)) 153 optimal_ranks = optimal_ranks + 1 154 dcg = ApproxNDCG.get_dcg(ranks, targets, k, scale_gains) 155 idcg = ApproxNDCG.get_dcg(optimal_ranks, optimal_targets, k, scale_gains) 156 ndcg = dcg / (idcg.clamp(min=1e-12)) 157 return ndcg 158 159 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 160 targets = self.process_targets(scores, targets) 161 approx_ranks = self.get_approx_ranks(scores, self.temperature) 162 ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains) 163 loss = 1 - ndcg 164 return loss.mean()
165 166
[docs] 167class ApproxMRR(ApproxLossFunction):
[docs] 168 def __init__(self, temperature: float = 1): 169 super().__init__(temperature)
170 171 @staticmethod 172 def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> torch.Tensor: 173 targets = targets.clamp(None, 1) 174 reciprocal_ranks = 1 / ranks 175 mrr = reciprocal_ranks * targets 176 if k is not None: 177 mrr = mrr.masked_fill(ranks > k, 0) 178 mrr = mrr.max(dim=-1)[0] 179 return mrr 180 181 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 182 targets = self.process_targets(scores, targets) 183 approx_ranks = self.get_approx_ranks(scores, self.temperature) 184 mrr = self.get_mrr(approx_ranks, targets, k=None) 185 loss = 1 - mrr 186 return loss.mean()
187 188
[docs] 189class ApproxRankMSE(ApproxLossFunction):
[docs] 190 def __init__( 191 self, 192 temperature: float = 1, 193 discount: Literal["log2", "reciprocal"] | None = None, 194 ): 195 super().__init__(temperature) 196 self.discount = discount
197 198 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 199 targets = self.process_targets(scores, targets) 200 approx_ranks = self.get_approx_ranks(scores, self.temperature) 201 ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1 202 loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none") 203 if self.discount == "log2": 204 weight = 1 / torch.log2(ranks + 1) 205 elif self.discount == "reciprocal": 206 weight = 1 / ranks 207 else: 208 weight = 1 209 loss = loss * weight 210 loss = loss.mean() 211 return loss
212 213
[docs] 214class NeuralLossFunction(ListwiseLossFunction): 215 # TODO add neural loss functions 216
[docs] 217 def __init__(self, temperature: float = 1, tol: float = 1e-5, max_iter: int = 50) -> None: 218 super().__init__() 219 self.temperature = temperature 220 self.tol = tol 221 self.max_iter = max_iter
222 223 def neural_sort(self, scores: torch.Tensor) -> torch.Tensor: 224 # https://github.com/ermongroup/neuralsort/blob/master/pytorch/neuralsort.py 225 scores = scores.unsqueeze(-1) 226 dim = scores.shape[1] 227 one = torch.ones((dim, 1), device=scores.device) 228 229 A_scores = torch.abs(scores - scores.permute(0, 2, 1)) 230 B = torch.matmul(A_scores, torch.matmul(one, torch.transpose(one, 0, 1))) 231 scaling = dim + 1 - 2 * (torch.arange(dim, device=scores.device) + 1) 232 C = torch.matmul(scores, scaling.to(scores).unsqueeze(0)) 233 234 P_max = (C - B).permute(0, 2, 1) 235 P_hat = torch.nn.functional.softmax(P_max / self.temperature, dim=-1) 236 237 P_hat = self.sinkhorn_scaling(P_hat) 238 239 return P_hat 240 241 def sinkhorn_scaling(self, mat: torch.Tensor) -> torch.Tensor: 242 # https://github.com/allegro/allRank/blob/master/allrank/models/losses/loss_utils.py#L8 243 idx = 0 244 while True: 245 if ( 246 torch.max(torch.abs(mat.sum(dim=2) - 1.0)) < self.tol 247 and torch.max(torch.abs(mat.sum(dim=1) - 1.0)) < self.tol 248 ) or idx > self.max_iter: 249 break 250 mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=1e-12) 251 mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=1e-12) 252 idx += 1 253 254 return mat 255 256 def get_sorted_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 257 permutation_matrix = self.neural_sort(scores) 258 pred_sorted_targets = torch.matmul(permutation_matrix, targets[..., None].to(permutation_matrix)).squeeze(-1) 259 return pred_sorted_targets
260 261
[docs] 262class InBatchLossFunction(ScoringLossFunction):
[docs] 263 def __init__( 264 self, 265 pos_sampling_technique: Literal["all", "first"] = "all", 266 neg_sampling_technique: Literal["all", "first"] = "all", 267 max_num_neg_samples: int | None = None, 268 ): 269 super().__init__() 270 self.pos_sampling_technique = pos_sampling_technique 271 self.neg_sampling_technique = neg_sampling_technique 272 self.max_num_neg_samples = max_num_neg_samples
273 274 def get_ib_idcs(self, num_queries: int, num_docs: int) -> Tuple[torch.Tensor, torch.Tensor]: 275 min_idx = torch.arange(num_queries)[:, None] * num_docs 276 max_idx = min_idx + num_docs 277 if self.pos_sampling_technique == "all": 278 pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange( 279 num_queries * num_docs 280 )[None].less(max_idx) 281 elif self.pos_sampling_technique == "first": 282 pos_mask = torch.arange(num_queries * num_docs)[None].eq(min_idx) 283 else: 284 raise ValueError("invalid pos sampling technique") 285 pos_idcs = pos_mask.nonzero(as_tuple=True)[1] 286 if self.neg_sampling_technique == "all": 287 neg_mask = torch.arange(num_queries * num_docs)[None].less(min_idx) | torch.arange(num_queries * num_docs)[ 288 None 289 ].greater_equal(max_idx) 290 elif self.neg_sampling_technique == "first": 291 neg_mask = torch.arange(num_queries * num_docs)[None, None].eq(min_idx).any(1) & torch.arange( 292 num_queries * num_docs 293 )[None].ne(min_idx) 294 else: 295 raise ValueError("invalid neg sampling technique") 296 neg_idcs = neg_mask.nonzero(as_tuple=True)[1] 297 if self.max_num_neg_samples is not None: 298 neg_idcs = neg_idcs.view(num_queries, -1) 299 if neg_idcs.shape[-1] > 1: 300 neg_idcs = neg_idcs[:, torch.randperm(neg_idcs.shape[-1])] 301 neg_idcs = neg_idcs[:, : self.max_num_neg_samples] 302 neg_idcs = neg_idcs.reshape(-1) 303 return pos_idcs, neg_idcs 304 305 def compute_loss(self, scores: torch.Tensor) -> torch.Tensor: 306 raise NotImplementedError("InBatchLossFunction.compute_loss must be implemented by subclasses")
307 308
[docs] 309class InBatchCrossEntropy(InBatchLossFunction): 310 def compute_loss(self, scores: torch.Tensor) -> torch.Tensor: 311 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device) 312 loss = torch.nn.functional.cross_entropy(scores, targets) 313 return loss
314 315
[docs] 316class RegularizationLossFunction(EmbeddingLossFunction):
[docs] 317 def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None: 318 self.query_weight = query_weight 319 self.doc_weight = doc_weight
320 321
[docs] 322class L2Regularization(RegularizationLossFunction): 323 def compute_loss( 324 self, 325 query_embeddings: torch.Tensor, 326 doc_embeddings: torch.Tensor, 327 ) -> torch.Tensor: 328 query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean() 329 doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean() 330 loss = query_loss + doc_loss 331 return loss
332 333
[docs] 334class L1Regularization(RegularizationLossFunction): 335 def compute_loss( 336 self, 337 query_embeddings: torch.Tensor, 338 doc_embeddings: torch.Tensor, 339 ) -> torch.Tensor: 340 query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean() 341 doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean() 342 loss = query_loss + doc_loss 343 return loss
344 345
[docs] 346class FLOPSRegularization(RegularizationLossFunction): 347 def compute_loss( 348 self, 349 query_embeddings: torch.Tensor, 350 doc_embeddings: torch.Tensor, 351 ) -> torch.Tensor: 352 query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2) 353 doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2) 354 anti_zero = 1 / (torch.sum(query_embeddings) ** 2) + 1 / (torch.sum(doc_embeddings) ** 2) 355 loss = self.query_weight * query_loss + self.doc_weight * doc_loss + anti_zero 356 return loss