Skip to content

Instantly share code, notes, and snippets.

@deli4iled
Created May 11, 2020 13:02
Show Gist options
  • Save deli4iled/0e3d913786966f11d84cedcfe9431d85 to your computer and use it in GitHub Desktop.
Save deli4iled/0e3d913786966f11d84cedcfe9431d85 to your computer and use it in GitHub Desktop.
import sys
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
device = 'cuda'
class SimpleCNN(nn.Module):
def __init__(self, num_channels=1, num_classes=10):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 32, 3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
self.pool1 = nn.MaxPool2d(2)
self.drop1 = nn.Dropout(0.25)
self.fc1 = nn.Linear(14*14*32, 128)
self.drop2 = nn.Dropout(0.5)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, X):
X = F.relu(self.conv1(X))
X = F.relu(self.conv2(X))
X = self.pool1(X)
X = self.drop1(X)
X = X.reshape(-1, 14*14*32)
X = F.relu(self.fc1(X))
X = self.drop2(X)
X = self.fc2(X)
return X # logits
def save_checkpoint(optimizer, model, epoch, filename):
checkpoint_dict = {
'optimizer': optimizer.state_dict(),
'model': model.state_dict(),
'epoch': epoch
}
torch.save(checkpoint_dict, filename)
def load_checkpoint(optimizer, model, filename):
checkpoint_dict = torch.load(filename)
epoch = checkpoint_dict['epoch']
model.load_state_dict(checkpoint_dict['model'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
return epoch
!mkdir -p checkpoints
def train(optimizer, model, num_epochs=10, first_epoch=1):
criterion = nn.CrossEntropyLoss()
train_losses = []
valid_losses = []
for epoch in range(first_epoch, first_epoch + num_epochs):
print('Epoch', epoch)
# train phase
model.train()
# create a progress bar
progress = ProgressMonitor(length=len(train_set))
train_loss = MovingAverage()
for batch, targets in train_loader:
# Move the training data to the GPU
batch = batch.to(device)
targets = targets.to(device)
# clear previous gradient computation
optimizer.zero_grad()
# forward propagation
predictions = model(batch)
# calculate the loss
loss = criterion(predictions, targets)
# backpropagate to compute gradients
loss.backward()
# update model weights
optimizer.step()
# update average loss
train_loss.update(loss)
# update progress bar
progress.update(batch.shape[0], train_loss)
print('Training loss:', train_loss)
train_losses.append(train_loss.value)
# validation phase
model.eval()
valid_loss = RunningAverage()
# keep track of predictions
y_pred = []
# We don't need gradients for validation, so wrap in
# no_grad to save memory
with torch.no_grad():
for batch, targets in valid_loader:
# Move the training batch to the GPU
batch = batch.to(device)
targets = targets.to(device)
# forward propagation
predictions = model(batch)
# calculate the loss
loss = criterion(predictions, targets)
# update running loss value
valid_loss.update(loss)
# save predictions
y_pred.extend(predictions.argmax(dim=1).cpu().numpy())
print('Validation loss:', valid_loss)
valid_losses.append(valid_loss.value)
# Calculate validation accuracy
y_pred = torch.tensor(y_pred, dtype=torch.int64)
accuracy = torch.mean((y_pred == valid_set.targets).float())
print('Validation accuracy: {:.4f}%'.format(float(accuracy) * 100))
# Save a checkpoint
checkpoint_filename = 'checkpoints/mnist-{:03d}.pkl'.format(epoch)
save_checkpoint(optimizer, model, epoch, checkpoint_filename)
return train_losses, valid_losses, y_pred
# transform for the training data
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.1307], [0.3081])
])
# use the same transform for the validation data
valid_transform = train_transform
# load datasets, downloading if needed
train_set = MNIST('./data/mnist', train=True, download=True,
transform=train_transform)
valid_set = MNIST('./data/mnist', train=False, download=True,
transform=valid_transform)
print(train_set.data.shape)
print(valid_set.data.shape)
train_loader = DataLoader(train_set, batch_size=256, num_workers=0, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=512, num_workers=0, shuffle=False)
model = SimpleCNN()
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True)
train_losses, valid_losses, y_pred = train(optimizer, model, num_epochs=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment