Quantized Conv2d operator (#20772)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20772

Copy of D15178352

A conflicting commit landed at the same time as D15178352 that removed registering kernels using IntArrayRef, Hence, D15178352 was revered. Using std::vector instead.

Reviewed By: zafartahirov

Differential Revision: D15437237

fbshipit-source-id: cd2f1caebcc720352b48ce25d716cb1ca49a5197
diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
index 0569296..498ce8a 100644
--- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
+++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
@@ -20,6 +20,14 @@
   int w_zp;
 };
 
+struct FBGEMM_API PackedConvWeight {
+  std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w;
+  std::vector<int32_t> col_offsets;
+  std::vector<int32_t> kernel;
+  float w_scale;
+  int32_t w_zp;
+};
+
 // Convert the weight from uint8 to int8.
 static void convert_uint8_int8(
     int K,
diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp
new file mode 100644
index 0000000..7675a1c
--- /dev/null
+++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp
@@ -0,0 +1,156 @@
+#include <ATen/ATen.h>
+#include <ATen/core/Type.h>
+#include <ATen/core/op_registration/op_registration.h>
+#include <ATen/cpp_custom_type_hack.h>
+#include <ATen/native/quantized/cpu/fbgemm_utils.h>
+#include <ATen/quantized/Quantizer.h>
+
+namespace at {
+namespace native {
+namespace {
+class QConv2dInt8 final : public c10::OperatorKernel {
+ public:
+#ifdef USE_FBGEMM
+  Tensor operator()(
+      Tensor act,
+      Tensor packed_weight,
+      Tensor bias,
+      const std::vector<int64_t>& stride,
+      const std::vector<int64_t>& padding,
+      const std::vector<int64_t>& dilation,
+      const std::vector<int64_t>& output_padding,
+      int64_t groups,
+      double output_scale,
+      int64_t output_zero_point) {
+    TORCH_CHECK(
+        fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
+    TORCH_CHECK(
+        act.ndimension() == 4,
+        "Activations are supposed to have 4 dimensions.");
+    TORCH_CHECK(stride.size() == 2, "2D convolution only");
+    TORCH_CHECK(padding.size() == 2, "2D convolution only");
+    TORCH_CHECK(dilation.size() == 2, "2D convolution only");
+    TORCH_CHECK(output_padding.size() == 2, "2D convolution only");
+    TORCH_CHECK(
+        (dilation[0] == 1 && dilation[1] == 1),
+        "Currently dilation should be 1");
+    TORCH_CHECK(
+        (output_padding[0] == 0 && output_padding[1] == 0),
+        "Currently output padding should be 0");
+
+    // inputs are in NHWC format
+    int N = act.size(0);
+    int H = act.size(1);
+    int W = act.size(2);
+    int C = act.size(3);
+    int K = bias.size(0);
+
+    Tensor act_contig = act.contiguous();
+    const uint8_t* act_ptr =
+        reinterpret_cast<uint8_t*>(act_contig.data<c10::quint8>());
+
+    PackedConvWeight& pack_ptr =
+        cpp_custom_type_hack::cast<PackedConvWeight>(packed_weight);
+    auto packB = pack_ptr.w.get();
+    // packB->printPackedMatrix("PackedB inside QConv2dInt8:");
+    auto& col_offsets = pack_ptr.col_offsets;
+    auto& kernel = pack_ptr.kernel;
+
+    std::vector<int32_t> row_offset_buf(
+        fbgemm::PackAWithIm2Col<uint8_t>::rowOffsetBufferSize());
+
+    int pad_l = padding[0];
+    int pad_t = padding[1];
+    int stride_h = stride[0];
+    int stride_w = stride[1];
+    int kernel_h = kernel[0];
+    int kernel_w = kernel[1];
+
+    fbgemm::conv_param_t<> conv_p(
+        N, // Batch size
+        C, // Number of input channels
+        K, // Number of output channels
+        {H, W},
+        groups,
+        {kernel_h, kernel_w},
+        {stride_h, stride_w},
+        {pad_l, pad_t, pad_l, pad_t});
+
+    fbgemm::PackAWithIm2Col<uint8_t> packA(
+        conv_p,
+        act_ptr,
+        nullptr,
+        act.q_zero_point().toInt(),
+        row_offset_buf.data());
+
+    fbgemm::DoNothing<> NoOpObj{};
+
+    auto bias_contig = bias.contiguous();
+
+    float act_scale = act.q_scale().toFloat();
+    int32_t act_zero_point = act.q_zero_point().toInt();
+
+    float weight_scale_float = pack_ptr.w_scale;
+    int32_t weight_zero_point_int32 = pack_ptr.w_zp;
+
+    float output_multiplier_float =
+        (act_scale * weight_scale_float) / static_cast<float>(output_scale);
+
+    fbgemm::ReQuantizeOutput<false> outputProcObj(
+        NoOpObj,
+        &output_multiplier_float,
+        output_zero_point,
+        act_zero_point,
+        &weight_zero_point_int32,
+        packA.getRowOffsetBuffer(),
+        col_offsets.data(),
+        bias_contig.data<int32_t>(),
+        K,
+        groups);
+
+    Tensor output = _empty_affine_quantized(
+        {N, H, W, K},
+        device(kCPU).dtype(kQUInt8),
+        output_scale,
+        output_zero_point);
+    auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
+
+    // Do the GEMM
+    fbgemm::fbgemmPacked(
+        packA,
+        *packB,
+        reinterpret_cast<uint8_t*>(output.data<c10::quint8>()),
+        buffer.data<int32_t>(),
+        K,
+        outputProcObj,
+        0 /* thread_id*/,
+        1 /* num_threads */);
+
+    return output;
+  }
+#else // USE_FBGEMM
+  Tensor operator()(
+      Tensor /* activation */,
+      Tensor /* packed_weight */,
+      Tensor /* bias */,
+      const std::vector<int64_t>& /* stride */,
+      const std::vector<int64_t>& /* padding */,
+      const std::vector<int64_t>& /* dilation */,
+      const std::vector<int64_t>& /* output padding */,
+      int64_t /* groups */,
+      double /* output scale */,
+      int64_t /* output_zero_point */) {
+    TORCH_CHECK(
+        false, "This PyTorch installation was not built with FBGEMM operators");
+  }
+#endif // USE_FBGEMM
+};
+
+static auto registry = c10::RegisterOperators().op(
+    "quantized::fbgemm_conv2d",
+    c10::RegisterOperators::options().kernel<QConv2dInt8>().dispatchKey(
+        QuantizedCPUTensorId()));
+
+} // namespace
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
new file mode 100644
index 0000000..f009a6a
--- /dev/null
+++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
@@ -0,0 +1,76 @@
+#include <ATen/ATen.h>
+#include <ATen/core/Type.h>
+#include <ATen/core/op_registration/op_registration.h>
+#include <ATen/cpp_custom_type_hack.h>
+#include <ATen/native/quantized/cpu/fbgemm_utils.h>
+#include <ATen/quantized/Quantizer.h>
+
+namespace caffe2 {
+#ifdef USE_FBGEMM
+// Required for cpp_custom_type_hack to work
+CAFFE_KNOWN_TYPE(PackedConvWeight);
+#endif
+} // namespace caffe2
+
+namespace at {
+namespace native {
+namespace {
+class QConvPackWeightInt8 final : public c10::OperatorKernel {
+ public:
+#ifdef USE_FBGEMM
+  Tensor operator()(Tensor weight, int64_t groups) {
+    TORCH_CHECK(
+        weight.ndimension() == 4, "Weights are expected to have 4 dimensions");
+    TORCH_CHECK(groups == 1, "Groupwise convolutions are not supported yet");
+    // weights in RS(C/G)K format
+    // matrix dimensions after im2col
+    int NDim = weight.size(3) / groups;
+    int KDim = weight.size(0) * weight.size(1) * groups * weight.size(2);
+    auto weight_config = weight.contiguous();
+    int weight_zero_point_int32 = weight.q_zero_point().toInt();
+    TORCH_CHECK(
+        weight_zero_point_int32 == 0,
+        "Only symmetric quantization is supported for weights yet");
+    const int8_t* weight_ptr_int8 =
+        reinterpret_cast<int8_t*>(weight_config.data<c10::quint8>());
+
+    std::vector<int32_t> col_offsets(NDim * groups);
+    std::vector<int32_t> kernel{static_cast<int>(weight.size(0)),
+                                static_cast<int>(weight.size(1))};
+    std::vector<int8_t> weight_int8(KDim * NDim * groups);
+    auto ret_ptr = guts::make_unique<PackedConvWeight>(
+        PackedConvWeight{guts::make_unique<fbgemm::PackBMatrix<int8_t>>(
+                             fbgemm::matrix_op_t::NoTranspose,
+                             KDim,
+                             NDim,
+                             weight_ptr_int8,
+                             NDim,
+                             nullptr, // PackBMatrix manages ownership of pmat
+                             groups),
+                         col_offsets,
+                         kernel,
+                         weight.q_scale().toFloat(),
+                         weight_zero_point_int32});
+    // TODO: we will need to replace this with torchscript classes at a later
+    // point.
+    return cpp_custom_type_hack::create(std::move(ret_ptr), weight.options());
+  }
+#else // USE_FBGEMM
+  Tensor operator()(
+      Tensor, /* weight */
+      int64_t /* groups */
+  ) {
+    TORCH_CHECK(
+        false, "This PyTorch installation was not built with FBGEMM operators");
+  }
+#endif // USE_FBGEMM
+};
+
+static auto registry = c10::RegisterOperators().op(
+    "quantized::fbgemm_conv_prepack",
+    c10::RegisterOperators::options().kernel<QConvPackWeightInt8>().dispatchKey(
+        QuantizedCPUTensorId()));
+
+} // namespace
+} // namespace native
+} // namespace at
diff --git a/test/test_quantized.py b/test/test_quantized.py
index 1d27521..8553265 100644
--- a/test/test_quantized.py
+++ b/test/test_quantized.py
@@ -25,6 +25,14 @@
     return x
 
 
+def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
+    """Requantizes a numpy array, i.e., intermediate int32 or int16 values are
+    converted back to given type"""
+    qx = (x * multiplier).round() + zero_point
+    qx = np.clip(qx, qmin, qmax).astype(qtype)
+    return qx
+
+
 # Make sure we won't have overflows from vpmaddubsw instruction used in FBGEMM.
 # On the current Intel x86 architecture, we need to utilize vpmaddubsw instruction
 # for the 8-bit int multiplication. This instruction vertically multiplies each
@@ -369,5 +377,119 @@
         np.testing.assert_equal(Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy())
 
 
+@unittest.skipIf(
+    TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
+    " Quantized convolution requires FBGEMM. FBGEMM does not play"
+    " well with UBSAN at the moment, so we skip the test if"
+    " we are in a UBSAN environment.",
+)
+class TestQuantizedConv(unittest.TestCase):
+    """Tests the correctness of quantized convolution op."""
+    def test_qconv(self):
+
+        qconv = torch.ops.quantized.fbgemm_conv2d
+        qconv_prepack = torch.ops.quantized.fbgemm_conv_prepack
+
+        # N
+        batch_size = 1
+        # C
+        input_channels = 16
+        # H, W
+        height = width = 24
+        # K
+        output_channels = 8
+
+        kernel_h = kernel_w = 3
+        stride_h = stride_w = 1
+        padding_h = padding_w = 1
+        dilation_h = dilation_w = 1
+        groups = 1
+
+        W_value_min = 0
+        W_value_max = 5
+        # We use small values to avoid overflow.
+        # (the operator expects them in the format (output_channels, input_channels/groups, kernel_h, kernel_w))
+
+        W_init = torch.randint(
+            W_value_min,
+            W_value_max,
+            (output_channels, int(input_channels / groups), kernel_h, kernel_w),
+        )
+
+        b_init = torch.randint(0, 10, (output_channels,))
+
+        # Existing floating point conv operator
+        conv_op = torch.nn.Conv2d(
+            input_channels,
+            output_channels,
+            (kernel_h, kernel_w),
+            (stride_h, stride_w),
+            (padding_h, padding_w),
+            (dilation_h, dilation_w),
+            groups,
+        )
+
+        # assign the weights
+        conv_op.weight = torch.nn.Parameter(
+            W_init.to(dtype=torch.float), requires_grad=False
+        )
+        conv_op.bias = torch.nn.Parameter(
+            b_init.to(dtype=torch.float), requires_grad=False
+        )
+
+        X_value_min = 0
+        X_value_max = 4
+        X_init = torch.randint(
+            X_value_min, X_value_max, (batch_size, input_channels, height, width)
+        )
+
+        # run on an input tensor
+        result_ref = conv_op(X_init.to(dtype=torch.float))
+
+        # reformat X_init and W_init in the required format by conv operator
+        # NCHW -> NHWC
+        X_NHWC = X_init.permute([0, 2, 3, 1]).contiguous()
+        # KCRS -> RSCK
+        W_RSCK = W_init.permute([2, 3, 1, 0]).contiguous()
+
+        X_scale = 1.5
+        # Currently only 0 as zero point is supported.
+        X_zero_point = 0
+        X = X_scale * (X_NHWC - X_zero_point).to(dtype=torch.float)
+
+        W_scale = 2.5
+        W_zero_point = 0
+        W = W_scale * (W_RSCK - W_zero_point).to(dtype=torch.float)
+
+        X_q = X.quantize_linear(scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8)
+        W_q = W.quantize_linear(scale=W_scale, zero_point=W_zero_point, dtype=torch.quint8)
+        b_q = b_init.to(dtype=torch.int32)
+
+        W_prepack = qconv_prepack(W_q, groups)
+        Y_scale = 7.3
+        Y_zero_point = 5
+
+        Y_q = qconv(
+            X_q,
+            W_prepack,
+            b_q,
+            [1, 1],  # stride
+            [1, 1],  # padding
+            [1, 1],  # dilation
+            [0, 0],  # output_padding
+            1,  # groups
+            Y_scale,
+            Y_zero_point,
+        )
+
+        result_NHWK = result_ref.permute([0, 2, 3, 1])
+        result_q = _requantize(
+            result_NHWK.numpy(), X_scale * W_scale / Y_scale, Y_zero_point
+        )
+
+        # Make sure the results match
+        np.testing.assert_equal(result_q, Y_q.int_repr().numpy())
+
+
 if __name__ == "__main__":
     run_tests()