Skip to content

Instantly share code, notes, and snippets.

@mark-andrews
Created December 10, 2020 17:45
Show Gist options
  • Save mark-andrews/631c21736c3a4058f78d762a3b136d4e to your computer and use it in GitHub Desktop.
Save mark-andrews/631c21736c3a4058f78d762a3b136d4e to your computer and use it in GitHub Desktop.
import torch
import torchvision
tt = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = torchvision.datasets.MNIST('./files/', train=True, download=True, transform=tt)
mnist_test = torchvision.datasets.MNIST('./files/', train=False, download=True, transform=tt)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=20, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size = 1000, shuffle=True)
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment