Last active
October 24, 2020 00:43
-
-
Save GuokaiLiu/ef1d6a65183a26e2a17307c0a83d680b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import pytorch_lightning as pl | |
from torch.utils.data import DataLoader, random_split | |
from torch.nn import functional as F | |
from torchvision.datasets import MNIST | |
from torchvision import datasets, transforms | |
import os | |
class LightningMNISTClassifier(pl.LightningModule): | |
def __init__(self): | |
super(LightningMNISTClassifier, self).__init__() | |
# mnist images are (1, 28, 28) (channels, width, height) | |
self.layer_1 = torch.nn.Linear(28 * 28, 128) | |
self.layer_2 = torch.nn.Linear(128, 256) | |
self.layer_3 = torch.nn.Linear(256, 10) | |
def forward(self, x): | |
batch_size, channels, width, height = x.size() | |
# (b, 1, 28, 28) -> (b, 1*28*28) | |
x = x.view(batch_size, -1) | |
# layer 1 (b, 1*28*28) -> (b, 128) | |
x = self.layer_1(x) | |
x = torch.relu(x) | |
# layer 2 (b, 128) -> (b, 256) | |
x = self.layer_2(x) | |
x = torch.relu(x) | |
# layer 3 (b, 256) -> (b, 10) | |
x = self.layer_3(x) | |
# probability distribution over labels | |
x = torch.log_softmax(x, dim=1) | |
return x | |
def cross_entropy_loss(self, logits, labels): | |
return F.nll_loss(logits, labels) | |
def training_step(self, train_batch, batch_idx): | |
x, y = train_batch | |
logits = self.forward(x) | |
loss = self.cross_entropy_loss(logits, y) | |
logs = {'train_loss': loss} | |
return {'loss': loss, 'log': logs} | |
def validation_step(self, val_batch, batch_idx): | |
x, y = val_batch | |
logits = self.forward(x) | |
loss = self.cross_entropy_loss(logits, y) | |
return {'val_loss': loss} | |
def validation_epoch_end(self, outputs): | |
# called at the end of the validation epoch | |
# outputs is an array with what you returned in validation_step for each batch | |
# outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}] | |
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() | |
tensorboard_logs = {'val_loss': avg_loss} | |
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} | |
def test_step(self, test_batch, batch_idx): | |
x, y = test_batch | |
logits = self.forward(x) | |
loss = self.cross_entropy_loss(logits, y) | |
return {'test_loss': loss} | |
def prepare_data(self): | |
# transforms for images | |
transform = transforms.Compose([transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,))]) | |
# prepare transforms standard to MNIST | |
mnist_train = MNIST(os.getcwd(), train=True, | |
download=True, transform=transform) | |
mnist_test = MNIST(os.getcwd(), train=False, | |
download=True, transform=transform) | |
self.mnist_train, self.mnist_val = random_split( | |
mnist_train, [55000, 5000]) | |
def train_dataloader(self): | |
return DataLoader(self.mnist_train, batch_size=256) | |
def val_dataloader(self): | |
return DataLoader(self.mnist_val, batch_size=64) | |
def test_dataloader(self): | |
return DataLoader(self, mnist_test, batch_size=64) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |
return optimizer | |
# train | |
model = LightningMNISTClassifier() | |
trainer = pl.Trainer(gpus='0',max_epochs=2) | |
trainer.fit(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment