Skip to content

Instantly share code, notes, and snippets.

@stefanthaler
Created February 3, 2017 14:38
Show Gist options
  • Save stefanthaler/7240f62d78de0b1a34ad2029e3d2336b to your computer and use it in GitHub Desktop.
Save stefanthaler/7240f62d78de0b1a34ad2029e3d2336b to your computer and use it in GitHub Desktop.
A simple example to demonstrate how to link embedding metadata to word embeddings in tensorflow / tensorboard
"""
Simple example to demostrate the embedding visualization for word embeddings in tensorflow / tensorboard
https://www.tensorflow.org/how_tos/embedding_viz/
"""
import tensorflow as tf
import os
assert tf.__version__ == '1.0.0-rc0' # if code breaks, check tensorflow version
from tensorflow.contrib.tensorboard.plugins import projector
"""
Hyperparameter
"""
checkpoint_path = "checkpoints"
if not os.path.exists(checkpoint_path): os.mkdir(checkpoint_path)
vocabulary_size = 20 # we have 20 words in our vocabulary
word_embedding_dim = 10 # each word is represented by a [1,10] dimensional row vector.
"""
Create Word Embedding
"""
# create word embeddings, fill randomly
word_embeddings = tf.Variable(tf.random_uniform([vocabulary_size, word_embedding_dim], -1.0, 1.0), name='word_embeddings')
"""
Save Word embedding checkpoint
"""
# Saver
saver = tf.train.Saver(tf.global_variables())
# Start session
session = tf.Session()
summary_writer = tf.summary.FileWriter(checkpoint_path, graph=session.graph)
session.run([tf.global_variables_initializer()]) # init variables
#... do stuff with session
# save checkpoints periodically
chkpoint_out_filename = os.path.join(checkpoint_path, "word_embedding_sample")
saver.save(session, chkpoint_out_filename, global_step=1)
print("\nword_embeddings checkpoint saved")
"""
Write metadata file
"""
tsv_row_template = "{}\t{}\t{}\n"
with open(os.path.join(checkpoint_path, 'word_embeddings.tsv'), "w") as f:
header_row = tsv_row_template.format("Name", "Category", "Type")
f.write(header_row)
for w_id in xrange(vocabulary_size):
# get metadat for each word
word = "word %0.2d"%w_id
category = w_id%5
word_type = "type %i"%(w_id%3)
data_row = tsv_row_template.format(word,category,word_type)
f.write(data_row)
print("word_embeddings.tsv written.")
"""
Link metadata tsv file to embedding
"""
config = projector.ProjectorConfig()
embedding = config.embeddings.add() # could add more metadata files here
embedding.tensor_name = word_embeddings.name
embedding.metadata_path = os.path.join(checkpoint_path, 'word_embeddings.tsv')
projector.visualize_embeddings(summary_writer, config)
print("Metadata linked to checkpoint\n")
print("run: tensorboard --logdir checkpoints/")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment