From dc186d9603942d52e59c710ea920a1321aa5fef3 Mon Sep 17 00:00:00 2001 From: zyf Date: Thu, 14 Sep 2023 10:46:36 +0800 Subject: [PATCH] Fix gen_attack's error on GPU. --- advertorch/attacks/blackbox/gen_attack.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/advertorch/attacks/blackbox/gen_attack.py b/advertorch/attacks/blackbox/gen_attack.py index d0006df..afad100 100644 --- a/advertorch/attacks/blackbox/gen_attack.py +++ b/advertorch/attacks/blackbox/gen_attack.py @@ -65,6 +65,7 @@ def crossover(p1, p2, probs): Select from p1 with the probabilties in probs. """ u = torch.rand(*p1.shape) + u = u.to(p1.device) return torch.where(probs[:, :, None] > u, p1, p2) @@ -87,6 +88,8 @@ def selection(pop_t, fitness, tau): # sample parent 1 from pop_t according to probs (multinomial) # sample parent 2 from pop_t according to probs (multinomial) u1, u2 = torch.rand(2, n_batch, nb_samples) + u1 = u1.to(pop_t.device) + u2 = u2.to(pop_t.device) # out of the original N samples, we draw another N samples # this requires us to compute the following broadcasted comparison @@ -119,9 +122,10 @@ def mutation(pop_t, alpha, rho, eps): """ # alpha and eps both have shape [B] perturb_noise = (2 * torch.rand(*pop_t.shape) - 1) + perturb_noise = perturb_noise.to(eps.device) perturb_noise = perturb_noise * alpha[:, None, None] * eps[:, None, None] - mask = (torch.rand(*pop_t.shape) > rho[:, None, None]).float() + mask = (torch.rand(*pop_t.shape).to(eps.device) > rho[:, None, None]).float() return pop_t + mask * perturb_noise @@ -218,8 +222,8 @@ def gen_attack( # shape: [B, N, F] pop_t = 2 * torch.rand(n_batch, nb_samples, n_dim) - 1 # Sample from Uniform(-eps, eps) - pop_t = eps[:, None, None] * pop_t pop_t = pop_t.to(x.device) + pop_t = eps[:, None, None] * pop_t else: pop_t = pop_init.clone()