Skip to content

Instantly share code, notes, and snippets.

@kwakseonghun
Created August 10, 2017 12:17
Show Gist options
  • Save kwakseonghun/c377c21ec6882d3b968db33b5e72a2eb to your computer and use it in GitHub Desktop.
Save kwakseonghun/c377c21ec6882d3b968db33b5e72a2eb to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import random
tf.set_random_seed(777)
mnist = input_data.read_data_sets("/tmp/data/", one_hot = True)
# constants
training_epochs = 12
learning_rate = 0.01
batch_size = 100
display_step = 1
depth_step = 50
init_call = 784 #들어올때 형태
out_call = 10 #나갈때 형태
init1 = 784
init = 0
keep_prob = tf.placeholder(tf.float32)
# Data placeholders
X = tf.placeholder("float", [None,784])
Y = tf.placeholder("float", [None,10])
W={}
b={}
#K = 3
nn_depth = int(init1/depth_step - 1)
step=150
for i in range(nn_depth):
init = init1
init1 -= step
W[i] = tf.Variable(tf.random_normal([init,init1]))
b[i]=tf.Variable(tf.random_normal([init1]))
if i==0:
L=tf.nn.sigmoid(tf.matmul(X,W[0])+b[0])
else:
L=tf.nn.sigmoid(tf.matmul(L,W[i])+b[i])
W_out = tf.Variable(tf.random_normal([334,out_call]))
b_out = tf.Variable(tf.random_normal([out_call]))
hypothesis= tf.matmul(L,W_out)+b_out
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = hypothesis, labels = Y))
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
correct_prediction = tf.equal(tf.arg_max(hypothesis, 1), tf.arg_max(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
#total_batch = 10
for i in range (total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
c,o = sess.run([cost, optimizer], feed_dict={X: batch_xs, Y: batch_ys, keep_prob: 0.7})
avg_cost += c/total_batch
print(avg_cost)
if epoch % display_step == 0 :
print("Epoch:",'%0.4d'%(epoch+1),"cost=","{:.9f}".format(avg_cost))
print("Model accuracy:", accuracy.eval({X:mnist.test.images, Y:mnist.test.labels}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment