Created
February 20, 2017 20:56
-
-
Save devnag/5aea1b240dba463781aa15b6ac5a79bb to your computer and use it in GitHub Desktop.
Changes to make gan_pytorch.py use only squared diffs rather than squared diffs + original data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/gan_pytorch.py b/gan_pytorch.py | |
index 0bff38c..802d6cb 100755 | |
--- a/gan_pytorch.py | |
+++ b/gan_pytorch.py | |
@@ -31,7 +31,7 @@ g_steps = 1 | |
# ### Uncomment only one of these | |
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x) | |
-(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2) | |
+(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x) | |
print "Using data [%s]" % (name) | |
@@ -79,7 +79,8 @@ def decorate_with_diffs(data, exponent): | |
mean = torch.mean(data.data, 1) | |
mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0]) | |
diffs = torch.pow(data - Variable(mean_broadcast), exponent) | |
- return torch.cat([data, diffs], 1) | |
+ #return torch.cat([data, diffs], 1) | |
+ return diffs | |
d_sampler = get_distribution_sampler(data_mean, data_stddev) | |
gi_sampler = get_generator_input_sampler() |
My own understanding about it is using both the data and diffs, the G can generate much more better fake data than only using diffs. It just looks like the features extracting from images, more dimensions of feature, more accuracy of the nueral network.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, i have the same question about the function of decorate_with_diffs. I didn't understand why you cat the data and diffs. I have run the scrip with using these two methods .After training, i found without cating the data and diffs, the fake data has more negative mean values than the first method. Can you explain it? Thanks very much for helping me.