blob: b65cfbf5eedf25ecb3672f1475ef9cb0955cc0c4 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
from torch.nn import Conv2d, BatchNorm2d, ReLU
from torch.nn._intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.quantization.QConfig import default_qat_qconfig
from torch.nn import Parameter
from torch.utils.mkldnn import disable_mkldnn_conv
from common_quantization import no_deadline
from common_utils import TestCase, run_tests
from hypothesis import given
from hypothesis import strategies as st
from functools import reduce
class IntrinsicQATModuleTest(TestCase):
# NOTE: Tests in this class are decorated with no_deadline
# to prevent spurious failures due to cuda runtime initialization.
@no_deadline
@given(batch_size=st.integers(1, 3),
input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
height=st.integers(10, 16),
width=st.integers(7, 14),
output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 7),
kernel_w=st.integers(1, 7),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
use_relu=st.booleans(),
eps=st.sampled_from([1e-5, 1e-4, 1e-3, 0.01, 0.1]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans())
def test_conv_bn_relu(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
use_relu,
eps,
momentum,
freeze_bn
):
with disable_mkldnn_conv():
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
dilation_h = dilation_w = dilation
conv_op = Conv2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
False, # No bias
padding_mode
).to(dtype=torch.float)
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.float)
relu_op = ReLU()
cls = ConvBnReLU2d if use_relu else ConvBn2d
qat_op = cls(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
padding_mode,
eps,
momentum,
freeze_bn,
default_qat_qconfig
).to(dtype=torch.float).disable_fake_quant()
# align inputs and internal parameters
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.float)
input.requires_grad_()
conv_op.weight = Parameter(qat_op.weight)
bn_op.running_mean = qat_op.running_mean
bn_op.running_var = qat_op.running_var
bn_op.weight = qat_op.gamma
bn_op.bias = qat_op.beta
def compose(functions):
# functions are reversed for natural reading order
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
if not use_relu:
def relu_op(x):
return x
if freeze_bn:
def ref_op(x):
x = conv_op(x)
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
x = relu_op(x)
return x
else:
ref_op = compose([conv_op, bn_op, relu_op])
result_ref = ref_op(input)
result_actual = qat_op(input)
self.assertEqual(result_ref, result_actual)
# backward
dout = torch.randn(result_ref.size(), dtype=torch.float)
result_actual.backward(dout, retain_graph=True)
grad_ref = input.grad.cpu()
result_actual.backward(dout)
grad_actual = input.grad.cpu()
self.assertEqual(grad_ref, grad_actual)
if __name__ == '__main__':
run_tests()