Last active
December 28, 2018 22:10
-
-
Save mrdrozdov/3164cb3c77f8c81ebbba61c48ddcb540 to your computer and use it in GitHub Desktop.
Logging in Tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tf_logger import TFLogger | |
""" Example of using TFLogger to save train & dev statistics. To visualize | |
in tensorboard simply do: | |
tensorboard --logdir /path/to/summaries | |
This code does depend on Tensorflow, but does not require that your model | |
is built using Tensorflow. For instance, could build a model in Chainer, then | |
log the loss and accuracy from your Chainer model using TFLogger. | |
""" | |
train_tf_logger = TFLogger(os.path.join('.', 'summaries', 'train')) | |
eval_tf_logger = TFLogger(os.path.join('.', 'summaries', 'eval')) | |
for step, (x_batch, y_batch) in enumerate(batch_iterator): | |
acc, loss = model.train(x_batch, y_batch) | |
train_tf_logger.log(step=step, accuracy=acc, loss=loss) | |
if step % eval_step == 0: | |
acc, loss = evalute(model) | |
eval_tf_logger.log(step=step, accuracy=acc, loss=loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class TFLogger(object): | |
""" Creates an "empty model" that writes Tensorflow summaries. Can | |
visualize these summaries with Tensorboard. | |
""" | |
def __init__(self, summary_dir): | |
super(TFLogger, self).__init__() | |
self.summary_dir = summary_dir | |
self.__initialize() | |
def __initialize(self): | |
sess = tf.Session() | |
loss = tf.Variable(0.0, name="loss", trainable=False) | |
acc = tf.Variable(0.0, name="accuracy", trainable=False) | |
loss_summary = tf.scalar_summary("loss", loss) | |
acc_summary = tf.scalar_summary("accuracy", acc) | |
summary_op = tf.merge_summary([loss_summary, acc_summary]) | |
summary_writer = tf.train.SummaryWriter(self.summary_dir, sess.graph) | |
saver = tf.train.Saver(tf.all_variables()) | |
sess.run(tf.initialize_all_variables()) | |
self.sess = sess | |
self.summary_op = summary_op | |
self.summary_writer = summary_writer | |
self.loss = loss | |
self.acc = acc | |
def log(self, step, loss, accuracy): | |
feed_dict = { | |
self.loss: loss, | |
self.acc: accuracy, | |
} | |
# sess.run returns a list, so we have to explicitly | |
# extract the first item using sess.run(...)[0] | |
summaries = self.sess.run([self.summary_op], feed_dict)[0] | |
self.summary_writer.add_summary(summaries, step) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment