Created
June 4, 2020 13:58
-
-
Save Cospel/2cf1b3d6323b9763002bc9e7e23a16c9 to your computer and use it in GitHub Desktop.
BatchRenormCallback.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BatchRenormCallback(tf.keras.callbacks.Callback): | |
def __init__(self, log_dir, bvalues): | |
super().__init__() | |
self.writer = tf.summary.create_file_writer(log_dir) | |
self.bvalues = bvalues | |
def on_epoch_begin(self, epoch, logs=None): | |
if epoch in self.bvalues: | |
self.change_renorm(epoch) | |
def change_renorm(self, epoch): | |
print(f"Changing renorm clipping values epoch {epoch}", self.bvalues[epoch]) | |
renorm_clipping = self.bvalues[epoch] | |
with self.writer.as_default(): | |
tf.summary.scalar("renorm_rmax", renorm_clipping["rmax"], epoch) | |
tf.summary.scalar("renorm_dmax", renorm_clipping["dmax"], epoch) | |
for layer in self.model.layers: | |
if isinstance(layer, tf.keras.Model): | |
for layer2 in layer.layers: | |
if isinstance(layer2, BatchNormalization): | |
layer2.renorm_clipping = renorm_clipping | |
if isinstance(layer, BatchNormalization): | |
layer.renorm_clipping = renorm_clipping |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment