Skip to content

Instantly share code, notes, and snippets.

@zshn25
Last active July 31, 2023 14:05
Show Gist options
  • Save zshn25/0b7fdab97c3fa06c0bfd1e528c861041 to your computer and use it in GitHub Desktop.
Save zshn25/0b7fdab97c3fa06c0bfd1e528c861041 to your computer and use it in GitHub Desktop.
Conv2d and BatchNorm2d fusion
from torch import nn
from torch.nn.utils.fusion import fuse_conv_bn_eval
def fuse_all_conv_bn(model):
stack = []
for name, module in model.named_children(): # immediate children
if list(module.named_children()): # is not empty (not a leaf)
fuse_all_conv_bn(module)
if isinstance(module, nn.BatchNorm2d):
if isinstance(stack[-1][1], nn.Conv2d):
setattr(model, stack[-1][0], fuse_conv_bn_eval(stack[-1][1], module))
setattr(model, name, nn.Identity())
else:
stack.append((name, module))
torch
torchvision
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment