Skip to content

Instantly share code, notes, and snippets.

@amohant4
Last active May 8, 2021 18:03
Show Gist options
  • Save amohant4/f10e4f4f8a3f37f58e79be09a9ef8f87 to your computer and use it in GitHub Desktop.
Save amohant4/f10e4f4f8a3f37f58e79be09a9ef8f87 to your computer and use it in GitHub Desktop.
Example usage of gluon in MXNet with test case of LeNet for MNIST
import mxnet as mx
from mxnet.gluon import nn
from mxnet.gluon.data.vision import datasets, transforms
from mxnet import init, gluon
import time
def create_lenet_using_sequential():
"""
Method to return a lenet using nn.Sequential from
MXNet. nn.Sequential is a subclass of nn.Block. 1
Arguments: None
Returns: An instance of type Sequential.
"""
net = nn.Sequential()
net.add(
nn.Conv2D(channels=6,kernel_size=5,activation='relu'),
nn.MaxPool2D(pool_size=2,strides=2),
nn.Conv2D(channels=16,kernel_size=3,activation='relu'),
nn.MaxPool2D(pool_size=2,strides=2),
nn.Dense(120, activation='relu'),
nn.Dense(84, activation='relu'),
nn.Dense(10))
return net
class myLeNet(nn.Block):
"""
Custom class implementing LeNet.
This implementation is very flexible.
Things to remember:
- this class is a subclass of nn.Block
- You need to call __init__ of nn.Block in the constructor
- __init__ defines all the nodes in the graph
__ forward defines the forward fuction of the network
"""
def __init__(self, **kwargs):
super(myLeNet, self).__init__(**kwargs)
self.conv1 = nn.Conv2D(channels=6,kernel_size=5,activation='relu')
self.pool1 = nn.MaxPool2D(pool_size=2,strides=2)
self.conv2 = nn.Conv2D(channels=16,kernel_size=3,activation='relu')
self.pool1 = nn.MaxPool2D(pool_size=2,strides=2)
self.fc1 = nn.Dense(120, activation='relu')
self.fc2 = nn.Dense(84, activation='relu')
self.fc3 = nn.Dense(10)
def forward(self,x):
return self.fc3(self.fc2(self.fc1(self.pool1(self.conv2(self.pool1(self.conv1(x)))))))
# Training Dataset creation ~~~
mnist_train = datasets.FashionMNIST(train=True)
# Transform inputs, augmentation, normalization etc is done here
transformer = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.13,0.31)])
# apply the transformation to each image in the dataset
mnist_train = mnist_train.transform_first(transformer)
# Data loader to facilitate loading of data during training
# Mark num_workers is 4, more worker threads are needed for complicated transforms and bigger batch size
batch_size = 256
train_data = gluon.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)
# Validation Dataset creation ~~~
mnist_valid = gluon.data.vision.FashionMNIST(train=False)
valid_data = gluon.data.DataLoader(mnist_valid.transform_first(transformer), batch_size=batch_size, num_workers=4)
# Create instance of the network and other necesities for training ~~~
net = myLeNet()
net.initialize(init=init.Xavier()) # Initialize the parameters using Xavier initialization
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() # Define the loss
# create a trainer with SGD training and learning rate of 0.1
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1})
# Metric function ~~~
def acc(output,label):
"""
Utility function to return the accuracy of the network for given outputs and labels.
Arguments:
output: output from the network (in this case the last fc layer)
label: golden output as per the dataset
returns:
accuracy (scalar): average number of times the prediction is correct
"""
return (output.argmax(axis=1) == label.astype('float32')).mean().asscalar()
# Training Loop ~~~
for epoch in range(10):
train_loss, train_acc, valid_acc = 0.,0.,0.
tic = time.time()
for data, label in train_data: # Iterate through the training dataset
with mx.autograd.record(): # Record gradient of error
output = net(data) # Forward pass
loss = softmax_cross_entropy(output, label) # get Loss
loss.backward() # Back propagate
trainer.step(batch_size)
train_loss += loss.mean().asscalar()
train_acc += acc(output, label)
for data, label in valid_data: # Validation Loop ~~
valid_acc += acc(net(data), label)
# Log metrics to std output ~~
print("Epoch %d: loss %.3f, train acc %.3f, test acc %.3f, in %.1f sec" % (
epoch, train_loss/len(train_data), train_acc/len(train_data),
valid_acc/len(valid_data), time.time()-tic))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment