Skip to content

Instantly share code, notes, and snippets.

@gaphex
Created February 29, 2020 17:04
Show Gist options
  • Save gaphex/f2d2e1a9c849ba9d69a3014da705968f to your computer and use it in GitHub Desktop.
Save gaphex/f2d2e1a9c849ba9d69a3014da705968f to your computer and use it in GitHub Desktop.
class PearsonrRankCallback(Callback):
def __init__(self, loader, filepaths, name=None, verbose=False,
sim_model=None, savemodel=None, savepath=None):
self.savemodel = savemodel
self.savepath = savepath
self.sim_model = sim_model
self.loader = loader
self.verbose = verbose
self.name = name
self.samples, self.labels = self.load_datasets(filepaths)
self.best = 0
super(PearsonrRankCallback, self).__init__()
def on_epoch_begin(self, epoch, logs=None):
pred = self.sim_model.predict(self.samples, batch_size=128,
verbose=self.verbose).reshape(-1,)
coef, p = pearsonr(self.labels, pred)
coef = np.round(coef, 4)
if coef > self.best:
self.best = coef
print("*** New best: {} = {}".format(self.name, coef))
if self.savemodel and self.savepath:
self.savemodel.save_weights(self.savepath)
else:
print("mean: {}".format(coef))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment