-
-
Save oscarknagg/45b187c236c6262b1c4bbe2d0920ded6 to your computer and use it in GitHub Desktop.
import torch | |
def projected_gradient_descent(model, x, y, loss_fn, num_steps, step_size, step_norm, eps, eps_norm, | |
clamp=(0,1), y_target=None): | |
"""Performs the projected gradient descent attack on a batch of images.""" | |
x_adv = x.clone().detach().requires_grad_(True).to(x.device) | |
targeted = y_target is not None | |
num_channels = x.shape[1] | |
for i in range(num_steps): | |
_x_adv = x_adv.clone().detach().requires_grad_(True) | |
prediction = model(_x_adv) | |
loss = loss_fn(prediction, y_target if targeted else y) | |
loss.backward() | |
with torch.no_grad(): | |
# Force the gradient step to be a fixed size in a certain norm | |
if step_norm == 'inf': | |
gradients = _x_adv.grad.sign() * step_size | |
else: | |
# Note .view() assumes batched image data as 4D tensor | |
gradients = _x_adv.grad * step_size / _x_adv.grad.view(_x_adv.shape[0], -1)\ | |
.norm(step_norm, dim=-1)\ | |
.view(-1, num_channels, 1, 1) | |
if targeted: | |
# Targeted: Gradient descent with on the loss of the (incorrect) target label | |
# w.r.t. the image data | |
x_adv -= gradients | |
else: | |
# Untargeted: Gradient ascent on the loss of the correct label w.r.t. | |
# the model parameters | |
x_adv += gradients | |
# Project back into l_norm ball and correct range | |
if eps_norm == 'inf': | |
# Workaround as PyTorch doesn't have elementwise clip | |
x_adv = torch.max(torch.min(x_adv, x + eps), x - eps) | |
else: | |
delta = x_adv - x | |
# Assume x and x_adv are batched tensors where the first dimension is | |
# a batch dimension | |
mask = delta.view(delta.shape[0], -1).norm(norm, dim=1) <= eps | |
scaling_factor = delta.view(delta.shape[0], -1).norm(norm, dim=1) | |
scaling_factor[mask] = eps | |
# .view() assumes batched images as a 4D Tensor | |
delta *= eps / scaling_factor.view(-1, 1, 1, 1) | |
x_adv = x + delta | |
x_adv = x_adv.clamp(*clamp) | |
return x_adv.detach() |
I am right now doing something very similar, although my implementation is more closely based on Madry's. I get good results for normalized MNIST (mean = 0.1307, std = 0.3081). Of course, my clamp is different, with (min=-0.4242, max=2.8214). I wonder if my good results for the normalized MNIST are due to eps=0.3 needing to be rescaled. Also, I am using Adam and a different network architecture. I think it might be due to the vanishing gradient problem, since I am using sigmoid activations.
Could you tell me at how many epochs do you start going over 50%?
Oh, I just wanted to point one thing out. This code doesn't work if eps_norm is an integer value. You normalized the deltas in the last else branch using the variable norm
which wasn't even defined anywhere. I think you meant eps_norm
here. @oscarknagg Have you tested this intensively?
x_adv = x_adv.clamp(*clamp)
Can someone explain the need for this clamping?
x_adv = x_adv.clamp(*clamp)
Can someone explain the need for this clamping?
It's equivalent to x_adv = x_adv.clamp(0, 1)
.
*
unpacks 'clamp'
Does the norm
here mean eps_norm
? (L45)
mask = delta.view(delta.shape[0], -1).norm(norm, dim=1) <= eps
What is y_target here?