blob: 87ffc4b93b766923b58745f59ee53f8fe810dc6a [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import torch
import torch.nn.quantized.functional as qF
from hypothesis import assume, given
from hypothesis import strategies as st
from hypothesis_utils import qtensors_conv
from common_utils import TestCase, run_tests
class FunctionalAPITest(TestCase):
"""Computes the output shape given convolution parameters."""
def _conv_output_shape(self, input_size, kernel_size, padding, stride,
dilation):
return np.floor((input_size + 2 * padding - kernel_size
- (kernel_size - 1) * (dilation - 1)) / stride) + 1
@given(Q=qtensors_conv(min_batch=1, max_batch=3,
min_in_channels=1, max_in_channels=7,
min_out_channels=1, max_out_channels=7,
H_range=(6, 12), W_range=(6, 12),
kH_range=(3, 5), kW_range=(3, 5),
dtypes=((torch.quint8, np.uint8, 0),),
max_groups=4),
padH=st.integers(1, 3), padW=st.integers(1, 3),
sH=st.integers(1, 3), sW=st.integers(1, 3),
dH=st.integers(1, 2), dW=st.integers(1, 2),
prepacked=st.booleans())
def test_conv_api(self, Q, padH, padW, sH, sW, dH, dW, prepacked):
"""Tests the correctness of the conv functional.
The correctness is defined by the behavior being similar to the
`quantized._ops` implementation.
"""
# Random iunputs
X, (scale, zero_point), (qmin, qmax), (torch_type, np_type) = Q
(inputs, filters, bias, groups) = X
iC, oC = inputs.shape[1], filters.shape[0]
iH, iW = inputs.shape[2:]
kH, kW = filters.shape[2:]
assume(kH // 2 >= padH)
assume(kW // 2 >= padW)
oH = self._conv_output_shape(iH, kH, padH, sH, dH)
assume(oH > 0)
oW = self._conv_output_shape(iW, kW, padW, sW, dW)
assume(oW > 0)
inputs = torch.from_numpy(inputs).to(torch.float)
filters = torch.from_numpy(filters).to(torch.float)
bias = torch.from_numpy(bias).to(torch.float)
kernel_size = (kH, kW)
stride = (sH, sW)
i_padding = (padH, padW)
dilation = (dH, dW)
# Quantized inputs
i_NHWC = inputs.permute([0, 2, 3, 1]).contiguous()
w_RSCK = filters.permute([0, 2, 3, 1]).contiguous()
q_inputs = torch.quantize_linear(i_NHWC, scale, zero_point, torch.quint8)
q_filters = torch.quantize_linear(w_RSCK, scale, zero_point, torch.qint8)
q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack(q_filters,
groups)
q_bias = torch.quantize_linear(bias, scale, zero_point, torch.qint32)
# Reference op
ref_op = torch.ops.quantized.fbgemm_conv2d
# Results check
try:
ref_result = ref_op(q_inputs, q_filters_ref, q_bias, stride,
i_padding, dilation,
groups, scale, zero_point)
except RuntimeError as e:
e_msg = str(e).split("\n")[0].split("(")[0].strip()
np.testing.assert_raises_regex(
type(e), e_msg, qF.conv2d,
q_inputs, q_filters_ref, bias=q_bias,
scale=scale, zero_point=zero_point,
stride=stride, padding=i_padding, dilation=dilation,
groups=groups, prepacked=True, dtype=torch_type)
else:
if prepacked:
q_filters = torch.ops.quantized.fbgemm_conv_prepack(q_filters,
groups)
q_result = qF.conv2d(q_inputs, q_filters, bias=q_bias,
scale=scale, zero_point=zero_point,
stride=stride, padding=i_padding,
dilation=dilation, groups=groups,
prepacked=prepacked, dtype=torch_type)
np.testing.assert_equal(ref_result.int_repr().numpy(),
q_result.int_repr().numpy())
if __name__ == "__main__":
run_tests()