Skip to content

Instantly share code, notes, and snippets.

@markloyman
Created September 5, 2018 12:17
Show Gist options
  • Save markloyman/4f96b6e2cde32d927eab2385d6b0c70d to your computer and use it in GitHub Desktop.
Save markloyman/4f96b6e2cde32d927eab2385d6b0c70d to your computer and use it in GitHub Desktop.
Use TensorBoard's runtime statistcs with a Keras model
import numpy as np
import tensorflow as tf
sess = tf.Session()
from keras import backend as K
K.set_session(sess)
from keras.objectives import mean_squared_error
from keras.layers import Dense
def load_dummy_data(n = 1000):
x = np.random.rand(n, 100)
y = np.sum(x, axis=1, keepdims=True)
return x, y
def load_keras_network(input_tensor):
x = Dense(100, activation='relu')(input_tensor)
x = Dense(100, activation='relu')(x)
x = Dense(100, activation='relu')(x)
x = Dense(1, activation='sigmoid')(x)
return x
# load your keras model as a tf.Tensor
input = tf.placeholder(tf.float32, shape=(None, 100)) # is passed as input to our keras layers
labels = tf.placeholder(tf.float32, shape=(None, 1))
net = load_network(input) # type(net) == tf.Tensor
loss = tf.reduce_mean(mean_squared_error(labels, net))
opt = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
writer = tf.summary.FileWriter(r'./logs', sess.graph)
sess.run(tf.global_variables_initializer())
with sess.as_default():
x, y = load_data(64)
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
sess.run([opt],
feed_dict={input: x, labels: y},
options=run_options,
run_metadata=run_metadata)
writer.add_run_metadata(run_metadata, 'runtime-stats')
writer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment