blob: eba0ffb697dbb03708e4ae51837317271b15d033 [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.nn._intrinsic.modules.fused as torch_fused
def fuse_conv_bn(conv, bn):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training),\
"Conv and BN both must be in the same mode (train or eval)."
if conv.training:
return torch.nn._intrinsic.ConvBn2d(conv, bn)
else:
return torch.nn.utils.fuse_conv_bn_eval(conv, bn)
def fuse_conv_bn_relu(conv, bn, relu):
r"""Given the conv and bn modules, fuses them and returns the fused module
Args:
conv: Module instance of type conv2d
bn: Spatial BN instance that needs to be fused with the conv
Examples::
>>> m1 = nn.Conv2d(10, 20, 3)
>>> b1 = nn.BatchNorm2d(20)
>>> m2 = fuse_conv_bn(m1, b1)
"""
assert(conv.training == bn.training == relu.training),\
"Conv and BN both must be in the same mode (train or eval)."
if conv.training:
return torch_fused.ConvBnReLU2d(conv, bn, relu)
else:
return torch_fused.ConvReLU2d(
torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
def _fuse_modules(model, named_module_dict, modules_to_fuse, fuser_func=None):
assert(len(modules_to_fuse) == 2 or len(modules_to_fuse) == 3),\
"Can fuse only 2 or 3 modules."
OP_LIST_TO_FUSER_FUNC = {
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn,
(torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu
}
mod = []
parent_mod = []
for i in range(len(modules_to_fuse)):
parent_module_name = '.'.join(modules_to_fuse[i].split('.')[:-1])
mod.append(named_module_dict[modules_to_fuse[i]])
parent_mod.append(named_module_dict.get(parent_module_name, model))
new_mod = mod[0]
if fuser_func is None:
types = tuple(type(m) for m in mod)
fuser_func = OP_LIST_TO_FUSER_FUNC.get(types, None)
if fuser_func is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
new_mod = fuser_func(*mod)
# Assign new_mod to module and set remaining modules to identity
if new_mod is not mod[0]:
setattr(parent_mod[0], modules_to_fuse[0].split('.')[-1], new_mod)
for i in range(1, len(modules_to_fuse)):
setattr(parent_mod[i], modules_to_fuse[i].split('.')[-1], torch.nn.Identity())
def fuse_modules(model, modules_to_fuse):
r"""Fuses a list of modules into a single module
Fuses only the following sequence of modules:
conv, bn
conv, bn, relu
All other sequences are left unchanged.
For these sequences, replaces the first item in the list
with the fused module, replacing the rest of the modules
with identity.
Arguments:
model: Model containing the modules to be fused
modules_to_fuse: list of list of module names to fuse.
Returns:
Modifies the model in place.
Examples::
>>> m = myModel()
>>> # m is a module containing the sub-modules below
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
>>> nn.quantization.fuse_module.fuse_module(m, modules_to_fuse)
>>> output = m(input)
"""
named_module_dict = {name: mod for name, mod in model.named_modules()}
for module_list in modules_to_fuse:
_fuse_modules(model, named_module_dict, module_list)