blob: f9f02e1d782b06c923d744e4d63baa4854893d36 [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import torch
def fuse_conv_bn_eval(conv, bn):
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
w_conv = fused_conv.weight
b_conv = fused_conv.bias
bn_mean = bn.running_mean
bn_var_sqrt = torch.sqrt(bn.running_var + bn.eps)
bn_weight = bn.weight
bn_bias = bn.bias
if b_conv is None:
b_conv = bn_mean.new_zeros(bn_mean.shape)
w_conv = w_conv * (bn_weight / bn_var_sqrt).reshape([-1, 1, 1, 1])
b_conv = (b_conv - bn_mean) / bn_var_sqrt * bn_weight + bn_bias
fused_conv.weight = torch.nn.Parameter(w_conv)
fused_conv.bias = torch.nn.Parameter(b_conv)
return fused_conv