Last active
October 17, 2016 16:28
-
-
Save gaphex/6b2062b995939280efd19c6f91edd8a5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import matplotlib.pyplot as plt | |
from IPython import display | |
""" | |
IPython Display rc0 | |
Try: | |
dsp = IDisplay() | |
dsp.test() | |
""" | |
class IDisplay(object): | |
def __init__(self, | |
vsize=8, hsize=12, | |
title='Training in progress', | |
xlabel='epoch', | |
ylabel='score'): | |
''' | |
Sets plot parameters such as: | |
hsize: horizontal plot size | |
vsize: vertical plot size | |
xlabel: x axis label | |
ylabel: y axis label | |
title: plot title | |
''' | |
self.colors = {} | |
self.xlabel = xlabel | |
self.ylabel = ylabel | |
self.title = title | |
self.hsize = hsize | |
self.vsize = vsize | |
def _generate_color(self, label): | |
''' | |
Generates color for new, unseen label and stores the pair. | |
''' | |
self.colors[label] = {'r': float(random.randrange(0,255))/255, | |
'g': float(random.randrange(0,255))/255, | |
'b': float(random.randrange(0,255))/255} | |
def regenerate_colors(self): | |
''' | |
Regenerates colors for seen labels | |
if you dont like the current ones. | |
''' | |
for label in self.colors: | |
self._generate_color(label) | |
def display(self, data_dict): | |
''' | |
Visualises 1-D data with line plots. | |
Every call to this method rewrites output from calling cell. | |
Accepts a data dictionary of following structure: | |
{'label_1': [], | |
'label_2': [], | |
...} | |
''' | |
display.clear_output(wait=True) | |
plt.figure(figsize=(self.hsize, self.vsize)) | |
plt.xlabel(self.xlabel) | |
plt.ylabel(self.ylabel) | |
plt.title(self.title) | |
for label in sorted(data_dict.keys()): | |
if label not in self.colors: | |
self._generate_color(label) | |
plt.plot(data_dict[label], label=label, | |
linewidth=4, | |
color=(self.colors[label]['r'], | |
self.colors[label]['g'], | |
self.colors[label]['b'])) | |
legend = plt.legend(loc='upper left', shadow=True) | |
plt.show() | |
def test(self, epochs=16): | |
''' | |
Performs a dummy run with randomly generated data | |
''' | |
data = {'popugai_1':[], 'popugai_2':[]} | |
for epoch in range(epochs): | |
for i, k in enumerate(data.keys()): | |
data[k].append(random.randint((i+1)*4, (i+1)*16)) | |
self.display(data) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment