Skip to content

Instantly share code, notes, and snippets.

@tanzhenyu
Created July 19, 2018 16:35
Show Gist options
  • Save tanzhenyu/48d705ba49acb0ff166da2db986745ef to your computer and use it in GitHub Desktop.
Save tanzhenyu/48d705ba49acb0ff166da2db986745ef to your computer and use it in GitHub Desktop.
nasnet model for model_to_estimator
# Download your data here: https://www.kaggle.com/c/dogs-vs-cats/data and split them into /train/dogs & /train/cats
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
import tensorflow as tf
def test_nasnet(self):
img_width, img_height = 150, 150
train_samples = 1600
train_data_dir = '/train'
epochs = 50
batch_size = 16
steps = int(train_samples / batch_size)
input_shape = (img_width, img_height, 3)
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='binary')
train_data = np.zeros((steps, 16, img_width, img_height, 3))
train_target = np.zeros((steps, 16, 2), np.int8)
for i in range(steps):
next_data = train_generator.next()
train_data[i] = next_data[0]
train_target[i] = to_categorical(next_data[1], num_classes=2)
def train_input_fn():
i = np.random.randint(100)
return (train_data[i], train_target[i])
model = nasnet.NASNetMobile(
input_shape=input_shape, weights=None, classes=2)
# model.summary()
model.compile(
optimizer='sgd',
loss='categorical_crossentropy',
metrics=['categorical_accuracy'])
config = run_config_lib.RunConfig(save_summary_steps=1, save_checkpoints_steps=1, log_step_count_steps=1)
nasnet_estimator = keras_lib.model_to_estimator(
keras_model=model,
model_dir='/result',
config=config)
for _ in range(epochs):
nasnet_estimator.train(input_fn=train_input_fn, steps=2)
nasnet_estimator.evaluate(input_fn=train_input_fn, steps=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment