Skip to content

Instantly share code, notes, and snippets.

@karino2
Last active July 8, 2019 15:06
Show Gist options
  • Save karino2/455233af1666f697388f351e2abb9a06 to your computer and use it in GitHub Desktop.
Save karino2/455233af1666f697388f351e2abb9a06 to your computer and use it in GitHub Desktop.
DROPOUT_RATE=0.5
L2_REGULARIZATION_RATE=0.1
EXTRACTED_FETURE_DIM=256
FEATURE_EXTRACTER_KERNEL_SIZE=7
GRU_HIDDEN=256
# tpu_model_fn also modify for update moving avg.
def fe_conv1d(filternum, kernelsize, x):
return Conv1D(filternum, kernelsize, kernel_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), bias_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), activity_regularizer=regularizers.l2(L2_REGULARIZATION_RATE))(x)
def feature_extractor(input_stroke_t, is_training):
"""input_stroke_t shape (batch, MAX_STROKE_NUM, MAX_ONE_STROKE_LEN, INPUT_TYPE_DIM)"""
with tf.variable_scope("feature_extractor"):
inpshape = input_stroke_t.shape
x = tf.reshape(input_stroke_t, [-1, inpshape[2], inpshape[3]])
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN, INPUT_TYPE_DIM)
x = fe_conv1d(32, FEATURE_EXTRACTER_KERNEL_SIZE, x)
x = tf.layers.BatchNormalization()(x, training=is_training)
x = Activation('relu')(x)
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN, 32)
x = MaxPooling1D(pool_size=2)(x)
x = Dropout(DROPOUT_RATE)(x, training=is_training)
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN/2, 32)
x = fe_conv1d(64, FEATURE_EXTRACTER_KERNEL_SIZE, x)
x = tf.layers.BatchNormalization()(x, training=is_training)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)
x = Dropout(DROPOUT_RATE)(x, training=is_training)
# (batch*MAX_STROKE_NUM, MAX_ONE_STROKE_LEN/4, 64)
x = fe_conv1d(EXTRACTED_FETURE_DIM, FEATURE_EXTRACTER_KERNEL_SIZE, x)
x = tf.layers.BatchNormalization()(x, training=is_training)
x = Activation('relu')(x)
x = Dropout(DROPOUT_RATE)(x, training=is_training)
x = GlobalMaxPooling1D()(x)
x = tf.reshape(x, [-1, inpshape[1], EXTRACTED_FETURE_DIM])
return x
def create_model(input_stroke_t, is_training):
features = feature_extractor(input_stroke_t, is_training)
last = GRU(GRU_HIDDEN, dropout=DROPOUT_RATE, recurrent_dropout=DROPOUT_RATE, kernel_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), bias_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), activity_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), recurrent_regularizer=regularizers.l2(L2_REGULARIZATION_RATE))(features, training=is_training)
if L2_REGULARIZATION_RATE == 0.0:
logit = Dense(VOCAB_SIZE)(last)
else:
logit = Dense(VOCAB_SIZE, kernel_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), bias_regularizer=regularizers.l2(L2_REGULARIZATION_RATE), activity_regularizer=regularizers.l2(L2_REGULARIZATION_RATE))(last)
return logit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment