Created
April 12, 2020 16:36
-
-
Save ptrcarta/7fe1424649462df2b7b84560bd0a1d67 to your computer and use it in GitHub Desktop.
dcgan tpu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
from tensorflow import keras | |
from timeit import default_timer as timer | |
import matplotlib.pyplot as plt | |
import numpy as np | |
IMAGE_DATA_FORMAT = 'channels_last' | |
keras.backend.set_image_data_format(IMAGE_DATA_FORMAT) | |
batchnorm_axis = 3 if IMAGE_DATA_FORMAT == 'channels_last' else 1 | |
if 'TPU_NAME' in os.environ: | |
resolver = tf.distribute.cluster_resolver.TPUClusterResolver() | |
tf.config.experimental_connect_to_cluster(resolver) | |
tf.tpu.experimental.initialize_tpu_system(resolver) | |
strategy = tf.distribute.experimental.TPUStrategy(resolver) | |
elif 'COLAB_GPU' in os.environ and int(os.environ['COLAB_GPU']) > 0: | |
strategy = tf.distribute.MirroredStrategy() | |
else: | |
strategy = tf.distribute.OneDeviceStrategy(device='/cpu:0') | |
BATCH_SIZE = 16 | |
GLOBAL_BATCH_SIZE = BATCH_SIZE*strategy.num_replicas_in_sync | |
EPOCHS = 10 | |
ADAM_LR=0.0002 | |
ADAM_BETA1 = 0.5 | |
ADAM_BETA2 = 0.999 | |
nz = 256 | |
def D(): | |
return keras.Sequential([ | |
keras.layers.Conv2D(16, 4, strides=4), | |
keras.layers.ReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2D(128, 3, strides=1), | |
keras.layers.ReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2D(256, 3, strides=1), | |
keras.layers.ReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2D(256, 3, strides=3), | |
keras.layers.ReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2D(1, (7,6), strides=1), | |
]) | |
def G(): | |
return keras.Sequential([ | |
keras.layers.Reshape((1,1,nz))\ | |
if IMAGE_DATA_FORMAT=='channels_last'\ | |
else keras.layers.Reshape((nz,1,1)), | |
keras.layers.Conv2DTranspose(2048, 2, strides=2), | |
keras.layers.LeakyReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2DTranspose(1024, 2, strides=2), | |
keras.layers.LeakyReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2DTranspose(512, 2, strides=2), | |
keras.layers.LeakyReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2DTranspose(256, 2, strides=2), | |
keras.layers.LeakyReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2DTranspose(128, 2, strides=2), | |
keras.layers.LeakyReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2DTranspose(64, 2, strides=2), | |
keras.layers.LeakyReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2DTranspose(32, 2, strides=2), | |
keras.layers.LeakyReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2DTranspose(32, 3, strides=1), | |
keras.layers.LeakyReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2DTranspose(32, 3, strides=1), | |
keras.layers.LeakyReLU(), | |
keras.layers.BatchNormalization(axis=batchnorm_axis), | |
keras.layers.Conv2D(3, 1, strides=1), | |
keras.layers.Activation(tf.sigmoid), | |
keras.layers.Cropping2D(((0,23),(0,43))) | |
]) | |
class TimedInputs: | |
def __init__(self, secs): | |
self._count = 0 | |
self._start_time = None | |
self._last_time = None | |
self._secs = secs | |
def __enter__(self): | |
self._start_time = timer() | |
return self | |
def _print_rate(self): | |
print(f'rate: {self._count/(self._last_time - self._start_time):.2f}') | |
def count(self, num): | |
self._last_time = timer() | |
self._count += num | |
if self._last_time - self._start_time > self._secs: | |
self._print_rate() | |
self._start_time = self._last_time | |
self._count = 0 | |
def __exit__(self, typ, value, tb): | |
self._last_time = timer() | |
self._print_rate() | |
feature_description = { | |
'name': tf.io.FixedLenFeature([], tf.string, default_value=''), | |
'img': tf.io.FixedLenFeature([], tf.string, default_value='') | |
} | |
def make_tfrecords(): | |
imgs_file = 10000 | |
def name_and_content(fn): | |
return fn, tf.io.read_file(fn) | |
ds = tf.data.Dataset.list_files( | |
'data/img_align_celeba/*jpg', shuffle=False).map(name_and_content) | |
for i, (fn, img) in enumerate(ds): | |
if i % imgs_file == 0: | |
print('img ', i, '/', 202599) | |
w = tf.io.TFRecordWriter( | |
f'gs://pietor-euw4/celeba/celeba_{i//imgs_file:03d}.tfr') | |
name_f = tf.train.Feature(bytes_list=tf.train.BytesList(value=[fn.numpy()])) | |
img_f = tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.numpy()])) | |
feature = {'name': name_f, 'img':img_f} | |
features=tf.train.Features(feature=feature) | |
example_proto = tf.train.Example(features=features) | |
w.write(example_proto.SerializeToString()) | |
if ((i + 1) % imgs_file == 0) or ((202599 - 1) == i): | |
w.close() | |
print('written') | |
def parse_examples(ex_serialized): | |
return tf.io.parse_single_example(ex_serialized, feature_description) | |
def decode_image(image): | |
img = tf.io.decode_image(image, expand_animations=False) | |
img = tf.cast(img[::2,::2,:], tf.float32)/255. | |
if IMAGE_DATA_FORMAT == 'channels_first': | |
img = tf.transpose(img, [2,0,1]) | |
return img | |
def get_dataset_local(): | |
return tf.data.Dataset.list_files( | |
'data/img_align_celeba/*jpg', shuffle=False).map( | |
tf.io.read_file).map(decode_image | |
) | |
def drop_names(d): | |
return d['img'] | |
def get_dataset_gcs(): | |
return tf.data.TFRecordDataset( | |
tf.io.matching_files('gs://pietor-euw4/celeba/celeba_*') | |
).prefetch(1024 | |
).map(parse_examples | |
).map(drop_names | |
).map(decode_image | |
) | |
if __name__ == '__main__': | |
with strategy.scope(): | |
d = D() | |
g = G() | |
optD = keras.optimizers.Adam(ADAM_LR, ADAM_BETA1, ADAM_BETA2) | |
optG = keras.optimizers.Adam(ADAM_LR, ADAM_BETA1, ADAM_BETA2) | |
@tf.function | |
def train_step(real_inputs): | |
zs = tf.random.normal( | |
[GLOBAL_BATCH_SIZE//strategy.num_replicas_in_sync, nz]) | |
with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape: | |
fake_inputs = g(zs) | |
## Train D | |
d_out_real = d(real_inputs) | |
d_out_fake = d(fake_inputs) | |
loss_d_real = keras.losses.binary_crossentropy( | |
tf.ones_like(d_out_real), d_out_real, from_logits=True) | |
loss_d_fake = keras.losses.binary_crossentropy( | |
tf.zeros_like(d_out_fake), d_out_fake, from_logits=True) | |
loss_d = loss_d_real + loss_d_fake | |
grad_d = d_tape.gradient(loss_d, d.trainable_variables) | |
optD.apply_gradients(zip(grad_d, d.trainable_variables)) | |
ddd = tf.linalg.norm(list(tf.nest.map_structure(tf.linalg.norm, grad_d))) | |
tf.print('grad_d norm', ddd) | |
## Train G | |
dg_out_fake = d(fake_inputs) | |
loss_g = keras.losses.binary_crossentropy( | |
tf.ones_like(dg_out_fake), | |
dg_out_fake, from_logits=True) | |
# loss_g = -keras.losses.binary_crossentropy( | |
# tf.zeros_like(dg_out_fake), dg_out_fake, from_logits=True) | |
grad_g = g_tape.gradient(loss_g, g.trainable_variables) | |
optG.apply_gradients(zip(grad_g, g.trainable_variables)) | |
ddd = tf.linalg.norm(list(tf.nest.map_structure(tf.linalg.norm, grad_g))) | |
tf.print('grad_g norm', ddd) | |
# Epoch | |
for epoch in range(EPOCHS): | |
print(f'epoch: {epoch}') | |
dataset = get_dataset_gcs().batch(GLOBAL_BATCH_SIZE, drop_remainder=True | |
) | |
dataset = strategy.experimental_distribute_dataset(dataset) # iterations | |
with TimedInputs(3) as t: | |
for inputs in dataset: | |
strategy.experimental_run_v2(train_step, (inputs,)) | |
t.count(GLOBAL_BATCH_SIZE) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment