Last active
April 3, 2020 11:04
-
-
Save lukoshkin/579426755c7a5a8b164b08d23534fbc5 to your computer and use it in GitHub Desktop.
Improving WGAN with the gradient penalty term
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def calc_grad_penalty(real_samples, fake_samples, net_D): | |
""" | |
Evaluates gradient penalty for `net_D` and allows other gradients | |
to backpropogate through this penalty term | |
Args: | |
real_samples - a tensor (presumably, without `grad` attribute) | |
fake_samples - tensor of the same shape as `real_samples` tensor | |
net_D - a 'critic' which takes the input of the same shape | |
as `real_samples` | |
""" | |
alpha = real_samples.new( | |
real_samples.size(0), | |
*([1]*(real_samples.dim()-1)) | |
).uniform_().expand(*real_samples.shape) | |
inputs = alpha * real_samples + (1-alpha) * fake_samples.detach() | |
inputs.requires_grad_(True) | |
outputs = net_D(inputs) | |
jacobian = torch.autograd.grad ( | |
outputs=outputs, | |
inputs=inputs, | |
grad_outputs=torch.ones_like(outputs) | |
create_graph=True | |
)[0] | |
# flatten each sample grad. and apply 2nd norm to it | |
jacobian = jacobian.view(jacobian.size(0), -1) | |
return ((jacobian.norm(dim=1) - 1) ** 2).mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment