Created
July 13, 2022 04:48
-
-
Save bhatiaabhinav/1b67de77cc5b2551185bcfaf1b4c0148 to your computer and use it in GitHub Desktop.
Julia implementation of WGAN-GP using Flux and Zygote
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
using Flux: update! | |
using Zygote | |
using StatsBase | |
""" | |
WGAN with gradient penalty. See algorithm 1 in https://proceedings.neurips.cc/paper/2017/file/892c3b1c6dccd52936e27cbd0ff683d6-Paper.pdf. The following code is almost line by line identical. | |
""" | |
function train_WGAN_GP(πΊ, π·, π::Array{Float32, N}, latent_size, num_iters, device_fn; m=32, Ξ»=10f0, ncritic=5, Ξ±=0.0001, Ξ²β=0, Ξ²β=0.9) where N | |
n = size(π)[end] # length of dataset | |
πΊ, π· = device_fn(deepcopy(πΊ)), device_fn(deepcopy(π·)) | |
ΞΈ, π€ = params(πΊ), params(π·) | |
adamΞΈ, adamπ€ = ADAM(Ξ±, (Ξ²β, Ξ²β)), ADAM(Ξ±, (Ξ²β, Ξ²β)) | |
for iter in 1:num_iters | |
for t in 1:ncritic | |
π±, π³, π = π[repeat([:], N-1)..., rand(1:n, m)], randn(Float32, latent_size..., m), rand(Float32, repeat([1], N-1)..., m) # Sample batch of real data x, latent variables z, random numbers Ο΅ βΌ U[0, 1]. | |
π±, π³, π = device_fn(π±), device_fn(π³), device_fn(π) | |
π±Μ = πΊ(π³) | |
π±Μ = π .* π± + (1f0 .- π) .* π±Μ | |
βπ€L = gradient(π€) do | |
βπ±Μπ·, = gradient(π±Μ -> sum(π·(π±Μ)), π±Μ) | |
L = mean(π·(π±Μ)) - mean(π·(π±)) + Ξ» * mean((sqrt.(sum(βπ±Μπ·.^2, dims=1) .+ 1f-12) .- 1f0).^2) | |
end | |
update!(adamπ€, π€, βπ€L) | |
end | |
π³ = device_fn(randn(Float32, latent_size..., m)) | |
βΞΈπ· = gradient(ΞΈ) do | |
-mean(π·(πΊ(π³))) | |
end | |
update!(adamΞΈ, ΞΈ, βΞΈπ·) | |
end | |
return πΊ, π· | |
end | |
π = rand(Float32, 50, 10000) # dummy data | |
z = 16 # latent size | |
πΊ = Chain(Dense(z, 32, leakyrelu), Dense(32, 50)) # Generator | |
π· = Chain(Dense(50, 32, leakyrelu), Dense(32, 1)) # Critic | |
πΊ, π· = train_WGAN_GP(πΊ, π·, π, (z, ), 1, cpu) # works | |
# πΊ, π· = train_WGAN_GP(πΊ, π·, π, (z, ), 1, gpu) # fails. Doesn't work on GPU yet. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment