Skip to content

Instantly share code, notes, and snippets.

@ij96
Created October 30, 2020 12:31
Show Gist options
  • Save ij96/981096a49fa214b784ba916e8741775d to your computer and use it in GitHub Desktop.
Save ij96/981096a49fa214b784ba916e8741775d to your computer and use it in GitHub Desktop.
def get_number_of_weights(model):
'''
Calculate the number of weights / params in a Keras model
'''
n_weight = 0
for l in model.layers:
n_weight_layer = 0
for w in l.weights:
n = 1
for i in w.shape:
n *= i
n_weight_layer += n
n_weight += n_weight_layer
return n_weight
def get_number_of_trainable_weights(model):
'''
Calculate the number of trainable weights / params in a Keras model
'''
n_weight_trainable = 0
for l in model.layers:
n_weight_trainable_layer = 0
for w in l.trainable_weights:
n = 1
for i in w.shape:
n *= i
n_weight_trainable_layer += n
n_weight_trainable += n_weight_trainable_layer
return n_weight_trainable
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment