This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Iterable | |
#from collections import Iterable # < py38 | |
def flatten(items): | |
"""Yield items from any nested iterable; | |
Source: https://stackoverflow.com/a/40857703 | |
Usage: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
from torch.nn.utils.fusion import fuse_conv_bn_eval | |
def fuse_all_conv_bn(model): | |
stack = [] | |
for name, module in model.named_children(): # immediate children | |
if list(module.named_children()): # is not empty (not a leaf) | |
fuse_all_conv_bn(module) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
torch |