1
+ class ContrastiveLoss (nn .Module ):
2
+ def __init__ (self , temperature = 0.07 , smoothing = 0.1 , lambda_entropy = 0.02 ):
3
+ super (ContrastiveLoss , self ).__init__ ()
4
+ self .temperature = temperature
5
+ self .smoothing = smoothing
6
+ self .lambda_entropy = lambda_entropy
7
+
8
+ def forward (self , logits_per_image , logits_per_text ):
9
+ # Normalize the features to avoid overflow or underflow
10
+ logits_per_image = F .normalize (logits_per_image , p = 2 , dim = 1 )
11
+ logits_per_text = F .normalize (logits_per_text , p = 2 , dim = 1 )
12
+
13
+ # Calculate logits
14
+ logits = torch .matmul (logits_per_image , logits_per_text .t ()) / self .temperature
15
+ labels = torch .arange (logits .size (0 ), device = logits .device )
16
+
17
+ # Apply label smoothing
18
+ N = logits .size (0 )
19
+ smoothed_labels = torch .full_like (logits , self .smoothing / (N - 1 ))
20
+ smoothed_labels .scatter_ (1 , labels .unsqueeze (1 ), 1.0 - self .smoothing )
21
+
22
+ # Calculate loss manually using log-softmax and smoothed labels
23
+ log_probs = F .log_softmax (logits , dim = 1 )
24
+ loss_img = - (smoothed_labels * log_probs ).sum (dim = 1 ).mean ()
25
+
26
+ log_probs = F .log_softmax (logits .t (), dim = 1 )
27
+ loss_txt = - (smoothed_labels * log_probs ).sum (dim = 1 ).mean ()
28
+
29
+ # Calculate entropy of the predictions to add confidence regularization
30
+ probs_img = F .softmax (logits , dim = 1 )
31
+ entropy_img = - torch .sum (probs_img * torch .log (probs_img + 1e-8 ), dim = 1 ).mean ()
32
+
33
+ probs_txt = F .softmax (logits .t (), dim = 1 )
34
+ entropy_txt = - torch .sum (probs_txt * torch .log (probs_txt + 1e-8 ), dim = 1 ).mean ()
35
+
36
+ # Combine the losses with confidence regularization
37
+ entropy_penalty = (entropy_img + entropy_txt ) / 2
38
+ total_loss = (loss_img + loss_txt ) / 2 - self .lambda_entropy * entropy_penalty
39
+
40
+ return total_loss
0 commit comments