-
-
Save Seanny123/0419e916f7f50cd77811c2b556d5ddeb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Input, Masking, LSTM, Dense | |
from keras.models import Model | |
import numpy as np | |
# Case1: model with return_sequences=True (output_shape = (1,10,1) ) | |
############################################################## | |
input1 = Input(batch_shape=(1, 10, 16)) | |
mask1 = Masking(mask_value=2.)(input1) | |
lstm1 = LSTM(16, return_sequences=True)(mask1) | |
dense_layer = Dense(1, activation='sigmoid') | |
dense_layer.__setattr__('supports_masking', True) | |
dense1 = dense_layer(lstm1) | |
model1 = Model(input1, dense1) | |
model1.compile(optimizer='adam', loss='binary_crossentropy') | |
# Case2: model with return_sequences=False (output_shape = (1,1) ) | |
############################################################### | |
lstm2 = LSTM(16, return_sequences=False)(mask1) | |
dense2 = dense_layer(lstm2) | |
model2 = Model(input1, dense2) | |
model2.compile(optimizer='adam', loss='binary_crossentropy') | |
# initialize train data and labels | |
############################################################### | |
data = np.zeros((3, 10, 16)) | |
data2 = np.ones((2, 10, 16)) | |
labels_net1 = np.ones((3, 10, 1)) | |
labels2_net1 = np.zeros((2, 10, 1)) | |
labels_net2 = np.ones((3, 1)) | |
labels2_net2 = np.zeros((2, 1)) | |
train_data1 = np.concatenate([data, data2], axis=0) | |
train_labels1 = np.concatenate([labels_net1, labels2_net1], axis=0) | |
train_labels2 = np.concatenate([labels_net2, labels2_net2], axis=0) | |
# add 'masked' data to train_data | |
################################################################ | |
masked_train_data = np.copy(train_data1) | |
masked_train_data[1, 1, :] = 2 | |
# train models | |
################################################################# | |
model1.fit(masked_train_data, train_labels1, nb_epoch=1000, batch_size=1) | |
model2.fit(masked_train_data, train_labels2, nb_epoch=1000, batch_size=1) | |
model3 = Model(input1, dense1) # want to retrain first network without masked data | |
model3.compile(optimizer='adam', loss='binary_crossentropy') | |
model3.fit(train_data1, train_labels1, nb_epoch=1000, batch_size=1) | |
# create test data | |
################################################################## | |
test_data1 = np.ones((1, 10, 16)) | |
test_data2 = np.zeros((1, 10, 16)) | |
# add 'mask' to test data | |
test_data1[0, 3, :] = 2 | |
test_data2[0, 3, :] = 2 | |
# predictions | |
################################################################## | |
model1_predictions1 = model1.predict(test_data1) | |
model1_predictions2 = model1.predict(test_data2) | |
model2_predictions1 = model2.predict(test_data1) | |
model2_predictions2 = model2.predict(test_data2) | |
model3_predictions1 = model3.predict(test_data1) | |
print(model1_predictions1) | |
print(model1_predictions2) | |
print(model2_predictions1) | |
print(model2_predictions2) | |
print(model3_predictions1) | |
# Glorious printouts | |
##################################################################### | |
# model1_predictions1, y_true = [0., 0., ..., 0.] | |
#[[[ 2.14141060e-08] | |
# [ 7.29542982e-10] | |
# [ 3.53702262e-10] | |
# [ 3.53702262e-10] <-- this is the masked line, output is same as previous | |
# [ 2.82663781e-10] | |
# [ 2.61021177e-10] | |
# [ 2.52067062e-10] | |
# [ 2.47708437e-10] | |
# [ 2.45296505e-10] | |
# [ 2.43848081e-10]]] | |
# model1_predictions2, y_true = [1., 1., ... ,1.] | |
#[[[ 0.99999619] | |
# [ 1. ] | |
# [ 1. ] | |
# [ 1. ] <-- masked line | |
# [ 1. ] | |
# [ 1. ] | |
# [ 1. ] | |
# [ 1. ] | |
# [ 1. ] | |
# [ 1. ]]] | |
# model2_predictions1, y_true = 0. | |
#[[ 1.09495701e-08]] <- runs data with mask | |
# model2_predictions2, y_true = 1. | |
#[[ 1.]] <- runs with mask | |
# model3_predictions3, y_true = [0., 0., ..., 0.] | |
#[[[ 2.14141060e-08] | |
# [ 7.29542982e-10] | |
# [ 3.53702262e-10] | |
# [ 3.53702262e-10] <- this is the masked line, output is same as previous | |
# [ 2.82663781e-10] | |
# [ 2.61021177e-10] | |
# [ 2.52067062e-10] | |
# [ 2.47708437e-10] | |
# [ 2.45296505e-10] | |
# [ 2.43848081e-10]]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment