Last active
July 31, 2023 14:05
-
-
Save zshn25/0b7fdab97c3fa06c0bfd1e528c861041 to your computer and use it in GitHub Desktop.
Conv2d and BatchNorm2d fusion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
from torch.nn.utils.fusion import fuse_conv_bn_eval | |
def fuse_all_conv_bn(model): | |
stack = [] | |
for name, module in model.named_children(): # immediate children | |
if list(module.named_children()): # is not empty (not a leaf) | |
fuse_all_conv_bn(module) | |
if isinstance(module, nn.BatchNorm2d): | |
if isinstance(stack[-1][1], nn.Conv2d): | |
setattr(model, stack[-1][0], fuse_conv_bn_eval(stack[-1][1], module)) | |
setattr(model, name, nn.Identity()) | |
else: | |
stack.append((name, module)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
torch | |
torchvision |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment