Code for mnist chicken experiments. Note that this code was written a long time ago and is not maintained.
Last active
September 12, 2020 18:44
-
-
Save EmilienDupont/99c7127dedb921a5a1f96d37d23c0d4b to your computer and use it in GitHub Desktop.
mnist chicken code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import DataLoader | |
from torchvision import datasets, transforms | |
def get_mnist_dataloaders(batch_size=128): | |
"""MNIST dataloader with (32, 32) images.""" | |
all_transforms = transforms.Compose([ | |
transforms.Resize(32), | |
transforms.ToTensor() | |
]) | |
train_data = datasets.MNIST('../ml-sandbox/data', train=True, download=True, | |
transform=all_transforms) | |
test_data = datasets.MNIST('../ml-sandbox/data', train=False, | |
transform=all_transforms) | |
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) | |
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) | |
return train_loader, test_loader | |
def get_fashion_mnist_dataloaders(batch_size=128): | |
"""FashionMNIST dataloader with (32, 32) images.""" | |
all_transforms = transforms.Compose([ | |
transforms.Resize(32), | |
transforms.ToTensor() | |
]) | |
train_data = datasets.FashionMNIST('../ml-sandbox/fashion_data', | |
train=True, download=True, | |
transform=all_transforms) | |
test_data = datasets.FashionMNIST('../ml-sandbox/fashion_data', | |
train=False, transform=all_transforms) | |
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) | |
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) | |
return train_loader, test_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
class CNN(nn.Module): | |
def __init__(self): | |
super(CNN, self).__init__() | |
self.img_to_features = nn.Sequential( | |
nn.Conv2d(1, 16, (4, 4), stride=2, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(16, 32, (4, 4), stride=2, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(32, 64, (4, 4), stride=2, padding=1), | |
nn.ReLU() | |
) | |
self.features_to_probs = nn.Sequential( | |
nn.Linear(64 * 4 * 4, 10), | |
nn.Softmax() | |
) | |
def forward(self, x): | |
features = self.img_to_features(x) | |
probs = self.features_to_probs(features.view(features.size(0), -1)) | |
return probs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from dataloaders import get_mnist_dataloaders | |
from models import CNN | |
from torch.autograd import Variable | |
def eval_model(model, data): | |
correct = 0 | |
total = 0 | |
for imgs, labels in data: | |
imgs = Variable(imgs).cuda() | |
probs = model(imgs) | |
_, predicted = torch.max(probs.data, 1) | |
total += labels.size(0) | |
correct += (predicted.cpu() == labels).sum() | |
return float(correct) / total | |
def train_epoch(model, data): | |
for i, (imgs, labels) in enumerate(data): | |
imgs = Variable(imgs).cuda() | |
labels = Variable(labels).cuda() | |
optimizer.zero_grad() | |
probs = model(imgs) | |
loss = criterion(probs, labels) | |
loss.backward() | |
optimizer.step() | |
if (i+1) % 100 == 0: | |
print("Iteration: {}, Loss: {}".format(i, loss.data[0])) | |
# Get datasets | |
train_loader, test_loader = get_mnist_dataloaders(batch_size=100) | |
# Create model | |
cnn = CNN() | |
cnn.cuda() | |
# Loss and Optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3) | |
# Train the Model | |
for epoch in range(10): | |
train_epoch(cnn, train_loader) | |
test_acc = eval_model(cnn, test_loader) | |
print("Epoch {}, Test Accuracy: {}\n".format(epoch + 1, test_acc)) | |
# Save the Trained Model | |
torch.save(cnn.state_dict(), 'mnist_cnn.pt') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment