Skip to content

Instantly share code, notes, and snippets.

@nuric
Last active August 24, 2018 17:20
Show Gist options
  • Save nuric/0c1fde80f0d1d4e703485a48f9c375e6 to your computer and use it in GitHub Desktop.
Save nuric/0c1fde80f0d1d4e703485a48f9c375e6 to your computer and use it in GitHub Desktop.
Stateful Checkpoint for Keras
import json
import socket
from keras.callbacks import ModelCheckpoint
class StatefulCheckpoint(ModelCheckpoint):
"""Save extra checkpoint data to resume training."""
def __init__(self, weight_file, state_file=None, **kwargs):
"""Save the state (epoch etc.) along side weights."""
super().__init__(weight_file, **kwargs)
self.state_f = state_file
self.hostname = socket.gethostname()
self.state = dict()
if self.state_f:
# Load the last state if any
try:
with open(self.state_f, 'r') as f:
self.state = json.load(f)
self.best = self.state['best']
except Exception as e: # pylint: disable=broad-except
print("Skipping last state:", e)
def on_train_begin(self, logs=None):
prefix = "Resuming" if self.state else "Starting"
print("{} training on {}".format(prefix, self.hostname))
def on_epoch_end(self, epoch, logs=None):
"""Saves training state as well as weights."""
super().on_epoch_end(epoch, logs)
if self.state_f:
state = {'epoch': epoch+1, 'best': self.best,
'hostname': self.hostname}
state.update(logs)
state.update(self.params)
with open(self.state_f, 'w') as f:
json.dump(state, f)
def get_last_epoch(self, initial_epoch=0):
"""Return last saved epoch if any, or return default argument."""
return self.state.get('epoch', initial_epoch)
def on_train_end(self, logs=None):
print("Training ending on {}".format(self.hostname))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment