Skip to content

Instantly share code, notes, and snippets.

@sudarshan85
Last active October 30, 2020 13:18
Show Gist options
  • Save sudarshan85/9d9ac69100b5c4e15f4bfb9bded37f8d to your computer and use it in GitHub Desktop.
Save sudarshan85/9d9ac69100b5c4e15f4bfb9bded37f8d to your computer and use it in GitHub Desktop.
Plot metrics logged by PL CSVLogger
df = pd.read_csv(csv)
df.drop(columns=['step', 'epoch'], inplace=True)
df.fillna(method='ffill', inplace=True)
df.fillna(method='bfill', inplace=True)
df.drop_duplicates(inplace=True)
df.reset_index(inplace=True, drop=True)
df = df.iloc[::2,:].reset_index(drop=True)
fig, ax = plt.subplots(1,2,figsize=(15,5))
df[['train_loss', 'val_loss']].plot(ax=ax[0])
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
df[['train_accuracy', 'val_accuracy']].plot(ax=ax[1])
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Accuracy')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment