Quantized relu to native_functions (#22316)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22316

Adding the quantized ReLU to the native_functions.yamp, as it has the same signature as non-quantized relu

Reviewed By: jerryzh168

Differential Revision: D16038441

fbshipit-source-id: 1cfbb594eb9bca1b7ec49ca486defcf1908b0d26
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b4f3e6f..941b3d8 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1699,6 +1699,7 @@
     CPU: relu
     CUDA: relu
     MkldnnCPU: mkldnn_relu
+    QuantizedCPU: quantized_relu
 
 - func: relu_(Tensor(a!) self) -> Tensor(a!)
   named_guard: False
diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
index fb1e065..914a93c 100644
--- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
@@ -2,29 +2,32 @@
 #include <ATen/core/op_registration/op_registration.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/cpu/Loops.h>
-#include <ATen/quantized/Quantizer.h>
+#include <ATen/NativeFunctions.h>
 
 #include <algorithm>
 
 namespace at { namespace native {
-namespace {
+Tensor quantized_relu(const Tensor& qx) {
+  Tensor qy;
+  const auto zero_point = qx.q_zero_point();
+  AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
+    qy = at::_empty_affine_quantized(qx.sizes(),
+                                     at::device(kCPU).dtype(SCALAR_TYPE),
+                                     qx.q_scale(),
+                                     qx.q_zero_point());
+    auto iter = TensorIterator::unary_op(qy, qx);
+    cpu_kernel(*iter, [&](scalar_t value) -> scalar_t {
+      return scalar_t(std::max<underlying_t>(value.val_, zero_point));
+    });
+  });
+  return qy;
+}
 
+namespace {
 class QRelu final : public c10::OperatorKernel {
  public:
   Tensor operator()(Tensor qx) {
-    Tensor qy;
-    const auto zero_point = qx.q_zero_point();
-    AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
-      qy = at::_empty_affine_quantized(qx.sizes(),
-                                       at::device(kCPU).dtype(SCALAR_TYPE),
-                                       qx.q_scale(),
-                                       qx.q_zero_point());
-      auto iter = TensorIterator::unary_op(qy, qx);
-      cpu_kernel(*iter, [&](scalar_t value) -> scalar_t {
-        return scalar_t(std::max<underlying_t>(value.val_, zero_point));
-      });
-    });
-    return qy;
+    return at::relu(qx);
   }
 };
 
@@ -33,4 +36,5 @@
     c10::RegisterOperators::options()
       .kernel<QRelu>(QuantizedCPUTensorId()));
 }  // namespace
+
 }}  // namespace at::native
diff --git a/test/test_quantized.py b/test/test_quantized.py
index 42e394c..1e29f22 100644
--- a/test/test_quantized.py
+++ b/test/test_quantized.py
@@ -83,19 +83,24 @@
                        qparams=hu.qparams()))
     def test_qrelu(self, X):
         X, (scale, zero_point, torch_type) = X
-        relu = torch.ops.quantized.relu
 
         Y = X.copy()
+        Y[Y < 0] = 0
+        qY = torch.quantize_linear(torch.from_numpy(Y), scale=scale,
+                                   zero_point=zero_point, dtype=torch_type)
         X = torch.from_numpy(X)
-
         qX = torch.quantize_linear(X, scale=scale, zero_point=zero_point,
                                    dtype=torch_type)
-        qY_hat = relu(qX)
 
-        Y[Y < 0] = 0
-        qY_ref = torch.quantize_linear(torch.from_numpy(Y), scale=scale,
-                                       zero_point=zero_point, dtype=torch_type)
-        self.assertEqual(qY_ref, qY_hat)
+        ops_under_test = {
+            'ops.quantized': torch.ops.quantized.relu,
+            'native': torch.relu,
+            'nn.functional': torch.nn.functional.relu
+        }
+
+        for name, op in ops_under_test.items():
+            qY_hat = op(qX)
+            self.assertEqual(qY, qY_hat, "{} relu failed".format(name))
 
     """Tests the correctness of the add and add_relu op."""
     def test_qadd_relu_same_qparams(self):