Quantized relu to native_functions (#22316)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22316
Adding the quantized ReLU to the native_functions.yamp, as it has the same signature as non-quantized relu
Reviewed By: jerryzh168
Differential Revision: D16038441
fbshipit-source-id: 1cfbb594eb9bca1b7ec49ca486defcf1908b0d26
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b4f3e6f..941b3d8 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1699,6 +1699,7 @@
CPU: relu
CUDA: relu
MkldnnCPU: mkldnn_relu
+ QuantizedCPU: quantized_relu
- func: relu_(Tensor(a!) self) -> Tensor(a!)
named_guard: False
diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
index fb1e065..914a93c 100644
--- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
@@ -2,29 +2,32 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
-#include <ATen/quantized/Quantizer.h>
+#include <ATen/NativeFunctions.h>
#include <algorithm>
namespace at { namespace native {
-namespace {
+Tensor quantized_relu(const Tensor& qx) {
+ Tensor qy;
+ const auto zero_point = qx.q_zero_point();
+ AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
+ qy = at::_empty_affine_quantized(qx.sizes(),
+ at::device(kCPU).dtype(SCALAR_TYPE),
+ qx.q_scale(),
+ qx.q_zero_point());
+ auto iter = TensorIterator::unary_op(qy, qx);
+ cpu_kernel(*iter, [&](scalar_t value) -> scalar_t {
+ return scalar_t(std::max<underlying_t>(value.val_, zero_point));
+ });
+ });
+ return qy;
+}
+namespace {
class QRelu final : public c10::OperatorKernel {
public:
Tensor operator()(Tensor qx) {
- Tensor qy;
- const auto zero_point = qx.q_zero_point();
- AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
- qy = at::_empty_affine_quantized(qx.sizes(),
- at::device(kCPU).dtype(SCALAR_TYPE),
- qx.q_scale(),
- qx.q_zero_point());
- auto iter = TensorIterator::unary_op(qy, qx);
- cpu_kernel(*iter, [&](scalar_t value) -> scalar_t {
- return scalar_t(std::max<underlying_t>(value.val_, zero_point));
- });
- });
- return qy;
+ return at::relu(qx);
}
};
@@ -33,4 +36,5 @@
c10::RegisterOperators::options()
.kernel<QRelu>(QuantizedCPUTensorId()));
} // namespace
+
}} // namespace at::native
diff --git a/test/test_quantized.py b/test/test_quantized.py
index 42e394c..1e29f22 100644
--- a/test/test_quantized.py
+++ b/test/test_quantized.py
@@ -83,19 +83,24 @@
qparams=hu.qparams()))
def test_qrelu(self, X):
X, (scale, zero_point, torch_type) = X
- relu = torch.ops.quantized.relu
Y = X.copy()
+ Y[Y < 0] = 0
+ qY = torch.quantize_linear(torch.from_numpy(Y), scale=scale,
+ zero_point=zero_point, dtype=torch_type)
X = torch.from_numpy(X)
-
qX = torch.quantize_linear(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
- qY_hat = relu(qX)
- Y[Y < 0] = 0
- qY_ref = torch.quantize_linear(torch.from_numpy(Y), scale=scale,
- zero_point=zero_point, dtype=torch_type)
- self.assertEqual(qY_ref, qY_hat)
+ ops_under_test = {
+ 'ops.quantized': torch.ops.quantized.relu,
+ 'native': torch.relu,
+ 'nn.functional': torch.nn.functional.relu
+ }
+
+ for name, op in ops_under_test.items():
+ qY_hat = op(qX)
+ self.assertEqual(qY, qY_hat, "{} relu failed".format(name))
"""Tests the correctness of the add and add_relu op."""
def test_qadd_relu_same_qparams(self):