Created
October 3, 2016 19:44
-
-
Save EderSantana/9b0d5fb309d775b995d5236c32238349 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import os | |
import time | |
from glob import glob | |
import tensorflow as tf | |
from six.moves import xrange | |
from scipy.misc import imresize | |
from subpixel import PS | |
from ops import * | |
from utils import * | |
def doresize(x, shape): | |
x = np.copy((x+1.)*127.5).astype("uint8") | |
y = imresize(x, shape) | |
return y | |
class DCGAN(object): | |
def __init__(self, sess, image_size=128, is_crop=True, | |
batch_size=64, sample_size = 64, image_shape=[128, 128, 3], | |
y_dim=None, z_dim=100, gf_dim=64, df_dim=64, | |
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default', | |
checkpoint_dir=None): | |
""" | |
Args: | |
sess: TensorFlow session | |
batch_size: The size of batch. Should be specified before training. | |
y_dim: (optional) Dimension of dim for y. [None] | |
z_dim: (optional) Dimension of dim for Z. [100] | |
gf_dim: (optional) Dimension of gen filters in first conv layer. [64] | |
df_dim: (optional) Dimension of discrim filters in first conv layer. [64] | |
gfc_dim: (optional) Dimension of gen untis for for fully connected layer. [1024] | |
dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024] | |
c_dim: (optional) Dimension of image color. [3] | |
""" | |
self.sess = sess | |
self.is_crop = is_crop | |
self.batch_size = batch_size | |
self.image_size = image_size | |
self.input_size = 32 | |
self.sample_size = sample_size | |
self.image_shape = image_shape | |
self.y_dim = y_dim | |
self.z_dim = z_dim | |
self.gf_dim = gf_dim | |
self.df_dim = df_dim | |
self.gfc_dim = gfc_dim | |
self.dfc_dim = dfc_dim | |
self.c_dim = 3 | |
self.dataset_name = dataset_name | |
self.checkpoint_dir = checkpoint_dir | |
self.build_model() | |
def build_model(self): | |
self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.input_size, self.input_size, 3], | |
name='real_images') | |
self.up_inputs = tf.image.resize_images(self.inputs, [self.image_shape[0], self.image_shape[1]], tf.image.ResizeMethod.NEAREST_NEIGHBOR) | |
self.images = tf.placeholder(tf.float32, [self.batch_size] + self.image_shape, | |
name='real_images') | |
self.sample_images= tf.placeholder(tf.float32, [self.sample_size] + self.image_shape, | |
name='sample_images') | |
self.G = self.generator(self.inputs) | |
self.G_sum = tf.image_summary("G", self.G) | |
self.g_loss = tf.reduce_mean(tf.square(self.images-self.G)) | |
self.g_loss_sum = tf.scalar_summary("g_loss", self.g_loss) | |
t_vars = tf.trainable_variables() | |
self.g_vars = [var for var in t_vars if 'g_' in var.name] | |
self.saver = tf.train.Saver() | |
def train(self, config): | |
"""Train DCGAN""" | |
data = glob(os.path.join("./data", config.dataset, "*.jpg")) | |
#np.random.shuffle(data) | |
g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ | |
.minimize(self.g_loss, var_list=self.g_vars) | |
tf.initialize_all_variables().run() | |
self.saver = tf.train.Saver() | |
self.g_sum = tf.merge_summary([self.G_sum, self.g_loss_sum]) | |
self.writer = tf.train.SummaryWriter("./logs", self.sess.graph) | |
sample_files = data[0:self.sample_size] | |
sample = [get_image(sample_file, self.image_size, is_crop=self.is_crop) for sample_file in sample_files] | |
sample_inputs = [doresize(xx, [self.input_size,]*2) for xx in sample] | |
sample_images = np.array(sample).astype(np.float32) | |
sample_input_images = np.array(sample_inputs).astype(np.float32) | |
counter = 1 | |
start_time = time.time() | |
if self.load(self.checkpoint_dir): | |
print(" [*] Load SUCCESS") | |
else: | |
print(" [!] Load failed...") | |
for epoch in xrange(config.epoch): | |
data = glob(os.path.join("./data", config.dataset, "*.jpg")) | |
batch_idxs = min(len(data), config.train_size) // config.batch_size | |
for idx in xrange(0, batch_idxs): | |
batch_files = data[idx*config.batch_size:(idx+1)*config.batch_size] | |
batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop) for batch_file in batch_files] | |
input_batch = [doresize(xx, [self.input_size,]*2) for xx in batch] | |
batch_images = np.array(batch).astype(np.float32) | |
batch_inputs = np.array(input_batch).astype(np.float32) | |
# Update G network | |
_, summary_str, errG = self.sess.run([g_optim, self.g_sum, self.g_loss], | |
feed_dict={ self.inputs: batch_inputs, self.images: batch_images }) | |
self.writer.add_summary(summary_str, counter) | |
counter += 1 | |
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, g_loss: %.8f" \ | |
% (epoch, idx, batch_idxs, | |
time.time() - start_time, errG)) | |
if np.mod(counter, 100) == 1: | |
samples, g_loss, up_inputs = self.sess.run( | |
[self.G, self.g_loss, self.up_inputs], | |
feed_dict={self.inputs: sample_input_images, self.images: sample_images} | |
) | |
save_images(samples, [8, 8], | |
'./samples/train_%s_%s.png' % (epoch, idx)) | |
save_images(up_inputs, [8, 8], | |
'./samples/inputs_%s_%s.png' % (epoch, idx)) | |
print("[Sample] g_loss: %.8f" % (g_loss)) | |
if np.mod(counter, 500) == 2: | |
self.save(config.checkpoint_dir, counter) | |
def generator(self, z): | |
# project `z` and reshape | |
self.h0, self.h0_w, self.h0_b = deconv2d(z, [self.batch_size, 32, 32, self.gf_dim], k_h=1, k_w=1, d_h=1, d_w=1, name='g_h0', with_w=True) | |
h0 = lrelu(self.h0) | |
self.h1, self.h1_w, self.h1_b = deconv2d(h0, [self.batch_size, 32, 32, self.gf_dim], name='g_h1', d_h=1, d_w=1, with_w=True) | |
h1 = lrelu(self.h1) | |
h2, self.h2_w, self.h2_b = deconv2d(h1, [self.batch_size, 32, 32, 3*16], d_h=1, d_w=1, name='g_h2', with_w=True) | |
h2 = PS(h2, 4, color=True) | |
return tf.nn.tanh(h2) | |
def save(self, checkpoint_dir, step): | |
model_name = "DCGAN.model" | |
model_dir = "%s_%s" % (self.dataset_name, self.batch_size) | |
checkpoint_dir = os.path.join(checkpoint_dir, model_dir) | |
if not os.path.exists(checkpoint_dir): | |
os.makedirs(checkpoint_dir) | |
self.saver.save(self.sess, | |
os.path.join(checkpoint_dir, model_name), | |
global_step=step) | |
def load(self, checkpoint_dir): | |
print(" [*] Reading checkpoints...") | |
model_dir = "%s_%s" % (self.dataset_name, self.batch_size) | |
checkpoint_dir = os.path.join(checkpoint_dir, model_dir) | |
ckpt = tf.train.get_checkpoint_state(checkpoint_dir) | |
if ckpt and ckpt.model_checkpoint_path: | |
ckpt_name = os.path.basename(ckpt.model_checkpoint_path) | |
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) | |
return True | |
else: | |
return False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment