In dependency.py
, add a new update fn in DependencyGraph
def _update_glu_index_mapping(self, glu_node: Node):
if glu_node.type != ops.OPTYPE.GLU:
return
# GLU halves the number of channels by applying sigmoid
input_node = glu_node.inputs[0]
in_channels = self.get_out_channels(input_node)
out_channels = in_channels // 2
# TODO: Need to check
for i, in_node in enumerate(glu_node.inputs):
for dep in in_node.dependencies:
if dep.target == glu_node:
dep.index_mapping[0] = (_helpers._GLUIndexMapping(out_channels))
# TODO: Need to check
for i, out_node in enumerate(glu_node.outputs):
for dep in out_node.dependencies:
if dep.target == glu_node:
dep.index_mapping[0] = (_helpers._GLUIndexMapping(out_channels))
Call it when update_index_mapping
is executed
def update_index_mapping(self):
for module, node in self.module2node.items():
...
if node.type == ops.OPTYPE.GLU:
self._update_glu_index_mapping(node)
Add GLU in ops.py
class GLUPruner(DummyPruner):
pass
# Standard Modules
TORCH_CONV = nn.modules.conv._ConvNd
...
TORCH_GLU = nn.GLU
class OPTYPE(IntEnum):
CONV = 0
...
GLU = 18 # nn.GLU
def module2type(module):
...
elif isinstance(module, TORCH_GLU):
return OPTYPE.GLU
def type2class(op_type):
...
elif op_type == OPTYPE.GLU:
return TORCH_GLU
Add GLU Index mapping in helpers:
class _GLUIndexMapping(object):
def __init__(self, out_channels):
self.out_channels = out_channels
def __call__(self, idxs: _HybridIndex):
# TOOD: Update this
return [ _HybridIndex(idx=i.idx % self.out_channels, root_idx=i.root_idx) for i in idxs]