Created
July 31, 2019 02:48
-
-
Save ay27/d83a1e0e9aa2aca312dbaf08caf16a29 to your computer and use it in GitHub Desktop.
[Fancy Logger] my logger
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import tensorflow as tf | |
from tensorflow.contrib.tensorboard.plugins import projector | |
import numpy as np | |
import time | |
try: | |
import scipy.misc | |
except ImportError: | |
scipy = None | |
import os | |
try: | |
from StringIO import StringIO # Python 2.7 | |
except ImportError: | |
from io import BytesIO # Python 3.x | |
def embedding_logger(tensor, save_path, meta_data=None): | |
embedding_var = tf.Variable(tensor) | |
os.makedirs(save_path, exist_ok=True) | |
meta_path = None | |
if meta_data: | |
meta_path = os.path.join(save_path, 'meta.csv') | |
with open(meta_path, 'w') as f: | |
for w in meta_data: | |
f.write(w) | |
f.write('\n') | |
with tf.Session() as sess: | |
writer = tf.summary.FileWriter(save_path, sess.graph) | |
sess.run(embedding_var.initializer) | |
config = projector.ProjectorConfig() | |
embedding = config.embeddings.add() | |
embedding.tensor_name = embedding_var.name | |
embedding.metadata_path = meta_path | |
projector.visualize_embeddings(writer, config) | |
saver_embed = tf.train.Saver([embedding_var]) | |
saver_embed.save(sess, os.path.join(save_path, 'embedding.ckpt'), 1) | |
writer.close() | |
class BoardLogger(object): | |
def __init__(self, log_dir): | |
"""Create a summary writer logging to log_dir.""" | |
self.writer = tf.summary.FileWriter(log_dir) | |
def scalar_summary(self, tag, value, step): | |
"""Log a scalar variable.""" | |
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) | |
self.writer.add_summary(summary, step) | |
def image_summary(self, tag, images, step): | |
"""Log a list of images.""" | |
if scipy is None: | |
print('not scipy found, skip image_summary') | |
return | |
img_summaries = [] | |
for i, img in enumerate(images): | |
# Write the image to a string | |
try: | |
s = StringIO() | |
except: | |
s = BytesIO() | |
scipy.misc.toimage(img).save(s, format="png") | |
# Create an Image object | |
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), | |
height=img.shape[0], | |
width=img.shape[1]) | |
# Create a Summary value | |
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) | |
# Create and write Summary | |
summary = tf.Summary(value=img_summaries) | |
self.writer.add_summary(summary, step) | |
def histo_summary(self, tag, values, step, bins=1000): | |
"""Log a histogram of the tensor of values.""" | |
# Create a histogram using numpy | |
counts, bin_edges = np.histogram(values, bins=bins) | |
# Fill the fields of the histogram proto | |
hist = tf.HistogramProto() | |
hist.min = float(np.min(values)) | |
hist.max = float(np.max(values)) | |
hist.num = int(np.prod(values.shape)) | |
hist.sum = float(np.sum(values)) | |
hist.sum_squares = float(np.sum(values ** 2)) | |
# Drop the start of the first bin | |
bin_edges = bin_edges[1:] | |
# Add bin edges and counts | |
for edge in bin_edges: | |
hist.bucket_limit.append(edge) | |
for c in counts: | |
hist.bucket.append(c) | |
# Create and write Summary | |
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) | |
self.writer.add_summary(summary, step) | |
self.writer.flush() | |
class Target(object): | |
def _update(self): | |
raise NotImplementedError | |
class Schedule(object): | |
def __init__(self): | |
self._step = 0 | |
self._t = [] | |
self._intervals = [] | |
def add_schedule(self, target, interval=1): | |
if isinstance(target, list): | |
self._t.extend(target) | |
self._intervals.extend([interval] * len(target)) | |
else: | |
self._t.append(target) | |
self._intervals.append(interval) | |
return self | |
def ticktock(self): | |
self._step += 1 | |
for t, val in zip(self._t, self._intervals): | |
if (isinstance(val, list) and self._step in val) \ | |
or (self._step % val == 0): | |
if isinstance(t, Target): | |
t._update() | |
else: | |
t() | |
class ValueTarget(Target): | |
def __init__(self, value): | |
super().__init__() | |
self.value = value | |
self._old_value = None | |
def _update(self): | |
if self._old_value is not None: | |
self._update_func(self._old_value, self.value) | |
self._old_value = self.value | |
def reset(self, value=0.0): | |
self.value = value | |
self._reset(value) | |
def _update_func(self, old_value, new_value): | |
raise NotImplementedError | |
def _reset(self, value): | |
raise NotImplementedError | |
class MovingMean(ValueTarget): | |
def __init__(self, mean_steps=100): | |
super().__init__(0) | |
self._mean_steps = float(mean_steps) | |
self._fifo = [] | |
self._sum = 0.0 | |
def _update_func(self, old_value, new_value): | |
if len(self._fifo) == self._mean_steps: | |
self._sum = self._sum - self._fifo.pop(0) + new_value | |
self.value = self._sum / self._mean_steps | |
self._fifo.append(new_value) | |
else: | |
self._fifo.append(new_value) | |
self._sum += new_value | |
self.value = self._sum / float(len(self._fifo)) | |
def _reset(self, value): | |
self._sum = value | |
self._fifo = [] | |
class Average(ValueTarget): | |
def __init__(self): | |
super().__init__(0) | |
self._reset(0) | |
def _update_func(self, old_value, new_value): | |
self._sum += new_value | |
self._iters += 1 | |
self.value = self._sum / float(self._iters) | |
def _reset(self, value): | |
self._sum = value | |
self._iters = 0 | |
class TimeStamp(ValueTarget): | |
def __init__(self): | |
super().__init__(time.time()) | |
pass | |
def _update_func(self, old_value, new_value): | |
self.value = time.time() | |
def _reset(self, value): | |
self.value = time.time() | |
class CsvLogger(Target): | |
def __init__(self, output_file, name_target_pairs, flush_interval=100, append_time=False): | |
super().__init__() | |
if len(os.path.dirname(output_file)) > 0 and not os.path.exists(os.path.dirname(output_file)): | |
os.mkdir(os.path.dirname(output_file)) | |
if append_time: | |
time_now = datetime.datetime.now().strftime("d%d-H%H-M%M-S%S") | |
self._output_file = output_file + time_now | |
self._output_file = output_file | |
self._name_target_pairs = name_target_pairs | |
for _, t in self._name_target_pairs.items(): | |
assert isinstance(t, ValueTarget) | |
self._flush_interval = flush_interval | |
self._cache = [] | |
self._file = open(output_file, 'w') | |
self._header = self._name_target_pairs.keys() | |
self._file.write(','.join(self._header) + '\n') | |
def _update(self): | |
tmp_dat = [] | |
for ii, h in enumerate(self._header): | |
tmp_dat.append(str(self._name_target_pairs[h].value)) | |
self._cache.append(','.join(tmp_dat) + '\n') | |
if len(self._cache) == self._flush_interval: | |
self._file.writelines(self._cache) | |
self._cache = [] | |
def __del__(self): | |
if len(self._cache) > 0: | |
self._file.writelines(self._cache) | |
self._file.close() | |
if __name__ == '__main__': | |
s = Schedule() | |
avg = Average() | |
moving = MovingMean(10) | |
ts = TimeStamp() | |
log = CsvLogger('/tmp/test', {'ts': ts, 'avg': avg, 'moving': moving}, flush_interval=10) | |
s.add_schedule([ts, avg, moving, log]) | |
x = np.random.rand(1000) | |
for v in range(103): | |
print(v) | |
avg.value = x[v] | |
moving.value = x[v] | |
s.ticktock() | |
avg.reset() | |
moving.reset() | |
s.ticktock() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment