Skip to content

Instantly share code, notes, and snippets.

@yashkumaratri
Created May 31, 2020 18:33
Show Gist options
  • Save yashkumaratri/b2780c76c6b5dae7305248b9ace6cc5e to your computer and use it in GitHub Desktop.
Save yashkumaratri/b2780c76c6b5dae7305248b9ace6cc5e to your computer and use it in GitHub Desktop.
heatmap
def make_heatmap(text, values, save=None, polarity=1):
cell_height=.325
cell_width=.15
n_limit = 74
text = list(map(lambda x: x.replace('\n', '\\n'), text))
num_chars = len(text)
total_chars = math.ceil(num_chars/float(n_limit))*n_limit
mask = np.array([0]*num_chars + [1]*(total_chars-num_chars))
text = np.array(text+[' ']*(total_chars-num_chars))
values = np.array(values+[0]*(total_chars-num_chars))
values *= polarity
values = values.reshape(-1, n_limit)
text = text.reshape(-1, n_limit)
mask = mask.reshape(-1, n_limit)
num_rows = len(values)
plt.figure(figsize=(cell_width*n_limit, cell_height*num_rows))
hmap=sns.heatmap(values, annot=text, mask=mask, fmt='', vmin=-1, vmax=1, cmap='RdYlGn',
xticklabels=False, yticklabels=False, cbar=False)
plt.tight_layout()
if save is not None:
plt.savefig(save)
# clear plot for next graph since we returned `hmap`
plt.clf()
return hmap
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment