Dynamic dispatch for optimized quantized op kernels (#25545)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25545

This re-uses the infrastructure from ATen/native/cpu, which compiles kernels multiple times for different instruction sets and dispatches dynamically based on the CPU's capability flags at runtime. This ensures we use the most optimal quantized kernel for the given machine

Test Plan: Imported from OSS

Differential Revision: D17166369

Pulled By: jamesr66a

fbshipit-source-id: 8c8393f99365e1408819bbaf254c1b5734a34b70
diff --git a/aten/src/ATen/cpu/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec256/vec256_qint.h
index b64c611..6a21479 100644
--- a/aten/src/ATen/cpu/vec256/vec256_qint.h
+++ b/aten/src/ATen/cpu/vec256/vec256_qint.h
@@ -294,8 +294,8 @@
         Vec256<c10::quint8> zero_point,
         Vec256<c10::quint8> q_six) {
 #ifdef __AVX2__
-      return _mm256_min_epi8(
-          _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
+      return _mm256_min_epu8(
+          _mm256_max_epu8(vals, zero_point.vals), q_six.vals);
 #else
       // Pray the compiler can autovectorize this
       uint8_t int_vals[size()];
@@ -405,8 +405,8 @@
         Vec256<c10::qint32> zero_point,
         Vec256<c10::qint32> q_six) {
 #ifdef __AVX2__
-      return _mm256_min_epi8(
-          _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
+      return _mm256_min_epi32(
+          _mm256_max_epi32(vals, zero_point.vals), q_six.vals);
 #else
       // Pray the compiler can autovectorize this
       int32_t int_vals[size()];
diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
new file mode 100644
index 0000000..00a1585
--- /dev/null
+++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
@@ -0,0 +1,128 @@
+#include <ATen/ATen.h>
+#include <ATen/Dispatch.h>
+#include <ATen/native/TensorIterator.h>
+#include <ATen/native/cpu/Loops.h>
+#include <ATen/native/quantized/cpu/quantized_ops.h>
+
+namespace at {
+namespace native {
+namespace {
+
+// ****************** HEY YOU! YES YOU! Read this! ********************
+//
+// Please read the README.md in this directory before editing this file
+
+void qrelu_kernel(const Tensor& qx, Tensor& qy) {
+  const auto zero_point = qx.q_zero_point();
+  AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
+    qy = at::_empty_affine_quantized(
+        qx.sizes(),
+        at::device(kCPU).dtype(SCALAR_TYPE),
+        qx.q_scale(),
+        qx.q_zero_point(),
+        qx.suggest_memory_format());
+    using Vec = Vec256<scalar_t>;
+    auto zero_point_vec = Vec(scalar_t(zero_point));
+    auto iter = TensorIterator::unary_op(qy, qx);
+    cpu_kernel_vec(
+        iter,
+        [&](scalar_t value) -> scalar_t {
+          return scalar_t(std::max<underlying_t>(value.val_, zero_point));
+        },
+        [&](Vec value) -> Vec { return value.relu(zero_point_vec); });
+  });
+}
+
+void qrelu6_kernel(const Tensor& qx, Tensor& qy) {
+  const auto zero_point = qx.q_zero_point();
+  AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu6", [&]() {
+    qy = at::_empty_affine_quantized(
+        qx.sizes(),
+        at::device(kCPU).dtype(SCALAR_TYPE),
+        qx.q_scale(),
+        qx.q_zero_point(),
+        qx.suggest_memory_format());
+    using Vec = Vec256<scalar_t>;
+    auto iter = TensorIterator::unary_op(qy, qx);
+    scalar_t six =
+        at::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(), 6.0);
+    auto zero_point_vec = Vec(scalar_t(zero_point));
+    auto six_vec = Vec(six);
+    cpu_kernel_vec(
+        iter,
+        [&](scalar_t value) -> scalar_t {
+          underlying_t relu_val =
+              std::max<underlying_t>(value.val_, zero_point);
+          return scalar_t(std::min<underlying_t>(relu_val, six.val_));
+        },
+        [&](Vec val) -> Vec { return val.relu6(zero_point_vec, six_vec); });
+  });
+}
+
+// Note: out is assumed to be the same size as self and other.
+// Note: Addition is only supported when self, other, out are of the same dtype.
+template <bool ReLUFused = false>
+void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
+  int64_t zero_point = out.q_zero_point();
+  double scale = out.q_scale();
+  int64_t self_zero_point = self.q_zero_point();
+  double self_scale = self.q_scale();
+  int64_t other_zero_point = other.q_zero_point();
+  double other_scale = other.q_scale();
+
+  // Broadcast out the parameters here to amortize out that cost across
+  // loop iterations.
+  // TODO: we can optimize dequantization by doing a premultiplication
+  // of the zero point by scale and doing FMA on scale*x_q - (scale*zero_point)
+  auto self_zero_point_vec = Vec256<float>((float)self_zero_point);
+  auto self_scale_vec = Vec256<float>(self_scale);
+  auto other_zero_point_vec = Vec256<float>((float)other_zero_point);
+  auto other_scale_vec = Vec256<float>(other_scale);
+
+  auto iter = TensorIterator::binary_op(out, self, other);
+
+  AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qadd", [&]() {
+    using Vec = Vec256<scalar_t>;
+    cpu_kernel_vec(
+        iter,
+        [&](scalar_t a, scalar_t b) -> scalar_t {
+          const auto da = at::dequantize_val(self_scale, self_zero_point, a);
+          const auto db = at::dequantize_val(other_scale, other_zero_point, b);
+          float c = da + db;
+          if (ReLUFused) {
+            c = std::max<float>(c, 0.0);
+          }
+          return at::quantize_val<scalar_t>(scale, zero_point, c);
+        },
+        [&](Vec a, Vec b) -> Vec {
+          const auto da = a.dequantize(self_scale_vec, self_zero_point_vec);
+          const auto db = b.dequantize(other_scale_vec, other_zero_point_vec);
+          Vec::float_vec_return_type retvals;
+          for (int i = 0; i < Vec::float_num_vecs(); ++i) {
+            auto c = da[i] + db[i];
+            if (ReLUFused) {
+              c = vec256::maximum(c, Vec256<float>(0.0f));
+            }
+            retvals[i] = c;
+          }
+          // TODO: fbgemm::Quantize doesn't support taking in the
+          // pre-broadcasted parameters. We might be able to save some cycles by
+          // enabling that in the API.
+          // TODO: specialize fbgemm::Quantize for a single vector and make it
+          // inlineable. This could help with interleaving as suggested by the
+          // TensorIterator implementations
+          auto rv = Vec::quantize(retvals, scale, zero_point);
+          return rv;
+        });
+  });
+}
+
+} // namespace
+
+REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
+REGISTER_DISPATCH(qrelu6_stub, &qrelu6_kernel);
+REGISTER_DISPATCH(qadd_relu_stub, &qadd_kernel<true>);
+REGISTER_DISPATCH(qadd_stub, &qadd_kernel<false>);
+
+} // namespace native
+} // namespace at
\ No newline at end of file
diff --git a/aten/src/ATen/native/quantized/cpu/kernels/README.md b/aten/src/ATen/native/quantized/cpu/kernels/README.md
new file mode 100644
index 0000000..dc3f6cf
--- /dev/null
+++ b/aten/src/ATen/native/quantized/cpu/kernels/README.md
@@ -0,0 +1,16 @@
+ The files in this directory are compiled multiple times for different CPU vector instruction
+ sets (e.g. AVX, AVX2). The purpose of putting code in this directory is to make
+ sure we can generate the optimal code for a given processor's vector
+ capabilities. Much of this is done via preprocessor guards in vec256_qint.h.
+
+ The considerations for code written in this directory include:
+  - Keep code in this directory to a minimum, since we're compiling it several
+    times.
+  - All code in this file should go through the DECLARE_DISPATCH,
+    DEFINE_DISPATCH, and REGISTER_DISPATCH mechanism to ensure the correct
+    runtime dispatch occurs.
+  - THE CODE MUST RESIDE IN THE ANONYMOUS NAMESPACE. FAILURE TO ENSURE THIS
+    IS THE CASE CAN LEAD TO HARD-TO-DEBUG ODR VIOLATIONS.
+  - **Make sure different variants of the code (AVX, AVX2) are tested!**
+    There are build variants that do things like have NO AVX and NO AVX2 in
+    CI. Make sure they work!
\ No newline at end of file
diff --git a/aten/src/ATen/native/quantized/cpu/qadd.cpp b/aten/src/ATen/native/quantized/cpu/qadd.cpp
index a910c8c..41d752a 100644
--- a/aten/src/ATen/native/quantized/cpu/qadd.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qadd.cpp
@@ -4,11 +4,16 @@
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/cpu/Loops.h>
 #include <ATen/quantized/Quantizer.h>
+#include <ATen/native/quantized/cpu/quantized_ops.h>
 
 #include <algorithm>
 
 namespace at {
 namespace native {
+
+DEFINE_DISPATCH(qadd_relu_stub);
+DEFINE_DISPATCH(qadd_stub);
+
 namespace {
 
 inline void check_inputs(const Tensor& qa, const Tensor& qb) {
@@ -28,56 +33,11 @@
 // Note: Addition is only supported when self, other, out are of the same dtype.
 template <bool ReLUFused = false>
 Tensor _add_out(Tensor& out, const Tensor& self, const Tensor& other) {
-  int64_t zero_point = out.q_zero_point();
-  double scale = out.q_scale();
-  int64_t self_zero_point = self.q_zero_point();
-  double self_scale = self.q_scale();
-  int64_t other_zero_point = other.q_zero_point();
-  double other_scale = other.q_scale();
-
-  // Broadcast out the parameters here to amortize out that cost across
-  // loop iterations.
-  // TODO: we can optimize dequantization by doing a premultiplication
-  // of the zero point by scale and doing FMA on scale*x_q - (scale*zero_point)
-  auto self_zero_point_vec = Vec256<float>((float)self_zero_point);
-  auto self_scale_vec = Vec256<float>(self_scale);
-  auto other_zero_point_vec = Vec256<float>((float)other_zero_point);
-  auto other_scale_vec = Vec256<float>(other_scale);
-
-  auto iter = TensorIterator::binary_op(out, self, other);
-
-  AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qadd", [&]() {
-    using Vec = Vec256<scalar_t>;
-    cpu_kernel_vec(iter, [&](scalar_t a, scalar_t b) -> scalar_t {
-      const auto da = at::dequantize_val(self_scale, self_zero_point, a);
-      const auto db = at::dequantize_val(other_scale, other_zero_point, b);
-      float c = da + db;
-      if (ReLUFused) {
-        c = std::max<float>(c, 0.0);
-      }
-      return at::quantize_val<scalar_t>(scale, zero_point, c);
-    },
-    [&](Vec a, Vec b) -> Vec {
-      const auto da = a.dequantize(self_scale_vec, self_zero_point_vec);
-      const auto db = b.dequantize(other_scale_vec, other_zero_point_vec);
-      Vec::float_vec_return_type retvals;
-      for (int i = 0; i < Vec::float_num_vecs(); ++i) {
-        auto c = da[i] + db[i];
-        if (ReLUFused) {
-          c = vec256::maximum(c, Vec256<float>(0.0f));
-        }
-        retvals[i] = c;
-      }
-      // TODO: fbgemm::Quantize doesn't support taking in the pre-broadcasted
-      // parameters. We might be able to save some cycles by enabling that
-      // in the API.
-      // TODO: specialize fbgemm::Quantize for a single vector and make it
-      // inlineable. This could help with interleaving as suggested by the
-      // TensorIterator implementations
-      auto rv = Vec::quantize(retvals, scale, zero_point);
-      return rv;
-    });
-  });
+  if (ReLUFused) {
+    qadd_relu_stub(self.device().type(), out, self, other);
+  } else {
+    qadd_stub(self.device().type(), out, self, other);
+  }
   return out;
 }
 
diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
index aa0a0c6..bc7c601 100644
--- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
@@ -4,31 +4,19 @@
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/cpu/Loops.h>
 #include <ATen/quantized/Quantizer.h>
+#include <ATen/native/quantized/cpu/quantized_ops.h>
 
 #include <algorithm>
 
 namespace at {
 namespace native {
+
+DEFINE_DISPATCH(qrelu_stub);
+DEFINE_DISPATCH(qrelu6_stub);
+
 Tensor quantized_relu(const Tensor& qx) {
   Tensor qy;
-  const auto zero_point = qx.q_zero_point();
-  AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
-    qy = at::_empty_affine_quantized(
-        qx.sizes(),
-        at::device(kCPU).dtype(SCALAR_TYPE),
-        qx.q_scale(),
-        qx.q_zero_point(),
-        qx.suggest_memory_format());
-    using Vec = Vec256<scalar_t>;
-    auto iter = TensorIterator::unary_op(qy, qx);
-    auto zero_point_vec = Vec(scalar_t(zero_point));
-    cpu_kernel_vec(
-        iter,
-        [&](scalar_t value) -> scalar_t {
-          return scalar_t(std::max<underlying_t>(value.val_, zero_point));
-        },
-        [&](Vec value) -> Vec { return value.relu(zero_point_vec); });
-  });
+  qrelu_stub(qx.device().type(), qx, qy);
   return qy;
 }
 Tensor& quantized_relu_(Tensor& qx) {
@@ -50,33 +38,10 @@
 namespace {
 Tensor quantized_relu6(const Tensor& qx) {
   Tensor qy;
-  const auto zero_point = qx.q_zero_point();
-  AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
-    qy = at::_empty_affine_quantized(
-        qx.sizes(),
-        at::device(kCPU).dtype(SCALAR_TYPE),
-        qx.q_scale(),
-        qx.q_zero_point(),
-        qx.suggest_memory_format());
-    using Vec = Vec256<scalar_t>;
-    auto iter = TensorIterator::unary_op(qy, qx);
-    scalar_t six = at::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(),
-                                              6.0);
-    auto zero_point_vec = Vec(scalar_t(zero_point));
-    auto six_vec = Vec(six);
-    cpu_kernel_vec(
-        iter,
-        [&](scalar_t value) -> scalar_t {
-          underlying_t relu_val =
-              std::max<underlying_t>(value.val_, zero_point);
-          return scalar_t(std::min<underlying_t>(relu_val, six.val_));
-        },
-        [&](Vec val) { return val.relu6(zero_point_vec, six_vec); });
-  });
+  qrelu6_stub(qx.device().type(), qx, qy);
   return qy;
 }
 
-
 class QRelu6 final : public c10::OperatorKernel {
  public:
   Tensor operator()(Tensor qx) {
diff --git a/aten/src/ATen/native/quantized/cpu/quantized_ops.h b/aten/src/ATen/native/quantized/cpu/quantized_ops.h
new file mode 100644
index 0000000..9475e60
--- /dev/null
+++ b/aten/src/ATen/native/quantized/cpu/quantized_ops.h
@@ -0,0 +1,18 @@
+#include <ATen/ATen.h>
+#include <ATen/native/DispatchStub.h>
+#include <ATen/native/TensorIterator.h>
+
+namespace at {
+namespace native {
+
+using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
+using qadd_fn =
+    void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/);
+
+DECLARE_DISPATCH(qrelu_fn, qrelu_stub);
+DECLARE_DISPATCH(qrelu_fn, qrelu6_stub);
+DECLARE_DISPATCH(qadd_fn, qadd_stub);
+DECLARE_DISPATCH(qadd_fn, qadd_relu_stub);
+
+} // namespace native
+} // namespace at
\ No newline at end of file
diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake
index e929847..3efaafe 100644
--- a/cmake/Codegen.cmake
+++ b/cmake/Codegen.cmake
@@ -84,7 +84,7 @@
     SET_SOURCE_FILES_PROPERTIES(${CMAKE_CURRENT_LIST_DIR}/../aten/src/TH/THAllocator.cpp PROPERTIES COMPILE_FLAGS "-fno-openmp")
   ENDIF()
 
-  FILE(GLOB cpu_kernel_cpp_in "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cpu/*.cpp")
+  FILE(GLOB cpu_kernel_cpp_in "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cpu/*.cpp" "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/quantized/cpu/kernels/*.cpp")
 
   LIST(APPEND CPU_CAPABILITY_NAMES "DEFAULT")
   LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}")
diff --git a/test/test_quantized.py b/test/test_quantized.py
index f0f1f91..a88a3ec 100644
--- a/test/test_quantized.py
+++ b/test/test_quantized.py
@@ -80,13 +80,10 @@
 class TestQuantizedOps(TestCase):
 
     """Tests the correctness of the quantized::relu op."""
-    @given(qparams=hu.qparams())
-    def test_qrelu(self, qparams):
-        X = np.array([[-3, -2, 1, 2],
-                      [0, 0, 0, 0],
-                      [-5, -4, -3, -2],
-                      [1, 2, 3, 4]], dtype=np.float32)
-        scale, zero_point, torch_type = qparams
+    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
+                       qparams=hu.qparams()))
+    def test_qrelu(self, X):
+        X, (scale, zero_point, torch_type) = X
 
         Y = X.copy()
         Y[Y < 0] = 0