Created
May 27, 2016 08:59
-
-
Save asmith26/d339c542cf3c55ecdc4eef5ab08b2edd to your computer and use it in GitHub Desktop.
model_from_json fails with "TypeError: arg 5 (closure) must be tuple"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
np.random.seed(2 ** 10) | |
# Prevent reaching to maximum recursion depth in `theano.tensor.grad` | |
# import sys | |
# sys.setrecursionlimit(2 ** 20) | |
from six.moves import range | |
from keras.datasets import cifar10 | |
from keras.layers import Input, Dense, Layer, merge, Activation, Flatten, Lambda | |
from keras.layers.convolutional import Convolution2D, AveragePooling2D | |
from keras.layers.normalization import BatchNormalization | |
from keras.models import Model | |
from keras.optimizers import SGD | |
from keras.regularizers import l2 | |
from keras.callbacks import Callback, LearningRateScheduler | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.utils import np_utils | |
import keras.backend as K | |
batch_size = 64 | |
nb_classes = 10 | |
nb_epoch = 500 | |
N = 18 | |
weight_decay = 1e-4 | |
lr_schedule = [0.5, 0.75] | |
death_mode = "lin_decay" # or uniform | |
death_rate = 0.5 | |
img_rows, img_cols = 32, 32 | |
img_channels = 3 | |
add_tables = [] | |
inputs = Input(shape=(img_channels, img_rows, img_cols)) | |
net = Convolution2D(16, 3, 3, border_mode="same", W_regularizer=l2(weight_decay))(inputs) | |
net = BatchNormalization(axis=1)(net) | |
net = Activation("relu")(net) | |
def residual_drop(x, input_shape, output_shape, strides=(1, 1)): | |
global add_tables | |
nb_filter = output_shape[0] | |
conv = Convolution2D(nb_filter, 3, 3, subsample=strides, | |
border_mode="same", W_regularizer=l2(weight_decay))(x) | |
conv = BatchNormalization(axis=1)(conv) | |
conv = Activation("relu")(conv) | |
conv = Convolution2D(nb_filter, 3, 3, | |
border_mode="same", W_regularizer=l2(weight_decay))(conv) | |
conv = BatchNormalization(axis=1)(conv) | |
if strides[0] >= 2: | |
x = AveragePooling2D(strides)(x) | |
if (output_shape[0] - input_shape[0]) > 0: | |
pad_shape = (1, | |
output_shape[0] - input_shape[0], | |
output_shape[1], | |
output_shape[2]) | |
padding = K.zeros(pad_shape) | |
padding = K.repeat_elements(padding, K.shape(x)[0], axis=0) | |
x = Lambda(lambda y: K.concatenate([y, padding], axis=1), | |
output_shape=output_shape)(x) | |
_death_rate = K.variable(death_rate) | |
scale = K.ones_like(conv) - _death_rate | |
conv = Lambda(lambda c: K.in_test_phase(scale * c, c), | |
output_shape=output_shape)(conv) | |
out = merge([conv, x], mode="sum") | |
out = Activation("relu")(out) | |
gate = K.variable(1, dtype="uint8") | |
add_tables += [{"death_rate": _death_rate, "gate": gate}] | |
return Lambda(lambda tensors: K.switch(gate, tensors[0], tensors[1]), | |
output_shape=output_shape)([out, x]) | |
for i in range(N): | |
net = residual_drop(net, input_shape=(16, 32, 32), output_shape=(16, 32, 32)) | |
net = residual_drop( | |
net, | |
input_shape=(16, 32, 32), | |
output_shape=(32, 16, 16), | |
strides=(2, 2) | |
) | |
for i in range(N - 1): | |
net = residual_drop( | |
net, | |
input_shape=(32, 16, 16), | |
output_shape=(32, 16, 16) | |
) | |
net = residual_drop( | |
net, | |
input_shape=(32, 16, 16), | |
output_shape=(64, 8, 8), | |
strides=(2, 2) | |
) | |
for i in range(N - 1): | |
net = residual_drop( | |
net, | |
input_shape=(64, 8, 8), | |
output_shape=(64, 8, 8) | |
) | |
pool = AveragePooling2D((8, 8))(net) | |
flatten = Flatten()(pool) | |
predictions = Dense(10, activation="softmax", W_regularizer=l2(weight_decay))(flatten) | |
model = Model(input=inputs, output=predictions) | |
sgd = SGD(lr=0.1, momentum=0.9, nesterov=True) | |
model.compile(optimizer=sgd, loss="categorical_crossentropy") | |
def open_all_gates(): | |
for t in add_tables: | |
K.set_value(t["gate"], 1) | |
# setup death rate | |
for i, tb in enumerate(add_tables, start=1): | |
if death_mode == "uniform": | |
K.set_value(tb["death_rate"], death_rate) | |
elif death_mode == "lin_decay": | |
K.set_value(tb["death_rate"], i / len(add_tables) * death_rate) | |
else: | |
raise | |
class GatesUpdate(Callback): | |
def on_batch_begin(self, batch, logs={}): | |
open_all_gates() | |
rands = np.random.uniform(size=len(add_tables)) | |
for t, rand in zip(add_tables, rands): | |
if rand < K.get_value(t["death_rate"]): | |
K.set_value(t["gate"], 0) | |
def on_batch_end(self, batch, logs={}): | |
open_all_gates() # for validation | |
def schedule(epoch_idx): | |
if (epoch_idx + 1) < (nb_epoch * lr_schedule[0]): | |
return 0.1 | |
elif (epoch_idx + 1) < (nb_epoch * lr_schedule[1]): | |
return 0.01 | |
return 0.001 | |
with open('model.yaml', 'w') as f: | |
# f.write(model.to_json()) | |
f.write( model.to_yaml() ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment