Created
May 15, 2018 17:52
-
-
Save aferral/b8eac5d64039c117d29702d2ce0939fc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import load_digits | |
import tensorflow as tf | |
import time | |
import numpy as np | |
from tensorflow.python import debug as tf_debug | |
# From http://adventuresinmachinelearning.com/tensorflow-dataset-tutorial/ | |
class timeit: | |
def __enter__(self): | |
self.st=time.time() | |
pass | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
print('elapsed {0}'.format(time.time()-self.st)) | |
pass | |
class b_counter: | |
def __init__(self,inc,end,epochs): | |
self.c = 0 | |
self.inc = inc | |
self.limit = end | |
self.epochs = epochs | |
def __call__(self, *args, **kwargs): | |
prev = self.c | |
self.c = self.c + self.inc | |
if self.c >= self.limit: | |
self.c = 0 | |
self.epochs -= 1 | |
if self.epochs == 0: | |
raise Exception() | |
return prev,self.limit | |
return prev,self.c | |
debug = False | |
epochs = 100 | |
reg_factor = 0.1 | |
batch_size = 30 | |
kp=0.95 | |
# load the data | |
digits = load_digits(return_X_y=True) | |
# split into train and validation sets | |
train_images = digits[0][:int(len(digits[0]) * 0.8)] | |
train_labels = digits[1][:int(len(digits[0]) * 0.8)] | |
one_hot_train_labels = np.zeros((train_labels.shape[0],10)) | |
one_hot_train_labels[np.arange(train_labels.shape[0]), train_labels] = 1 | |
valid_images = digits[0][int(len(digits[0]) * 0.8):] | |
valid_labels = digits[1][int(len(digits[0]) * 0.8):] | |
mean = train_images.mean(axis=0) | |
# create the training datasets | |
dx_train = tf.data.Dataset.from_tensor_slices(train_images).map(lambda z: tf.add(z, -mean) ) | |
dy_train = tf.data.Dataset.from_tensor_slices(train_labels).map(lambda z: tf.one_hot(z, 10)) | |
train_dataset = tf.data.Dataset.zip((dx_train, dy_train)).shuffle(500).repeat(epochs).batch(batch_size).cache().prefetch(1000) | |
dy_train = tf.data.Dataset.from_tensor_slices(one_hot_train_labels) | |
train_dataset = tf.data.Dataset.zip((dx_train, dy_train)).shuffle(500).repeat(epochs).batch(batch_size).cache().prefetch(1000) | |
dx_valid = tf.data.Dataset.from_tensor_slices(valid_images).map(lambda z: tf.add(z, -mean) ) | |
dy_valid = tf.data.Dataset.from_tensor_slices(valid_labels).map(lambda z: tf.one_hot(z, 10)) | |
dy_valid = tf.data.Dataset.from_tensor_slices(one_hot_train_labels) | |
valid_dataset = tf.data.Dataset.zip((dx_valid, dy_valid)).shuffle(500).repeat(1).batch(batch_size).cache() | |
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, | |
train_dataset.output_shapes) | |
next_element = iterator.get_next() | |
# model MNIST | |
keep_p = tf.placeholder(tf.float64,name='k_prob') | |
input_l = tf.placeholder(train_dataset.output_types[0],train_dataset.output_shapes[0]) if debug else next_element[0] | |
targets = tf.placeholder(train_dataset.output_types[1],train_dataset.output_shapes[1]) if debug else next_element[1] | |
l1=tf.layers.dense(input_l,15,tf.nn.relu, | |
kernel_initializer=tf.contrib.layers.xavier_initializer(), | |
kernel_regularizer=tf.contrib.layers.l2_regularizer(reg_factor)) | |
l1_dropout=tf.contrib.layers.dropout(l1,keep_prob=keep_p) | |
l2=tf.layers.dense(l1_dropout,30,tf.nn.relu, | |
kernel_initializer=tf.contrib.layers.xavier_initializer(), | |
kernel_regularizer=tf.contrib.layers.l2_regularizer(reg_factor)) | |
l2_dropout=tf.contrib.layers.dropout(l2,keep_prob=keep_p) | |
l3=tf.layers.dense(l2_dropout,50,tf.nn.relu, | |
kernel_initializer=tf.contrib.layers.xavier_initializer(), | |
kernel_regularizer=tf.contrib.layers.l2_regularizer(reg_factor)) | |
l3_dropout=tf.contrib.layers.dropout(l3,keep_prob=keep_p) | |
out=tf.layers.dense(l3_dropout,10,use_bias=False, | |
kernel_initializer=tf.contrib.layers.xavier_initializer(), | |
kernel_regularizer=tf.contrib.layers.l2_regularizer(reg_factor)) | |
loss=tf.losses.softmax_cross_entropy(targets,out) | |
train_step = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss) | |
# get accuracy | |
prediction = tf.argmax(out, 1) | |
equality = tf.equal(prediction, tf.argmax(targets, 1)) | |
accuracy = tf.reduce_mean(tf.cast(equality, tf.float32)) | |
# Init ops | |
init_op = tf.global_variables_initializer() | |
training_init_op = iterator.make_initializer(train_dataset) | |
validation_init_op = iterator.make_initializer(valid_dataset) | |
def prepare_feed(iterator): | |
data, labels = iterator.get_next() | |
x_batch,y_batch = sess.run([data, labels]) | |
return {input_l: x_batch, targets: y_batch, keep_p : kp} | |
def prepare_feed_raw(): | |
c_index = b_counter(batch_size, train_images.shape[0],epochs) | |
def get_batch(): | |
st, end = c_index() | |
return {input_l: train_images[st:end], targets: one_hot_train_labels[st:end], keep_p : kp} | |
return get_batch | |
#feed_fun = lambda : prepare_feed(iterator) #Using tf iterator with feed (A LOT OF TIME) | |
feed_fun = lambda : {keep_p : kp} # Using the tf iterator elapsed 9.742583751678467 | |
# feed_fun = prepare_feed_raw() # Just using numpy elapsed 5.372682094573975 | |
with tf.Session() as sess: | |
sess.run(init_op) | |
sess.run(training_init_op) | |
i=0 | |
with timeit() as t: | |
while True: | |
try: | |
fd = feed_fun() | |
l, _, acc = sess.run([loss, train_step, accuracy],fd) | |
if i % 50 == 0: | |
print("It: {}, loss_batch: {:.3f}, batch_accuracy: {:.2f}%".format(i, l, acc * 100)) | |
i+=1 | |
except Exception as e: | |
print(str(e)) | |
print('break at {0}'.format(i)) | |
break | |
# sess.run(validation_init_op) | |
# avg_acc = 0 | |
# c=0 | |
# while True: | |
# try: | |
# fd = prepare_feed(iterator) if debug else {keep_p : 1} | |
# acc = sess.run([accuracy],fd) | |
# avg_acc += acc[0] | |
# c+=1 | |
# except tf.errors.OutOfRangeError: | |
# print("Average validation set accuracy over {} iterations is {:.2f}%".format(c,(avg_acc / c) * 100)) | |
# break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment