- A simple example of how const prop works
def foo(x):
a = 1 + 2
b = a + 3
c = b + 4
return x + c
jit_foo = torch.jit.script(foo)
# `module.graph` refers to the unoptimized graph
print(jit_foo.graph)
# `with torch.no_grad(): module.graph_for(input)` refers to the optimized graph
with torch.no_grad():
print(jit_foo.graph_for(torch.rand(10)))
- A case that shows how const prop works in pyrys
First, we slightly modify the test_conv_bn_relu() in test_pass.py
#Conv+Bn+Relu
def test_conv_bn_relu():
ConvBnRelu = ScriptedConv2dBnRelu(3, 32, kernel_size = 3, stride = 1)
pyrys._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'weight')
pyrys._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'bias')
pyrys._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'running_mean')
pyrys._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'running_var')
pyrys._jit_pass_freeze_flags(ConvBnRelu._c, 'forward', 'training', False)
x = torch.rand((1, 3, 8, 8))
with torch.no_grad():
print('Conv2d+BatchNorm2d+Relu Graph:\n', ConvBnRelu.graph_for(x))
ConvBnRelu(x)
Then we can compare the difference between enabling/disabling constprop in fusion_pass.
// disable constprop in fusion_pass
Conv2d+BatchNorm2d+Relu Graph:
graph(%self : ClassType<ScriptedConv2dBnRelu>,
%x.1 : Float(*, *, *, *)):
%328 : int[] = prim::Constant[value=[0, 0]]()
%327 : int[] = prim::Constant[value=[1, 1]]()
%3 : float = prim::Constant[value=0.001]() # /home/pinzhenxu/pytorch_llga/torch/nn/modules/batchnorm.py:79:40
%4 : None = prim::Constant() # /home/pinzhenxu/pytorch_llga/torch/nn/modules/conv.py:339:36
%23 : int = prim::Constant[value=1]() # /home/pinzhenxu/pytorch_llga/torch/nn/modules/conv.py:336:46
%343.weight.1 : Float(32, 3, 3, 3) = prim::Constant[value=<Tensor>]()
%346.running_mean : Float(32) = prim::Constant[value=<Tensor>]()
%347.running_var : Float(32) = prim::Constant[value=<Tensor>]()
%329 : int[] = prim::Constant[value=[1, 1]]()
%343.weight.1.bn_folded : Float(32, 3, 3, 3) = dnnl::fold_weight(%343.weight.1, %347.running_var, %347.running_var, %3)
%4.bn_folded : None = dnnl::fold_bias(%343.weight.1.bn_folded, %4, %347.running_var, %346.running_mean, %346.running_mean, %347.running_var, %3)
%343 : Float(*, *, *, *) = dnnl::conv2d_relu[format_info=[1, 81, 1, 1, 1, 1, 1, 1]](%x.1, %343.weight.1.bn_folded, %4.bn_folded, %327, %328, %329, %23)
%result.1.reorder : Tensor = dnnl::reorder[format_info=[1, 7], group_info=1](%343)
return (%result.1.reorder)
// enable constprop in fusion_pass
Conv2d+BatchNorm2d+Relu Graph:
graph(%self : ClassType<ScriptedConv2dBnRelu>,
%x.1 : Float(*, *, *, *)):
%328 : int[] = prim::Constant[value=[0, 0]]()
%327 : int[] = prim::Constant[value=[1, 1]]()
%23 : int = prim::Constant[value=1]() # /home/pinzhenxu/pytorch_llga/torch/nn/modules/conv.py:336:46
%329 : int[] = prim::Constant[value=[1, 1]]()
%344 : Float(32, 3, 3, 3) = prim::Constant[value=<Tensor>]()
%345 : Float(32) = prim::Constant[value=<Tensor>]()
%343 : Float(*, *, *, *) = dnnl::conv2d_relu[format_info=[1, 81, 1, 1, 1, 1, 1, 1]](%x.1, %344, %345, %327, %328, %329, %23)
%result.1.reorder : Tensor = dnnl::reorder[format_info=[1, 7], group_info=1](%343)
return (%result.1.reorder)
- Intro to conv-bn fusion
Here we fuse conv and batchnorm into one op. Conv-BN fusion is a commonly used technique in CNN inference optimization. In most CNN networks, we uses batchnorm to speed up the training. But it won't do us any good in inference-only scenarios but hurt the performance because it's a memory bound operator that requires lots of I/O.
Thanks to its algebraic properties, we can fuse the parameters of batchnorm (running_mean & running_var) into its previous convolution's parameters (weight, bias). i.e.
x -> conv -> y -> bn -> z
becomes
x -> conv -> z
If you're interested in the math details about it, please refer to https://tehnokv.com/posts/fusing-batchnorm-and-conv/ The math part of this article is correct, however, the python code is problematic. Check this out for the correct python code: intel/webml-polyfill#240 (comment)
Question (don't have to send me the answer):
-
What does
_jit_pass_freeze_params
and_jit_pass_freeze_flags
do? What if we do not_jit_pass_freeze_flags at all
? -
How do we fuse the conv and bn in pyrys? That is, how do we tranform a tedious JIT graph from
x -> conv -> y -> bn -> z
tox -> conv -> z
? You may inspect the intermediate graphs after each pass to understand the transformation process.