1from abc import ABC, abstractmethod
2from typing import Literal, Tuple
3
4import torch
5
6
[docs]
7class LossFunction(ABC):
8 @abstractmethod
9 def compute_loss(self, *args, **kwargs) -> torch.Tensor: ...
10
11
[docs]
12class ScoringLossFunction(LossFunction):
13 @abstractmethod
14 def compute_loss(
15 self,
16 scores: torch.Tensor,
17 targets: torch.Tensor,
18 ) -> torch.Tensor: ...
19
20 def process_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
21 if targets.ndim > scores.ndim:
22 return targets.max(-1).values
23 return targets
24
25
[docs]
26class EmbeddingLossFunction(LossFunction):
27 @abstractmethod
28 def compute_loss(
29 self,
30 query_embeddings: torch.Tensor,
31 doc_embeddings: torch.Tensor,
32 ) -> torch.Tensor: ...
33
34
[docs]
35class PairwiseLossFunction(ScoringLossFunction):
36 def get_pairwise_idcs(self, targets: torch.Tensor) -> Tuple[torch.Tensor, ...]:
37 # positive items are items where label is greater than other label in sample
38 return torch.nonzero(targets[..., None] > targets[:, None], as_tuple=True)
39
40
[docs]
41class ListwiseLossFunction(ScoringLossFunction):
42 pass
43
44
[docs]
45class MarginMSE(PairwiseLossFunction):
[docs]
46 def __init__(self, margin: float | Literal["scores"] = 1.0):
47 self.margin = margin
48
49 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
50 targets = self.process_targets(scores, targets)
51 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
52 pos = scores[query_idcs, pos_idcs]
53 neg = scores[query_idcs, neg_idcs]
54 margin = pos - neg
55 if isinstance(self.margin, float):
56 target_margin = torch.tensor(self.margin, device=scores.device)
57 elif self.margin == "scores":
58 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs]
59 else:
60 raise ValueError("invalid margin type")
61 loss = torch.nn.functional.mse_loss(margin, target_margin.clamp(min=0))
62 return loss
63
64
[docs]
65class ConstantMarginMSE(MarginMSE):
[docs]
66 def __init__(self, margin: float = 1.0):
67 super().__init__(margin)
68
69
[docs]
70class SupervisedMarginMSE(MarginMSE):
[docs]
71 def __init__(self):
72 super().__init__("scores")
73
74
[docs]
75class RankNet(PairwiseLossFunction):
76 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
77 targets = self.process_targets(scores, targets)
78 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
79 pos = scores[query_idcs, pos_idcs]
80 neg = scores[query_idcs, neg_idcs]
81 margin = pos - neg
82 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin))
83 return loss
84
85
[docs]
86class KLDivergence(ListwiseLossFunction):
87 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
88 targets = self.process_targets(scores, targets)
89 scores = torch.nn.functional.log_softmax(scores, dim=-1)
90 targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1)
91 loss = torch.nn.functional.kl_div(scores, targets, log_target=True, reduction="batchmean")
92 return loss
93
94
[docs]
95class LocalizedContrastiveEstimation(ListwiseLossFunction):
96 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
97 targets = self.process_targets(scores, targets)
98 targets = targets.argmax(dim=1)
99 loss = torch.nn.functional.cross_entropy(scores, targets)
100 return loss
101
102
[docs]
103class ApproxLossFunction(ListwiseLossFunction):
[docs]
104 def __init__(self, temperature: float = 1) -> None:
105 super().__init__()
106 self.temperature = temperature
107
108 @staticmethod
109 def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor:
110 score_diff = scores[:, None] - scores[..., None]
111 normalized_score_diff = torch.sigmoid(score_diff / temperature)
112 # set diagonal to 0
113 normalized_score_diff = normalized_score_diff * (1 - torch.eye(scores.shape[1], device=scores.device))
114 approx_ranks = normalized_score_diff.sum(-1) + 1
115 return approx_ranks
116
117
[docs]
118class ApproxNDCG(ApproxLossFunction):
[docs]
119 def __init__(self, temperature: float = 1, scale_gains: bool = True):
120 super().__init__(temperature)
121 self.scale_gains = scale_gains
122
123 @staticmethod
124 def get_dcg(
125 ranks: torch.Tensor,
126 targets: torch.Tensor,
127 k: int | None = None,
128 scale_gains: bool = True,
129 ) -> torch.Tensor:
130 log_ranks = torch.log2(1 + ranks)
131 discounts = 1 / log_ranks
132 if scale_gains:
133 gains = 2**targets - 1
134 else:
135 gains = targets
136 dcgs = gains * discounts
137 if k is not None:
138 dcgs = dcgs.masked_fill(ranks > k, 0)
139 return dcgs.sum(dim=-1)
140
141 @staticmethod
142 def get_ndcg(
143 ranks: torch.Tensor,
144 targets: torch.Tensor,
145 k: int | None = None,
146 scale_gains: bool = True,
147 optimal_targets: torch.Tensor | None = None,
148 ) -> torch.Tensor:
149 targets = targets.clamp(min=0)
150 if optimal_targets is None:
151 optimal_targets = targets
152 optimal_ranks = torch.argsort(torch.argsort(optimal_targets, descending=True))
153 optimal_ranks = optimal_ranks + 1
154 dcg = ApproxNDCG.get_dcg(ranks, targets, k, scale_gains)
155 idcg = ApproxNDCG.get_dcg(optimal_ranks, optimal_targets, k, scale_gains)
156 ndcg = dcg / (idcg.clamp(min=1e-12))
157 return ndcg
158
159 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
160 targets = self.process_targets(scores, targets)
161 approx_ranks = self.get_approx_ranks(scores, self.temperature)
162 ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains)
163 loss = 1 - ndcg
164 return loss.mean()
165
166
[docs]
167class ApproxMRR(ApproxLossFunction):
[docs]
168 def __init__(self, temperature: float = 1):
169 super().__init__(temperature)
170
171 @staticmethod
172 def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> torch.Tensor:
173 targets = targets.clamp(None, 1)
174 reciprocal_ranks = 1 / ranks
175 mrr = reciprocal_ranks * targets
176 if k is not None:
177 mrr = mrr.masked_fill(ranks > k, 0)
178 mrr = mrr.max(dim=-1)[0]
179 return mrr
180
181 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
182 targets = self.process_targets(scores, targets)
183 approx_ranks = self.get_approx_ranks(scores, self.temperature)
184 mrr = self.get_mrr(approx_ranks, targets, k=None)
185 loss = 1 - mrr
186 return loss.mean()
187
188
[docs]
189class ApproxRankMSE(ApproxLossFunction):
[docs]
190 def __init__(
191 self,
192 temperature: float = 1,
193 discount: Literal["log2", "reciprocal"] | None = None,
194 ):
195 super().__init__(temperature)
196 self.discount = discount
197
198 def compute_loss(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
199 targets = self.process_targets(scores, targets)
200 approx_ranks = self.get_approx_ranks(scores, self.temperature)
201 ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1
202 loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none")
203 if self.discount == "log2":
204 weight = 1 / torch.log2(ranks + 1)
205 elif self.discount == "reciprocal":
206 weight = 1 / ranks
207 else:
208 weight = 1
209 loss = loss * weight
210 loss = loss.mean()
211 return loss
212
213
[docs]
214class NeuralLossFunction(ListwiseLossFunction):
215 # TODO add neural loss functions
216
[docs]
217 def __init__(self, temperature: float = 1, tol: float = 1e-5, max_iter: int = 50) -> None:
218 super().__init__()
219 self.temperature = temperature
220 self.tol = tol
221 self.max_iter = max_iter
222
223 def neural_sort(self, scores: torch.Tensor) -> torch.Tensor:
224 # https://github.com/ermongroup/neuralsort/blob/master/pytorch/neuralsort.py
225 scores = scores.unsqueeze(-1)
226 dim = scores.shape[1]
227 one = torch.ones((dim, 1), device=scores.device)
228
229 A_scores = torch.abs(scores - scores.permute(0, 2, 1))
230 B = torch.matmul(A_scores, torch.matmul(one, torch.transpose(one, 0, 1)))
231 scaling = dim + 1 - 2 * (torch.arange(dim, device=scores.device) + 1)
232 C = torch.matmul(scores, scaling.to(scores).unsqueeze(0))
233
234 P_max = (C - B).permute(0, 2, 1)
235 P_hat = torch.nn.functional.softmax(P_max / self.temperature, dim=-1)
236
237 P_hat = self.sinkhorn_scaling(P_hat)
238
239 return P_hat
240
241 def sinkhorn_scaling(self, mat: torch.Tensor) -> torch.Tensor:
242 # https://github.com/allegro/allRank/blob/master/allrank/models/losses/loss_utils.py#L8
243 idx = 0
244 while True:
245 if (
246 torch.max(torch.abs(mat.sum(dim=2) - 1.0)) < self.tol
247 and torch.max(torch.abs(mat.sum(dim=1) - 1.0)) < self.tol
248 ) or idx > self.max_iter:
249 break
250 mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=1e-12)
251 mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=1e-12)
252 idx += 1
253
254 return mat
255
256 def get_sorted_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
257 permutation_matrix = self.neural_sort(scores)
258 pred_sorted_targets = torch.matmul(permutation_matrix, targets[..., None].to(permutation_matrix)).squeeze(-1)
259 return pred_sorted_targets
260
261
[docs]
262class InBatchLossFunction(ScoringLossFunction):
[docs]
263 def __init__(
264 self,
265 pos_sampling_technique: Literal["all", "first"] = "all",
266 neg_sampling_technique: Literal["all", "first"] = "all",
267 max_num_neg_samples: int | None = None,
268 ):
269 super().__init__()
270 self.pos_sampling_technique = pos_sampling_technique
271 self.neg_sampling_technique = neg_sampling_technique
272 self.max_num_neg_samples = max_num_neg_samples
273
274 def get_ib_idcs(self, num_queries: int, num_docs: int) -> Tuple[torch.Tensor, torch.Tensor]:
275 min_idx = torch.arange(num_queries)[:, None] * num_docs
276 max_idx = min_idx + num_docs
277 if self.pos_sampling_technique == "all":
278 pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange(
279 num_queries * num_docs
280 )[None].less(max_idx)
281 elif self.pos_sampling_technique == "first":
282 pos_mask = torch.arange(num_queries * num_docs)[None].eq(min_idx)
283 else:
284 raise ValueError("invalid pos sampling technique")
285 pos_idcs = pos_mask.nonzero(as_tuple=True)[1]
286 if self.neg_sampling_technique == "all":
287 neg_mask = torch.arange(num_queries * num_docs)[None].less(min_idx) | torch.arange(num_queries * num_docs)[
288 None
289 ].greater_equal(max_idx)
290 elif self.neg_sampling_technique == "first":
291 neg_mask = torch.arange(num_queries * num_docs)[None, None].eq(min_idx).any(1) & torch.arange(
292 num_queries * num_docs
293 )[None].ne(min_idx)
294 else:
295 raise ValueError("invalid neg sampling technique")
296 neg_idcs = neg_mask.nonzero(as_tuple=True)[1]
297 if self.max_num_neg_samples is not None:
298 neg_idcs = neg_idcs.view(num_queries, -1)
299 if neg_idcs.shape[-1] > 1:
300 neg_idcs = neg_idcs[:, torch.randperm(neg_idcs.shape[-1])]
301 neg_idcs = neg_idcs[:, : self.max_num_neg_samples]
302 neg_idcs = neg_idcs.reshape(-1)
303 return pos_idcs, neg_idcs
304
305 def compute_loss(self, scores: torch.Tensor) -> torch.Tensor:
306 raise NotImplementedError("InBatchLossFunction.compute_loss must be implemented by subclasses")
307
308
[docs]
309class InBatchCrossEntropy(InBatchLossFunction):
310 def compute_loss(self, scores: torch.Tensor) -> torch.Tensor:
311 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device)
312 loss = torch.nn.functional.cross_entropy(scores, targets)
313 return loss
314
315
[docs]
316class RegularizationLossFunction(EmbeddingLossFunction):
[docs]
317 def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None:
318 self.query_weight = query_weight
319 self.doc_weight = doc_weight
320
321
[docs]
322class L2Regularization(RegularizationLossFunction):
323 def compute_loss(
324 self,
325 query_embeddings: torch.Tensor,
326 doc_embeddings: torch.Tensor,
327 ) -> torch.Tensor:
328 query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean()
329 doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean()
330 loss = query_loss + doc_loss
331 return loss
332
333
[docs]
334class L1Regularization(RegularizationLossFunction):
335 def compute_loss(
336 self,
337 query_embeddings: torch.Tensor,
338 doc_embeddings: torch.Tensor,
339 ) -> torch.Tensor:
340 query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean()
341 doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean()
342 loss = query_loss + doc_loss
343 return loss
344
345
[docs]
346class FLOPSRegularization(RegularizationLossFunction):
347 def compute_loss(
348 self,
349 query_embeddings: torch.Tensor,
350 doc_embeddings: torch.Tensor,
351 ) -> torch.Tensor:
352 query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2)
353 doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2)
354 anti_zero = 1 / (torch.sum(query_embeddings) ** 2) + 1 / (torch.sum(doc_embeddings) ** 2)
355 loss = self.query_weight * query_loss + self.doc_weight * doc_loss + anti_zero
356 return loss