Skip to content

Commit b6d5185

Browse files
author
Minseon Kim
authored
Update attack_lib.py
1 parent 986b34c commit b6d5185

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/attack_lib.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def get_loss(self, original_images, target, optimizer, weight, random_start=True
149149
elif self.loss_type == 'l1':
150150
loss = F.l1_loss(self.projector(self.model(x)), self.projector(self.model(target)))
151151
elif self.loss_type =='cos':
152-
loss = -F.cosine_similarity(self.projector(self.model(x)), self.projector(self.model(target))).mean()
152+
loss = 1-F.cosine_similarity(self.projector(self.model(x)), self.projector(self.model(target))).mean()
153153

154154
grads = torch.autograd.grad(loss, x, grad_outputs=None, only_inputs=True, retain_graph=False)[0]
155155

@@ -178,6 +178,6 @@ def get_loss(self, original_images, target, optimizer, weight, random_start=True
178178
elif self.loss_type == 'l1':
179179
loss = F.l1_loss(self.projector(self.model(x)), self.projector(self.model(target))) * (1.0/batch_size)
180180
elif self.loss_type == 'cos':
181-
loss = -F.cosine_similarity(self.projector(self.model(x)), self.projector(self.model(target))).sum() * (1.0/batch_size)
181+
loss = 1-F.cosine_similarity(self.projector(self.model(x)), self.projector(self.model(target))).sum() * (1.0/batch_size)
182182

183183
return x.detach(), loss

0 commit comments

Comments
 (0)