-
-
Save poutyface/c1818e43e96d550deea0ceef48091eff to your computer and use it in GitHub Desktop.
m2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import os | |
import numpy as np | |
import chainer | |
import chainer.functions as F | |
import chainer.links as L | |
from chainer import cuda | |
from chainer import Chain | |
from chainer import optimizers | |
from chainer import serializers | |
import cv2 | |
import dataset | |
print "VAE p(x|y,z)" | |
data_dir = "./dataset/" | |
mnist = dataset.load_mnist_data(data_dir) | |
all_x = np.array(mnist['data'], dtype=np.float32) / 255.0 | |
all_y_tmp = np.array(mnist['target'], dtype=np.float32) | |
all_y = np.zeros((all_x.shape[0], (np.max(all_y_tmp) + 1.0)), dtype=np.float32) | |
for i in range(all_y_tmp.shape[0]): | |
all_y[i][all_y_tmp[i]] = 1.0 | |
train_x = all_x[:50000] | |
train_y = all_y[:50000] | |
valid_x = all_x[50000:60000] | |
valid_y = all_y[50000:60000] | |
test_x = all_x[60000:] | |
test_y = all_y[60000:] | |
image_size = 28 | |
nx = image_size * image_size | |
nbatch = 10 | |
nz = 300 | |
ny = 10 | |
class VAE(Chain): | |
def __init__(self): | |
super(VAE, self).__init__( | |
recog_x1 = L.Linear(nx, 500, nobias=True), | |
recog_x2 = L.Linear(500, 500), | |
recog_y1 = L.Linear(ny, 500), | |
recog_y2 = L.Linear(500, 500), | |
recog_x_y = L.Linear(500, 500), | |
recog_mean = L.Linear(500, nz), | |
recog_log_sigma = L.Linear(500, nz), | |
gen_y1= L.Linear(ny, 500), | |
gen_y2 = L.Linear(500, 500), | |
gen_z1 = L.Linear(nz, 500, nobias=True), | |
gen_z2 = L.Linear(500, 500), | |
gen_z_y = L.Linear(500, 500), | |
gen = L.Linear(500, nx), | |
#bn1 = L.BatchNormalization(500), | |
#bn2 = L.BatchNormalization(500), | |
#bn3 = L.BatchNormalization(ndf*8), | |
#gen_log_sigma = L.Liner(500, n_input) | |
) | |
def generate_z(self, x, y): | |
# q(z|x,y) | |
hx = F.relu(self.recog_x1(x)) | |
hx = F.relu(self.recog_x2(hx)) | |
hy = F.relu(self.recog_y1(y)) | |
hy = F.relu(self.recog_y2(hy)) | |
hq = F.relu(self.recog_x_y(hx + hy)) | |
recog_mean = self.recog_mean(hq) | |
recog_log_sigma = 0.5 * self.recog_log_sigma(hq) | |
eps = np.random.normal(0, 1, (x.data.shape[0], nz)).astype(np.float32) | |
eps = chainer.Variable(eps) | |
# z = mu + sigma + epsilon | |
#z = recog_mean + F.exp(0.5 * recog_log_sigma) * eps | |
#z = recog_mean + F.exp(recog_log_sigma) * eps | |
z = recog_mean + F.exp(recog_log_sigma) * eps | |
return z, recog_mean, recog_log_sigma | |
def generate_x(self, z, y): | |
# q(x|y,z) | |
hy = F.relu(self.gen_y1(y)) | |
hy = F.relu(self.gen_y2(hy)) | |
hz = F.relu(self.gen_z1(z)) | |
hz = F.relu(self.gen_z2(hz)) | |
hp = F.relu(self.gen_z_y(hy + hz)) | |
output = self.gen(hp) | |
output = F.sigmoid(output) | |
return output | |
def generate(self, x, y): | |
z, _, _ = self.generate_z(x, y) | |
output = np.zeros((y.data.shape[1], x.data.shape[1]), dtype=np.float32) | |
for label in range(y.data.shape[1]): | |
sample_y = np.zeros((1, y.data.shape[1]), dtype=np.float32) | |
sample_y[0][label] = 1.0 | |
out = self.generate_x(z, chainer.Variable(sample_y)) | |
output[label] = out.data | |
return output | |
def __call__(self, x, y): | |
z, recog_mean, recog_log_sigma = self.generate_z(x, y) | |
output = self.generate_x(z, y) | |
loss = F.mean_squared_error(output, x) | |
kld = -0.5 * F.sum(1 + recog_log_sigma - recog_mean**2 - F.exp(recog_log_sigma)) / (x.data.shape[0] * x.data.shape[1]) | |
return loss, kld, output | |
class Disc(Chain): | |
def __init__(self): | |
super(Disc, self).__init__( | |
bn1 = L.BatchNormalization(ndf*2), | |
bn2 = L.BatchNormalization(ndf*4), | |
bn3 = L.BatchNormalization(ndf*8), | |
c1 = L.Convolution2D(nc, ndf, ksize=4, stride=2, pad=1), | |
c2 = L.Convolution2D(ndf, ndf*2, ksize=4, stride=2, pad=1), | |
c3 = L.Convolution2D(ndf*2, ndf*4, ksize=4, stride=2, pad=1), | |
c4 = L.Convolution2D(ndf*4, ndf*8, ksize=4, stride=2, pad=1), | |
l1 = L.Linear(ndf*8*6*6, 1) | |
) | |
def __call__(self, x, test=False): | |
h1 = F.leaky_relu(self.c1(x)) | |
h2 = F.leaky_relu(self.bn1(self.c2(h1), test=test)) | |
h3 = F.leaky_relu(self.bn2(self.c3(h2), test=test)) | |
h4 = F.leaky_relu(self.bn3(self.c4(h3), test=test)) | |
#h2 = F.leaky_relu(self.c2(h1)) | |
#h3 = F.leaky_relu(self.c3(h2)) | |
#h4 = F.leaky_relu(self.c4(h3)) | |
#h5 = F.average_pooling_2d(h4, 4) | |
#h5 = self.l1(h4) | |
h5 = self.l1(h4) | |
print x.data.shape | |
print h1.data.shape | |
print h2.data.shape | |
print h3.data.shape | |
print h4.data.shape | |
print h5.data.shape | |
#print h6.data.shape | |
return h5 | |
#image_path = "./lfwcrop_grey/faces" | |
#fs = os.listdir(image_path) | |
#print len(fs) | |
#dataset = [] | |
#for fn in fs: | |
# read as grey | |
# img = cv2.imread('%s/%s'%(image_path, fn), 0) | |
# img = cv2.resize(img, (image_size,image_size)) | |
# img = img.astype(np.float32) | |
# img = img / 255 | |
# img = img.reshape(image_size*image_size) | |
# dataset.append(img) | |
vae = VAE() | |
opt = optimizers.Adam(alpha=0.0002, beta1=0.5) | |
opt.setup(vae) | |
indexes = np.random.permutation(train_x.shape[0]) | |
for epoch in xrange(500000): | |
print "epoch:", epoch | |
for i in xrange(0, train_x.shape[0], nbatch): | |
x_batch = train_x[indexes[i:i+nbatch]] | |
y_batch = train_y[indexes[i:i+nbatch]] | |
# VAE | |
recog_loss, kld_loss, output = vae(chainer.Variable(x_batch), chainer.Variable(y_batch)) | |
loss = recog_loss + kld_loss | |
print loss.data, recog_loss.data, kld_loss.data | |
vae.zerograds() | |
loss.backward() | |
opt.update() | |
x_batch = np.zeros((1, train_x.shape[1]), dtype=np.float32) | |
y_batch = np.zeros((1, train_y.shape[1]), dtype=np.float32) | |
x_batch[0] = train_x[1] | |
y_batch[0] = train_y[1] | |
output = vae.generate(chainer.Variable(x_batch), chainer.Variable(y_batch)) | |
img = train_x[1] | |
img = img * 255 | |
img = img.reshape(image_size, image_size) | |
img = img.astype(np.uint8) | |
cv2.imshow("input", img) | |
for i in range(0, output.shape[0]): | |
img = output[i] | |
img = img * 255 | |
img = img.reshape(image_size, image_size) | |
img = img.astype(np.uint8) | |
cv2.imshow("%d"%i, img) | |
cv2.waitKey(1) | |
## for j in range(0, 3): | |
# img = output.data[j] | |
# img = img * 255 | |
# img = img.reshape(image_size, image_size) | |
# img = img.astype(np.uint8) | |
# cv2.imshow("%d"%j, img) | |
# cv2.waitKey(1) | |
# if epoch % 1000 == 0: | |
# for j in range(0, nbatch): | |
# img = output.data[j] | |
# img = img * 255 | |
# img = img.reshape(image_size, image_size) | |
# img = img.astype(np.uint8) | |
# cv2.imwrite("out_images/%d_%d.jpg"%(epoch, j), img) | |
# serializers.save_hdf5("out_models/model_%d.h5"%(epoch), vae) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment