Created
September 11, 2017 22:21
-
-
Save lantiga/15ba60f6dbdbc99873f0af94761e9630 to your computer and use it in GitHub Desktop.
PyTorch namespaces tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
from graphviz import Digraph | |
def name(node, annotation=None): | |
kind = node.kind() | |
if kind in ['PythonOp', 'CppOp']: | |
return node.blockName() + node.name() | |
ann = annotation or node.kind() | |
try: | |
if node.kind() == 'Return': | |
content = ','.join([str(n.type()) for n in node.inputs()]) | |
else: | |
content = str(node.type()) | |
node_name = '%s:%s' % (ann, content) | |
except: | |
node_name = ann | |
return node.blockName() + node_name | |
def searchOpsRecFwd(ops, visited, nodes): | |
for i in nodes: | |
if i in visited: | |
continue | |
visited.add(i) | |
if i.kind() in ['PythonOp', 'CppOp']: | |
ops.add(i) | |
elif i.kind() == 'Select': | |
searchOpsRecFwd(ops, visited, [u.user for u in i.uses()]) | |
else: | |
ops.add(i) | |
def searchDownstreamOps(node): | |
ops = set() | |
visited = set() | |
searchOpsRecFwd(ops, visited, [u.user for u in node.uses()]) | |
return ops | |
def addNode(dot, node, node_name=''): | |
dot.node(str(id(node)), name(node, node_name)) | |
def addEdge(dot, n1, n2): | |
dot.edge(str(id(n1)), str(id(n2))) | |
def make_dot(g, input_names, show_params=False): | |
node_attr = dict(style='filled', | |
shape='box', | |
align='left', | |
fontsize='12', | |
ranksep='0.1', | |
height='0.2') | |
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) | |
seen = set() | |
ops = [] | |
for node in g.nodes(): | |
if node.kind() not in ['PythonOp', 'CppOp']: | |
continue | |
dot.node(str(id(node)), name(node)) | |
ops.append(node) | |
input_dict = dict(zip(g.inputs()[-len(input_names):], input_names)) | |
for node in ops: | |
downstream = searchDownstreamOps(node) | |
for n in downstream: | |
addNode(dot, n) | |
addEdge(dot, node, n) | |
for n in node.inputs(): | |
if n in input_dict.keys(): | |
addNode(dot, n, 'Input$' + input_dict[n]) | |
addEdge(dot, n, node) | |
elif show_params: | |
if n.kind() != 'Select': | |
addNode(dot, n) | |
addEdge(dot, n, node) | |
elif n.kind() == 'Constant': | |
addNode(dot, n) | |
addEdge(dot, n, node) | |
return dot | |
def test1(): | |
class MyModule(nn.Module): | |
def forward(self, x): | |
t = x + 1 | |
r = t * 2 | |
return r | |
class MyNet(nn.Module): | |
def __init__(self): | |
super(MyNet, self).__init__() | |
self.module1 = MyModule() | |
self.module2 = MyModule() | |
def forward(self, x): | |
y = self.module1(x) | |
torch._tracing_state.push_block('Foo') | |
t = y + x | |
torch._tracing_state.pop_block() | |
return self.module2(t) | |
def doit(x): | |
a = MyNet() | |
return a(x) | |
t = Variable(torch.ones(1), requires_grad=True) | |
traced, _ = torch.jit.record_trace(doit, t) | |
g = torch._C._jit_get_graph(traced) | |
print(g) | |
d = make_dot(g, ['t'], show_params=True) | |
d.view() | |
def test2(): | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.layer1 = nn.Sequential(nn.Linear(2,2), nn.ReLU()) | |
def forward(self, x): | |
return self.layer1(x) | |
net = Net() | |
t = Variable(torch.ones(2), requires_grad=True) | |
traced, _ = torch.jit.record_trace(net, t) | |
g = torch._C._jit_get_graph(traced) | |
print(g) | |
d = make_dot(g, ['t'], show_params=True) | |
d.view() | |
def test3(): | |
from torchvision import models | |
inputs = Variable(torch.randn(1,3,224,224)) | |
resnet18 = models.resnet18() | |
traced, _ = torch.jit.record_trace(resnet18, inputs) | |
g = torch._C._jit_get_graph(traced) | |
print(g) | |
d = make_dot(g, ['inputs']) | |
d.view() | |
def test4(): | |
x = Variable(torch.Tensor([0.4]), requires_grad=True) | |
y = Variable(torch.Tensor([0.7]), requires_grad=True) | |
def doit(x, y): | |
torch._tracing_state.push_block('Foo') | |
z = Variable(torch.Tensor([0.7]), requires_grad=True) | |
out = torch.sigmoid(torch.tanh(x * (y + z))) | |
torch._tracing_state.pop_block() | |
return out | |
traced, _ = torch.jit.record_trace(doit, x, y) | |
g = torch._C._jit_get_graph(traced) | |
print(g) | |
d = make_dot(g, ['x', 'y']) | |
d.view() | |
test1() | |
test2() | |
test3() | |
test4() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Test1
Test2
Test3
Test4