Skip to content

Commit 409a2b4

Browse files
authoredOct 22, 2024
Add entropy penalty custom loss
1 parent 1cc9332 commit 409a2b4

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed
 
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
class ContrastiveLoss(nn.Module):
2+
def __init__(self, temperature=0.07, smoothing=0.1, lambda_entropy=0.02):
3+
super(ContrastiveLoss, self).__init__()
4+
self.temperature = temperature
5+
self.smoothing = smoothing
6+
self.lambda_entropy = lambda_entropy
7+
8+
def forward(self, logits_per_image, logits_per_text):
9+
# Normalize the features to avoid overflow or underflow
10+
logits_per_image = F.normalize(logits_per_image, p=2, dim=1)
11+
logits_per_text = F.normalize(logits_per_text, p=2, dim=1)
12+
13+
# Calculate logits
14+
logits = torch.matmul(logits_per_image, logits_per_text.t()) / self.temperature
15+
labels = torch.arange(logits.size(0), device=logits.device)
16+
17+
# Apply label smoothing
18+
N = logits.size(0)
19+
smoothed_labels = torch.full_like(logits, self.smoothing / (N - 1))
20+
smoothed_labels.scatter_(1, labels.unsqueeze(1), 1.0 - self.smoothing)
21+
22+
# Calculate loss manually using log-softmax and smoothed labels
23+
log_probs = F.log_softmax(logits, dim=1)
24+
loss_img = -(smoothed_labels * log_probs).sum(dim=1).mean()
25+
26+
log_probs = F.log_softmax(logits.t(), dim=1)
27+
loss_txt = -(smoothed_labels * log_probs).sum(dim=1).mean()
28+
29+
# Calculate entropy of the predictions to add confidence regularization
30+
probs_img = F.softmax(logits, dim=1)
31+
entropy_img = -torch.sum(probs_img * torch.log(probs_img + 1e-8), dim=1).mean()
32+
33+
probs_txt = F.softmax(logits.t(), dim=1)
34+
entropy_txt = -torch.sum(probs_txt * torch.log(probs_txt + 1e-8), dim=1).mean()
35+
36+
# Combine the losses with confidence regularization
37+
entropy_penalty = (entropy_img + entropy_txt) / 2
38+
total_loss = (loss_img + loss_txt) / 2 - self.lambda_entropy * entropy_penalty
39+
40+
return total_loss

0 commit comments

Comments
 (0)
Please sign in to comment.