| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| import torch |
| from torch.nn import Conv2d, BatchNorm2d, ReLU |
| from torch.nn._intrinsic.qat import ConvBn2d, ConvBnReLU2d |
| from torch.quantization.QConfig import default_qat_qconfig |
| from torch.nn import Parameter |
| from torch.utils.mkldnn import disable_mkldnn_conv |
| from common_quantization import no_deadline |
| from common_utils import TestCase, run_tests |
| from hypothesis import given |
| from hypothesis import strategies as st |
| from functools import reduce |
| |
| |
| class IntrinsicQATModuleTest(TestCase): |
| # NOTE: Tests in this class are decorated with no_deadline |
| # to prevent spurious failures due to cuda runtime initialization. |
| |
| @no_deadline |
| @given(batch_size=st.integers(1, 3), |
| input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), |
| height=st.integers(10, 16), |
| width=st.integers(7, 14), |
| output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), |
| groups=st.integers(1, 3), |
| kernel_h=st.integers(1, 7), |
| kernel_w=st.integers(1, 7), |
| stride_h=st.integers(1, 2), |
| stride_w=st.integers(1, 2), |
| pad_h=st.integers(0, 2), |
| pad_w=st.integers(0, 2), |
| dilation=st.integers(1, 1), |
| padding_mode=st.sampled_from(['zeros', 'circular']), |
| use_relu=st.booleans(), |
| eps=st.sampled_from([1e-5, 1e-4, 1e-3, 0.01, 0.1]), |
| momentum=st.sampled_from([0.1, 0.2, 0.3]), |
| freeze_bn=st.booleans()) |
| def test_conv_bn_relu( |
| self, |
| batch_size, |
| input_channels_per_group, |
| height, |
| width, |
| output_channels_per_group, |
| groups, |
| kernel_h, |
| kernel_w, |
| stride_h, |
| stride_w, |
| pad_h, |
| pad_w, |
| dilation, |
| padding_mode, |
| use_relu, |
| eps, |
| momentum, |
| freeze_bn |
| ): |
| with disable_mkldnn_conv(): |
| input_channels = input_channels_per_group * groups |
| output_channels = output_channels_per_group * groups |
| dilation_h = dilation_w = dilation |
| |
| conv_op = Conv2d( |
| input_channels, |
| output_channels, |
| (kernel_h, kernel_w), |
| (stride_h, stride_w), |
| (pad_h, pad_w), |
| (dilation_h, dilation_w), |
| groups, |
| False, # No bias |
| padding_mode |
| ).to(dtype=torch.float) |
| bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.float) |
| relu_op = ReLU() |
| |
| cls = ConvBnReLU2d if use_relu else ConvBn2d |
| qat_op = cls( |
| input_channels, |
| output_channels, |
| (kernel_h, kernel_w), |
| (stride_h, stride_w), |
| (pad_h, pad_w), |
| (dilation_h, dilation_w), |
| groups, |
| padding_mode, |
| eps, |
| momentum, |
| freeze_bn, |
| default_qat_qconfig |
| ).to(dtype=torch.float).disable_fake_quant() |
| |
| # align inputs and internal parameters |
| input = torch.randn(batch_size, input_channels, height, width, dtype=torch.float) |
| input.requires_grad_() |
| conv_op.weight = Parameter(qat_op.weight) |
| bn_op.running_mean = qat_op.running_mean |
| bn_op.running_var = qat_op.running_var |
| bn_op.weight = qat_op.gamma |
| bn_op.bias = qat_op.beta |
| |
| def compose(functions): |
| # functions are reversed for natural reading order |
| return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x) |
| |
| if not use_relu: |
| def relu_op(x): |
| return x |
| |
| if freeze_bn: |
| def ref_op(x): |
| x = conv_op(x) |
| x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \ |
| (bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \ |
| .reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1]) |
| x = relu_op(x) |
| return x |
| else: |
| ref_op = compose([conv_op, bn_op, relu_op]) |
| |
| result_ref = ref_op(input) |
| result_actual = qat_op(input) |
| self.assertEqual(result_ref, result_actual) |
| |
| # backward |
| dout = torch.randn(result_ref.size(), dtype=torch.float) |
| result_actual.backward(dout, retain_graph=True) |
| grad_ref = input.grad.cpu() |
| result_actual.backward(dout) |
| grad_actual = input.grad.cpu() |
| self.assertEqual(grad_ref, grad_actual) |
| |
| if __name__ == '__main__': |
| run_tests() |