Skip to content

Instantly share code, notes, and snippets.

@amohant4
Created July 17, 2019 06:01
Show Gist options
  • Save amohant4/767511778c45e38457ab7e660bd37f3c to your computer and use it in GitHub Desktop.
Save amohant4/767511778c45e38457ab7e660bd37f3c to your computer and use it in GitHub Desktop.
Example of creating a network using nn.Block class
from mxnet import nn
class myLeNet(nn.Block):
def __init__(self, **kwargs):
super(myLeNet, self).__init__(**kwargs)
self.conv1 = nn.Conv2D(channels=6,kernel_size=5,activation='relu')
self.pool1 = nn.MaxPool2D(pool_size=2,strides=2)
self.conv2 = nn.Conv2D(channels=16,kernel_size=3,activation='relu')
self.pool1 = nn.MaxPool2D(pool_size=2,strides=2)
self.fc1 = nn.Dense(120, activation='relu')
self.fc2 = nn.Dense(84, activation='relu')
self.fc3 = nn.Dense(10)
def forward(self,x):
return self.fc3(self.fc2(self.fc1(self.pool1(self.conv2(self.pool1(self.conv1(x)))))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment