Last active
May 7, 2018 07:12
-
-
Save visionNoob/35edba42777a65c98c05be2bcd5b05e5 to your computer and use it in GitHub Desktop.
Model1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
###########Import | |
import keras | |
from keras.callbacks import ModelCheckpoint | |
from keras import models | |
from keras import layers | |
keras.__version__ | |
###########Load MNIST | |
from keras.datasets import mnist | |
from keras.utils import to_categorical | |
(train_images, train_labels), (test_images, test_labels) = mnist.load_data() | |
train_images = train_images.reshape((60000, 28 * 28)) | |
train_images = train_images.astype('float32') / 255 | |
test_images = test_images.reshape((10000, 28 * 28)) | |
test_images = test_images.astype('float32') / 255 | |
train_labels = to_categorical(train_labels) | |
test_labels = to_categorical(test_labels) | |
###########Network1 | |
network1 = models.Sequential() | |
network1.add(layers.Dense(10, activation='softmax', input_shape=(28 * 28,))) | |
network1.compile(optimizer='rmsprop',loss='categorical_crossentropy', metrics=['accuracy']) | |
###########Network2 | |
network2 = models.Sequential() | |
network2.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,))) | |
network2.add(layers.Dense(512, activation='relu', input_shape=(512,))) | |
network2.add(layers.Dense(10, activation='softmax')) | |
network2.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy']) | |
###########Fit both | |
result1 = network1.fit(train_images, train_labels, epochs= 100, batch_size=1024, validation_split = 0.1) | |
result1 = network2.fit(train_images, train_labels, epochs= 100, batch_size=1024, validation_split = 0.1) | |
###########Test | |
n1_test_loss, n1_test_acc = network1.evaluate(test_images, test_labels) | |
print('n1_test_acc:', n1_test_acc) | |
print('n1_test_loss:', n1_test_loss) | |
n2_test_loss, n2_test_acc = network2.evaluate(test_images, test_labels) | |
print('n2_test_acc:', n2_test_acc) | |
print('n2_test_loss:', n2_test_loss) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment