-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Carlini Linfinity attack #70
Open
samuelemarro
wants to merge
5
commits into
BorealisAI:master
Choose a base branch
from
samuelemarro:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
5a7a14d
Added Carlini LInf attack.
samuelemarro 4c23ade
Implemented Carlini Wagner LInf.
samuelemarro 4983c85
Fixed Carlini Linf formatting.
samuelemarro 60d5496
Added return_best.
samuelemarro 1a6db7f
Added assertions for replace_active.
samuelemarro File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
from advertorch.utils import clamp | ||
from advertorch.utils import to_one_hot | ||
from advertorch.utils import replicate_input | ||
from advertorch.utils import replace_active | ||
|
||
from .base import Attack | ||
from .base import LabelMixin | ||
|
@@ -34,27 +35,34 @@ | |
UPPER_CHECK = 1e9 | ||
PREV_LOSS_INIT = 1e6 | ||
TARGET_MULT = 10000.0 | ||
EPS = 1e-6 | ||
NUM_CHECKS = 10 | ||
|
||
try: | ||
boolean_type = torch.bool | ||
except AttributeError: | ||
# Old version, use torch.uint8 | ||
boolean_type = torch.uint8 | ||
|
||
|
||
class CarliniWagnerL2Attack(Attack, LabelMixin): | ||
""" | ||
The Carlini and Wagner L2 Attack, https://arxiv.org/abs/1608.04644 | ||
|
||
:param predict: forward pass function. | ||
:param num_classes: number of clasess. | ||
:param num_classes: number of classes. | ||
:param confidence: confidence of the adversarial examples. | ||
:param targeted: if the attack is targeted. | ||
:param learning_rate: the learning rate for the attack algorithm | ||
:param learning_rate: the learning rate for the attack algorithm. | ||
:param binary_search_steps: number of binary search times to find the | ||
optimum | ||
:param max_iterations: the maximum number of iterations | ||
optimum. | ||
:param max_iterations: the maximum number of iterations. | ||
:param abort_early: if set to true, abort early if getting stuck in local | ||
min | ||
:param initial_const: initial value of the constant c | ||
min. | ||
:param initial_const: initial value of the constant c. | ||
:param clip_min: mininum value per input dimension. | ||
:param clip_max: maximum value per input dimension. | ||
:param loss_fn: loss function | ||
:param loss_fn: loss function. | ||
""" | ||
|
||
def __init__(self, predict, num_classes, confidence=0, | ||
|
@@ -66,7 +74,7 @@ def __init__(self, predict, num_classes, confidence=0, | |
if loss_fn is not None: | ||
import warnings | ||
warnings.warn( | ||
"This Attack currently do not support a different loss" | ||
"This Attack currently does not support a different loss" | ||
" function other than the default. Setting loss_fn manually" | ||
" is not effective." | ||
) | ||
|
@@ -126,7 +134,6 @@ def _is_successful(self, output, label, is_logits): | |
|
||
return is_successful(pred, label, self.targeted) | ||
|
||
|
||
def _forward_and_update_delta( | ||
self, optimizer, x_atanh, delta, y_onehot, loss_coeffs): | ||
|
||
|
@@ -141,7 +148,6 @@ def _forward_and_update_delta( | |
|
||
return loss.item(), l2distsq.data, output.data, adv.data | ||
|
||
|
||
def _get_arctanh_x(self, x): | ||
result = clamp((x - self.clip_min) / (self.clip_max - self.clip_min), | ||
min=0., max=1.) * 2 - 1 | ||
|
@@ -192,7 +198,6 @@ def _update_loss_coeffs( | |
else: | ||
loss_coeffs[ii] *= 10 | ||
|
||
|
||
def perturb(self, x, y=None): | ||
x, y = self._verify_and_process_inputs(x, y) | ||
|
||
|
@@ -245,3 +250,281 @@ def perturb(self, x, y=None): | |
loss_coeffs, coeff_upper_bound, coeff_lower_bound) | ||
|
||
return final_advs | ||
|
||
|
||
class CarliniWagnerLinfAttack(Attack, LabelMixin): | ||
""" | ||
The Carlini and Wagner LInfinity Attack, https://arxiv.org/abs/1608.04644 | ||
|
||
:param predict: forward pass function (pre-softmax). | ||
:param num_classes: number of classes. | ||
:param min_tau: the minimum value of tau. | ||
:param initial_tau: the initial value of tau. | ||
:param tau_factor: the decay rate of tau (between 0 and 1) | ||
:param initial_const: initial value of the constant c. | ||
:param max_const: the maximum value of the constant c. | ||
:param const_factor: the rate of growth of the constant c. | ||
:param reduce_const: if True, the inital value of c is halved every | ||
time tau is reduced. | ||
:param warm_start: if True, use the previous adversarials as starting point | ||
for the next iteration. | ||
:param targeted: if the attack is targeted. | ||
:param learning_rate: the learning rate for the attack algorithm. | ||
:param max_iterations: the maximum number of iterations. | ||
:param abort_early: if set to true, abort early if getting stuck in local | ||
min. | ||
:param clip_min: mininum value per input dimension. | ||
:param clip_max: maximum value per input dimension. | ||
:param loss_fn: loss function | ||
:param return_best: if True, return the best adversarial found, else | ||
return the the last adversarial found. | ||
""" | ||
|
||
def __init__(self, predict, num_classes, min_tau=1 / 256, | ||
initial_tau=1, tau_factor=0.9, initial_const=1e-5, | ||
max_const=20, const_factor=2, reduce_const=False, | ||
warm_start=True, targeted=False, learning_rate=5e-3, | ||
max_iterations=1000, abort_early=True, clip_min=0., | ||
clip_max=1., loss_fn=None, return_best=True): | ||
"""Carlini Wagner LInfinity Attack implementation in pytorch.""" | ||
if loss_fn is not None: | ||
import warnings | ||
warnings.warn( | ||
"This Attack currently does not support a different loss" | ||
" function other than the default. Setting loss_fn manually" | ||
" is not effective." | ||
) | ||
|
||
loss_fn = None | ||
|
||
super(CarliniWagnerLinfAttack, self).__init__( | ||
predict, loss_fn, clip_min, clip_max) | ||
|
||
self.predict = predict | ||
self.num_classes = num_classes | ||
self.min_tau = min_tau | ||
self.initial_tau = initial_tau | ||
self.tau_factor = tau_factor | ||
self.initial_const = initial_const | ||
self.max_const = max_const | ||
self.const_factor = const_factor | ||
self.reduce_const = reduce_const | ||
self.warm_start = warm_start | ||
self.targeted = targeted | ||
self.learning_rate = learning_rate | ||
self.max_iterations = max_iterations | ||
self.abort_early = abort_early | ||
self.clip_min = clip_min | ||
self.clip_max = clip_max | ||
self.return_best = return_best | ||
|
||
def _get_arctanh_x(self, x): | ||
result = clamp((x - self.clip_min) / (self.clip_max - self.clip_min), | ||
min=0., max=1.) * 2 - 1 | ||
return torch_arctanh(result * ONE_MINUS_EPS) | ||
|
||
def _outputs_and_loss(self, x, modifiers, starting_atanh, y, const, taus): | ||
adversarials = tanh_rescale( | ||
starting_atanh + modifiers, self.clip_min, self.clip_max) | ||
|
||
outputs = self.predict(adversarials) | ||
y_onehot = to_one_hot(y, self.num_classes).float() | ||
|
||
real = (y_onehot * outputs).sum(dim=1) | ||
|
||
other = ((1.0 - y_onehot) * outputs - (y_onehot * TARGET_MULT) | ||
).max(dim=1)[0] | ||
# - (y_onehot * TARGET_MULT) is for the true label not to be selected | ||
|
||
if self.targeted: | ||
loss1 = torch.clamp(other - real, min=0.) | ||
else: | ||
loss1 = torch.clamp(real - other, min=0.) | ||
|
||
loss1 = const * loss1 | ||
|
||
image_dimensions = tuple(range(1, len(x.shape))) | ||
taus_shape = (-1,) + (1,) * (len(x.shape) - 1) | ||
|
||
penalties = torch.clamp( | ||
torch.abs(x - adversarials) - taus.view(taus_shape), min=0) | ||
loss2 = torch.sum(penalties, dim=image_dimensions) | ||
|
||
assert loss1.shape == loss2.shape | ||
|
||
loss = loss1 + loss2 | ||
return outputs.detach(), loss | ||
|
||
def _successful(self, outputs, y): | ||
adversarial_labels = torch.argmax(outputs, dim=1) | ||
|
||
if self.targeted: | ||
return torch.eq(adversarial_labels, y) | ||
else: | ||
return ~torch.eq(adversarial_labels, y) | ||
|
||
def _run_attack(self, x, y, initial_const, taus, prev_adversarials): | ||
assert len(x) == len(taus) | ||
batch_size = len(x) | ||
best_adversarials = x.clone().detach() | ||
best_distances = torch.ones((batch_size,), | ||
device=x.device) * float("inf") | ||
|
||
if self.warm_start: | ||
starting_atanh = self._get_arctanh_x(prev_adversarials.clone()) | ||
else: | ||
starting_atanh = self._get_arctanh_x(x.clone()) | ||
|
||
modifiers = torch.nn.Parameter(torch.zeros_like(starting_atanh)) | ||
|
||
# An array of booleans that stores which samples have not converged | ||
# yet | ||
active = torch.ones((batch_size,), dtype=boolean_type, device=x.device) | ||
optimizer = optim.Adam([modifiers], lr=self.learning_rate) | ||
|
||
const = initial_const | ||
|
||
while torch.any(active) and const < self.max_const: | ||
for _ in range(self.max_iterations): | ||
optimizer.zero_grad() | ||
outputs, loss = self._outputs_and_loss( | ||
x[active], | ||
modifiers[active], | ||
starting_atanh[active], | ||
y[active], | ||
const, | ||
taus[active]) | ||
|
||
adversarials = tanh_rescale( | ||
starting_atanh + modifiers, | ||
self.clip_min, | ||
self.clip_max).detach() | ||
|
||
successful = self._successful(outputs, y[active]) | ||
|
||
if self.return_best: | ||
distances = torch.max( | ||
torch.abs( | ||
x[active] - adversarials[active] | ||
).flatten(1), | ||
dim=1)[0] | ||
better_distance = distances < best_distances[active] | ||
|
||
replace_active(adversarials[active], | ||
best_adversarials, | ||
active, | ||
successful & better_distance) | ||
replace_active(distances, | ||
best_distances, | ||
active, | ||
successful & better_distance) | ||
else: | ||
best_adversarials[active] = adversarials[active] | ||
|
||
# If early aborting is enabled, drop successful | ||
# samples with a small loss (the current adversarials | ||
# are saved regardless of whether they are dropped) | ||
if self.abort_early: | ||
small_loss = loss < 0.0001 * const | ||
|
||
drop = successful & small_loss | ||
|
||
# This workaround avoids modifying "active" | ||
# in-place, which would mess with | ||
# gradient computation in backwards() | ||
active_clone = active.clone() | ||
active_clone[active] = ~drop | ||
active = active_clone | ||
|
||
if not active.any(): | ||
break | ||
|
||
# Update the modifiers | ||
total_loss = torch.sum(loss) | ||
total_loss.backward() | ||
optimizer.step() | ||
|
||
# Give more weight to the output loss | ||
const *= self.const_factor | ||
|
||
return best_adversarials | ||
|
||
def perturb(self, x, y=None): | ||
x, y = self._verify_and_process_inputs(x, y) | ||
|
||
# Initialization | ||
if y is None: | ||
y = self._get_predicted_label(x) | ||
|
||
x = replicate_input(x) | ||
batch_size = len(x) | ||
best_adversarials = x.clone() | ||
best_distances = torch.ones((batch_size,), | ||
device=x.device) * float("inf") | ||
|
||
# An array of booleans that stores which samples have not converged | ||
# yet | ||
active = torch.ones((batch_size,), dtype=boolean_type, device=x.device) | ||
|
||
initial_const = self.initial_const | ||
taus = torch.ones((batch_size,), device=x.device) * self.initial_tau | ||
|
||
# The previous adversarials. This is used to perform a "warm start" | ||
# during optimisation | ||
prev_adversarials = x.clone() | ||
|
||
while torch.any(active): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this while needed here? there already exists a loop in _run_attack. This loop may not stop and it seems to cause an infinite loop in the benchmarking test. |
||
adversarials = self._run_attack( | ||
x[active], | ||
y[active], | ||
initial_const, | ||
taus[active], | ||
prev_adversarials[active].clone()) | ||
|
||
# Store the adversarials for the next iteration, | ||
# even if they failed | ||
prev_adversarials[active] = adversarials | ||
|
||
adversarial_outputs = self.predict(adversarials) | ||
successful = self._successful(adversarial_outputs, y[active]) | ||
|
||
# If the Linf distance is lower than tau and the adversarial | ||
# is successful, use it as the new tau | ||
linf_distances = torch.max( | ||
torch.abs(adversarials - x[active]).flatten(1), | ||
dim=1)[0] | ||
linf_lower = linf_distances < taus[active] | ||
|
||
replace_active(linf_distances, | ||
taus, | ||
active, | ||
linf_lower & successful) | ||
|
||
# Save the remaining adversarials | ||
if self.return_best: | ||
better_distance = linf_distances < best_distances[active] | ||
replace_active(adversarials, | ||
best_adversarials, | ||
active, | ||
successful & better_distance) | ||
replace_active(linf_distances, | ||
best_distances, | ||
active, | ||
successful & better_distance) | ||
else: | ||
replace_active(adversarials, | ||
best_adversarials, | ||
active, | ||
successful) | ||
|
||
taus *= self.tau_factor | ||
|
||
if self.reduce_const: | ||
initial_const /= 2 | ||
|
||
# Drop failed samples or with a low tau | ||
low_tau = taus[active] <= self.min_tau | ||
drop = low_tau | (~successful) | ||
active[active] = ~drop | ||
|
||
return best_adversarials |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line doesn't work for torch 1.1 (works for torch 1.4). As in torch 1.1, torch.any only takes byte tensor, not bool tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to fix right now, but please mention this in the comment. We'll implement testing under multiple torch versions later, and it can be fixed with other problems all together.