| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| import unittest |
| import numpy as np |
| import torch |
| import torch.nn.quantized.functional as qF |
| |
| from hypothesis import assume, given |
| from hypothesis import strategies as st |
| import hypothesis_utils as hu |
| |
| from common_quantized import _conv_output_shape |
| from common_utils import TestCase, run_tests |
| |
| @unittest.skipIf( |
| not torch.fbgemm_is_cpu_supported(), |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.", |
| ) |
| class QuantizedConvTest(TestCase): |
| @given(X=hu.tensor_conv2d(min_batch=1, max_batch=3, |
| min_in_channels=1, max_in_channels=7, |
| min_out_channels=1, max_out_channels=7, |
| H_range=(6, 12), W_range=(6, 12), |
| kH_range=(3, 5), kW_range=(3, 5), |
| max_groups=4, |
| qparams=[hu.qparams(dtypes=torch.quint8, |
| zero_point_min=0, |
| zero_point_max=0), |
| hu.qparams(dtypes=torch.qint8, |
| zero_point_min=0, |
| zero_point_max=0), |
| hu.qparams(dtypes=torch.qint32, |
| zero_point_min=0, |
| zero_point_max=0)]), |
| padH=st.integers(1, 3), padW=st.integers(1, 3), |
| sH=st.integers(1, 3), sW=st.integers(1, 3), |
| dH=st.integers(1, 2), dW=st.integers(1, 2)) |
| def test_conv_api(self, X, padH, padW, sH, sW, dH, dW): |
| """Tests the correctness of the conv functional. |
| |
| The correctness is defined by the behavior being similar to the |
| `quantized._ops` implementation. |
| """ |
| # Random inputs |
| # X, (scale, zero_point, torch_type) = X |
| (inputs, filters, bias, groups) = X |
| inputs, (inputs_scale, inputs_zero_point, inputs_qtype) = inputs |
| filters, (filters_scale, filters_zero_point, filters_qtype) = filters |
| bias, (bias_scale, bias_zero_point, bias_qtype) = bias |
| |
| scale, zero_point = inputs_scale, inputs_zero_point |
| torch_type = inputs_qtype |
| |
| iC, oC = inputs.shape[1], filters.shape[0] |
| |
| iH, iW = inputs.shape[2:] |
| kH, kW = filters.shape[2:] |
| assume(kH // 2 >= padH) |
| assume(kW // 2 >= padW) |
| oH = _conv_output_shape(iH, kH, padH, sH, dH) |
| assume(oH > 0) |
| oW = _conv_output_shape(iW, kW, padW, sW, dW) |
| assume(oW > 0) |
| |
| inputs = torch.from_numpy(inputs).to(torch.float) |
| filters = torch.from_numpy(filters).to(torch.float) |
| bias = torch.from_numpy(bias).to(torch.float) |
| |
| kernel_size = (kH, kW) |
| stride = (sH, sW) |
| i_padding = (padH, padW) |
| dilation = (dH, dW) |
| |
| # Quantized inputs |
| q_inputs = torch.quantize_linear(inputs, inputs_scale, |
| inputs_zero_point, inputs_qtype) |
| q_filters = torch.quantize_linear(filters, filters_scale, |
| filters_zero_point, filters_qtype) |
| q_bias = torch.quantize_linear(bias, bias_scale, bias_zero_point, |
| bias_qtype) |
| |
| # Reference op |
| ref_op = torch.ops.quantized.fbgemm_conv2d |
| |
| # Results check |
| try: |
| q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack(q_filters.permute([0, 2, 3, 1]), |
| stride, |
| i_padding, |
| dilation, |
| groups) |
| ref_result = ref_op(q_inputs.permute([0, 2, 3, 1]), q_filters_ref, |
| q_bias, stride, |
| i_padding, dilation, |
| groups, scale, zero_point).permute([0, 3, 1, 2]) |
| except RuntimeError as e: |
| e_msg = str(e).split("\n")[0].split("(")[0].strip() |
| np.testing.assert_raises_regex( |
| type(e), e_msg, qF.conv2d, |
| q_inputs, q_filters, bias=q_bias, |
| scale=scale, zero_point=zero_point, |
| stride=stride, padding=i_padding, dilation=dilation, |
| groups=groups, dtype=torch_type) |
| else: |
| q_result = qF.conv2d(q_inputs, |
| q_filters, |
| bias=q_bias, scale=scale, |
| zero_point=zero_point, |
| stride=stride, padding=i_padding, |
| dilation=dilation, groups=groups, |
| dtype=torch_type) |
| |
| self.assertEqual(ref_result, q_result) |
| |
| if __name__ == "__main__": |
| run_tests() |