blob: 4144c07441045494ed169f830eec54ee15098489 [file] [log] [blame]
import math
import torch
import torch.nn as nn
from torch.nn import Conv2d, BatchNorm2d, ReLU, init
from torch.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.nn.modules.utils import _pair
from torch.quantization.qconfig import default_qat_qconfig
import torch.backends.mkldnn
from torch.testing._internal.common_utils import TestCase
from hypothesis import given
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
from functools import reduce
class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
"""
Conv-BN fusion implemented with explicit folding. Useful
to verify numerical equivalency with non-folded version.
"""
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups,
bias,
padding_mode,
# BatchNormNd args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, transposed,
output_padding, groups, False, padding_mode)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.eps = eps
self.momentum = momentum
self.freeze_bn = freeze_bn if self.training else True
self.num_features = out_channels
self.gamma = nn.Parameter(torch.Tensor(out_channels))
self.beta = nn.Parameter(torch.Tensor(out_channels))
self.affine = True
self.track_running_stats = True
self.register_buffer('running_mean', torch.zeros(out_channels))
self.register_buffer('running_var', torch.ones(out_channels))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.activation_post_process = self.qconfig.activation()
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_bn_parameters()
def reset_running_stats(self):
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
def reset_bn_parameters(self):
self.reset_running_stats()
init.uniform_(self.gamma)
init.zeros_(self.beta)
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def reset_parameters(self):
super(_ReferenceConvBnNd, self).reset_parameters()
# A hack to avoid resetting on undefined parameters
if hasattr(self, 'gamma'):
self.reset_bn_parameters()
def update_bn_stats(self):
self.freeze_bn = False
return self
def freeze_bn_stats(self):
self.freeze_bn = True
return self
def _forward(self, input):
# exponential_average_factor is self.momentum set to
# (when it is available) only so that if gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and not self.freeze_bn and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# we use running statistics from the previous batch, so this is an
# approximation of the approach mentioned in the whitepaper, but we only
# need to do one convolution in this case instead of two
running_std = torch.sqrt(self.running_var + self.eps)
scale_factor = self.gamma / running_std
scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1])
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight))
if self.training and not self.freeze_bn:
# recovering original conv to get original batch_mean and batch_var
if self.bias is not None:
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])
else:
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
n = float(conv_orig.numel() / conv_orig.size()[1])
unbiased_batch_var = batch_var * (n / (n - 1))
batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps)
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \
(self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1])
self.running_mean = exponential_average_factor * batch_mean.detach() + \
(1 - exponential_average_factor) * self.running_mean
self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \
(1 - exponential_average_factor) * self.running_var
else:
if self.bias is None:
conv = conv + (self.beta - self.gamma * self.running_mean /
running_std).reshape([1, -1, 1, 1])
else:
conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1])
return conv
def extra_repr(self):
# TODO(jerryzh): extend
return super(_ReferenceConvBnNd, self).extra_repr()
def forward(self, input):
return self.activation_post_process(self._forward(input))
@classmethod
def from_float(cls, mod, qconfig=None):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
if not qconfig:
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
qconfig = mod.qconfig
conv, bn = mod[0], mod[1]
qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
conv.stride, conv.padding, conv.dilation,
conv.groups, conv.bias is not None,
conv.padding_mode,
bn.eps, bn.momentum,
False,
qconfig)
qat_convbn.weight = conv.weight
qat_convbn.bias = conv.bias
qat_convbn.gamma = bn.weight
qat_convbn.beta = bn.bias
qat_convbn.running_mean = bn.running_mean
qat_convbn.running_var = bn.running_var
qat_convbn.num_batches_tracked = bn.num_batches_tracked
return qat_convbn
class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
_FLOAT_MODULE = torch.nn.intrinsic.ConvBn2d
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm2d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
_ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, False, _pair(0), groups, bias, padding_mode,
eps, momentum, freeze_bn, qconfig)
class TestQATModule(TestCase):
@given(batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
use_relu=st.booleans(),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans())
def test_conv_bn_relu(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
use_relu,
eps,
momentum,
freeze_bn
):
# **** WARNING: This is used to temporarily disable MKL-DNN convolution due
# to a bug: https://github.com/pytorch/pytorch/issues/23825
# Once this bug is fixed, this context manager as well as its callsites
# should be removed!
with torch.backends.mkldnn.flags(enabled=False):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
dilation_h = dilation_w = dilation
conv_op = Conv2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
False, # No bias
padding_mode
).to(dtype=torch.double)
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
relu_op = ReLU()
cls = ConvBnReLU2d if use_relu else ConvBn2d
qat_op = cls(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
None, # bias
padding_mode,
eps,
momentum,
freeze_bn=True,
qconfig=default_qat_qconfig
).to(dtype=torch.double)
qat_op.apply(torch.quantization.disable_fake_quant)
if freeze_bn:
qat_op.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
else:
qat_op.apply(torch.nn.intrinsic.qat.update_bn_stats)
# align inputs and internal parameters
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
bn_op.running_mean = qat_op.bn.running_mean.clone()
bn_op.running_var = qat_op.bn.running_var.clone()
bn_op.weight = torch.nn.Parameter(qat_op.bn.weight.detach())
bn_op.bias = torch.nn.Parameter(qat_op.bn.bias.detach())
def compose(functions):
# functions are reversed for natural reading order
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
if not use_relu:
def relu_op(x):
return x
if freeze_bn:
def ref_op(x):
x = conv_op(x)
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
x = relu_op(x)
return x
else:
ref_op = compose([conv_op, bn_op, relu_op])
input_clone = input.clone().detach().requires_grad_()
for i in range(2):
result_ref = ref_op(input)
result_actual = qat_op(input_clone)
self.assertEqual(result_ref, result_actual)
# backward
dout = torch.randn(result_ref.size(), dtype=torch.double)
loss = (result_ref - dout).sum()
loss.backward()
input_grad_ref = input.grad.cpu()
weight_grad_ref = conv_op.weight.grad.cpu()
gamma_grad_ref = bn_op.weight.grad.cpu()
beta_grad_ref = bn_op.bias.grad.cpu()
running_mean_ref = bn_op.running_mean
running_var_ref = bn_op.running_var
num_batches_tracked_ref = bn_op.num_batches_tracked
loss = (result_actual - dout).sum()
loss.backward()
input_grad_actual = input_clone.grad.cpu()
weight_grad_actual = qat_op.weight.grad.cpu()
gamma_grad_actual = qat_op.bn.weight.grad.cpu()
beta_grad_actual = qat_op.bn.bias.grad.cpu()
running_mean_actual = qat_op.bn.running_mean
running_var_actual = qat_op.bn.running_var
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
precision = 1e-10
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
@given(batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans(),
bias=st.booleans())
def test_conv_bn_folded_vs_unfolded(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
eps,
momentum,
freeze_bn,
bias,
):
# **** WARNING: This is used to temporarily disable MKL-DNN convolution due
# to a bug: https://github.com/pytorch/pytorch/issues/23825
# Once this bug is fixed, this context manager as well as its callsites
# should be removed!
with torch.backends.mkldnn.flags(enabled=False):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
dilation_h = dilation_w = dilation
qat_op = ConvBn2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
bias, # bias
padding_mode,
eps,
momentum,
freeze_bn=freeze_bn,
qconfig=default_qat_qconfig
).to(dtype=torch.double)
qat_ref_op = _ReferenceConvBn2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
bias, # bias
padding_mode,
eps,
momentum,
freeze_bn=freeze_bn,
qconfig=default_qat_qconfig
).to(dtype=torch.double)
qat_op.apply(torch.quantization.disable_fake_quant)
qat_ref_op.apply(torch.quantization.disable_fake_quant)
# align inputs and internal parameters
qat_ref_op.weight = torch.nn.Parameter(qat_op.weight.detach().clone())
qat_ref_op.running_mean = qat_op.bn.running_mean.clone()
qat_ref_op.running_var = qat_op.bn.running_var.clone()
qat_ref_op.gamma = torch.nn.Parameter(qat_op.bn.weight.detach().clone())
qat_ref_op.beta = torch.nn.Parameter(qat_op.bn.bias.detach().clone())
if qat_op.bias is not None:
qat_ref_op.bias = torch.nn.Parameter(qat_op.bias.detach().clone())
lr = 0.01
qat_op_optim = torch.optim.SGD(qat_op.parameters(), lr=lr)
qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)
for i in range(5):
# make sure that calling model.train() does not override the
# bn freeze setting
qat_op.train()
qat_ref_op.train()
qat_op_optim.zero_grad()
qat_ref_op_optim.zero_grad()
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
input_clone = input.clone().detach().requires_grad_()
if i > 2:
qat_op.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
qat_ref_op.freeze_bn_stats()
if i > 3:
qat_op.apply(torch.quantization.disable_observer)
qat_ref_op.apply(torch.quantization.disable_observer)
result_ref = qat_ref_op(input)
result_actual = qat_op(input_clone)
self.assertEqual(result_ref, result_actual)
# backward
dout = torch.randn(result_ref.size(), dtype=torch.double) + 10.0
loss = (result_ref - dout).sum()
loss.backward()
input_grad_ref = input.grad.cpu()
weight_grad_ref = qat_ref_op.weight.grad.cpu()
gamma_grad_ref = qat_ref_op.gamma.grad.cpu()
beta_grad_ref = qat_ref_op.beta.grad.cpu()
running_mean_ref = qat_ref_op.running_mean
running_var_ref = qat_ref_op.running_var
num_batches_tracked_ref = qat_ref_op.num_batches_tracked
loss = (result_actual - dout).sum()
loss.backward()
input_grad_actual = input_clone.grad.cpu()
weight_grad_actual = qat_op.weight.grad.cpu()
gamma_grad_actual = qat_op.bn.weight.grad.cpu()
beta_grad_actual = qat_op.bn.bias.grad.cpu()
running_mean_actual = qat_op.bn.running_mean
running_var_actual = qat_op.bn.running_var
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
precision = 1e-5
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
qat_op_optim.step()
qat_ref_op_optim.step()