quantized elu: require observation (#40100)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40100
ELU has a range of [-1, inf]. In the original PR which added
the quantized operator we decided to pass the quantization params
from the input. However, it makes more sense to require observation
for this op.
This PR changes the API to require observation. Next PRs in this stack
will add the eager and graph mode handling.
Test Plan:
```
python test/test_quantization.py TestQuantizedOps.test_qelu
```
Imported from OSS
Differential Revision: D22075083
fbshipit-source-id: 0ea0fd05a00cc7a5f122a2b1de09144bbd586f32
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index eaf9800..54b1eeb 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5474,18 +5474,10 @@
- func: elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
- dispatch:
- CPU: elu_out
- CUDA: elu_out
- QuantizedCPU: quantized_elu_out
- func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
use_c10_dispatcher: full
python_module: nn
- dispatch:
- CPU: elu
- CUDA: elu
- QuantizedCPU: quantized_elu
- func: elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
@@ -5499,10 +5491,6 @@
- func: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!)
python_module: nn
- dispatch:
- CPU: elu_
- CUDA: elu_
- QuantizedCPU: quantized_elu_
- func: glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
diff --git a/aten/src/ATen/native/quantized/cpu/qelu.cpp b/aten/src/ATen/native/quantized/cpu/qelu.cpp
index 24c264f..4967768 100644
--- a/aten/src/ATen/native/quantized/cpu/qelu.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qelu.cpp
@@ -9,28 +9,16 @@
DEFINE_DISPATCH(qelu_stub);
-Tensor& quantized_elu_out(Tensor& result, const Tensor& self, Scalar alpha,
- Scalar scale, Scalar input_scale) {
- qelu_stub(self.device().type(), self, alpha, result);
- return result;
-}
-
-Tensor& quantized_elu_(Tensor& self, Scalar alpha, Scalar scale,
- Scalar input_scale) {
- Tensor qy = at::_empty_affine_quantized(self.sizes(), self.options(),
- self.q_scale(), self.q_zero_point());
- qelu_stub(self.device().type(), self, alpha, qy);
- // This can be optimized in a later PR if necessary.
- self.copy_(qy);
- return self;
-}
-
Tensor quantized_elu(
- const Tensor& qx, Scalar alpha, Scalar scale, Scalar input_scale) {
+ const Tensor& qx, double output_scale, int64_t output_zero_point, Scalar alpha, Scalar scale, Scalar input_scale) {
Tensor qy = at::_empty_affine_quantized(qx.sizes(), qx.options(),
- qx.q_scale(), qx.q_zero_point());
+ output_scale, output_zero_point);
qelu_stub(qx.device().type(), qx, alpha, qy);
return qy;
}
+TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
+ m.impl("elu", quantized_elu);
+}
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp
index fc54e00..d01b1b8 100644
--- a/aten/src/ATen/native/quantized/library.cpp
+++ b/aten/src/ATen/native/quantized/library.cpp
@@ -76,6 +76,7 @@
m.def("conv3d_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]");
m.def("conv3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]");
m.def("conv3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int");
+ m.def("elu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor");
m.def("hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor");
m.def("group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("instance_norm(Tensor input, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");
diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py
index dcefd96..48c7a1f 100644
--- a/test/quantization/test_quantize_jit.py
+++ b/test/quantization/test_quantize_jit.py
@@ -2349,7 +2349,6 @@
self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
- self.elu = torch.nn.ELU()
self.leaky_relu = torch.nn.LeakyReLU()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.sigmoid = torch.nn.Sigmoid()
@@ -2379,9 +2378,6 @@
x = F.upsample_nearest(x, (32, 32)) # interpolate node
x = F.interpolate(x, 4, mode='linear') # common node
x = F.upsample_bilinear(x, (32, 32)) # common node
- x = self.elu(x)
- x = F.elu(x)
- x.elu_()
x = self.leaky_relu(x)
x = F.leaky_relu(x)
x.leaky_relu_()
@@ -2424,7 +2420,7 @@
# mapping from number of quant for the op to the number of these ops
# for example, for `3` in the key means for this type of op
# we'll have 3 quantize_per_tensor
- num_op_by_num_quant = {1: 35, 2: 2, 3: 3}
+ num_op_by_num_quant = {1: 32, 2: 2, 3: 3}
num_quantize_per_tensor = 1 # for output
for num_quant, num_op in num_op_by_num_quant.items():
num_quantize_per_tensor += num_op * num_quant
diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py
index a3fc275..eaed08f 100644
--- a/test/quantization/test_quantized_op.py
+++ b/test/quantization/test_quantized_op.py
@@ -306,35 +306,24 @@
alpha=st.floats(0.01, 10.0, allow_nan=False, allow_infinity=False))
def test_qelu(self, X, alpha):
X, (scale, zero_point, torch_type) = X
+ output_scale = 0.5
+ output_zero_point = 1
X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
- op = torch.nn.quantized.functional.elu
# calculate ELU(dqX) and quantize
dqX = qX.dequantize()
dqY_hat = dqX.clone()
- dqY_hat[dqX < 0] = alpha * (torch.exp(dqY_hat[dqX < 0]) - 1.)
- qY_hat = torch.quantize_per_tensor(dqY_hat, scale=scale, zero_point=zero_point,
+ dqY_hat = torch.nn.functional.elu(dqX, alpha)
+ qY_hat = torch.quantize_per_tensor(dqY_hat, scale=output_scale, zero_point=output_zero_point,
dtype=torch_type)
- # test regular
- qY = op(qX, alpha=alpha)
+ qY = torch.nn.quantized.functional.elu(qX, output_scale, output_zero_point, alpha=alpha)
self.assertEqual(qY, qY_hat,
msg="F.elu failed ({} vs {})".format(qY, qY_hat))
- # test inplace
- qXcopy = qX.clone()
- op(qXcopy, alpha=alpha, inplace=True)
- self.assertEqual(qXcopy, qY_hat,
- msg="F.elu_ failed ({} vs {})".format(qXcopy, qY_hat))
-
- # test explicit scale and zp
- qYout = op(qX, alpha=alpha, scale=scale, zero_point=zero_point)
- self.assertEqual(qYout, qY_hat,
- msg="F.elu.out failed ({} vs {})".format(qY, qY_hat))
-
"""Tests the correctness of the quantized::qlayer_norm op."""
@skipIfNoFBGEMM
def test_qlayer_norm(self):
diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp
index 0508d6b..c419c41 100644
--- a/torch/csrc/jit/passes/quantization/helper.cpp
+++ b/torch/csrc/jit/passes/quantization/helper.cpp
@@ -112,7 +112,6 @@
"upsample_bilinear",
"upsample_nearest",
"hardtanh",
- "elu",
"leaky_relu",
};
@@ -139,8 +138,6 @@
// "clamp_", // Enable when quantized `clamp_` is ready
"hardtanh",
"hardtanh_",
- "elu",
- "elu_",
"leaky_relu",
"leaky_relu_",
};
diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h
index 37e74cc..694e625 100644
--- a/torch/csrc/jit/passes/quantization/quantization_patterns.h
+++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h
@@ -805,12 +805,6 @@
auto hardtanh_ = getClampOpFusionInfo("aten::hardtanh_", {"%min", "%max"});
- auto elu = getInputTensorQParamOpFusionInfo(
- "aten::elu", {"%alpha", "%scale", "%input_scale"});
-
- auto elu_ = getInputTensorQParamOpFusionInfo(
- "aten::elu_", {"%alpha", "%scale", "%input_scale"});
-
auto leaky_relu =
getInputTensorQParamOpFusionInfo("aten::leaky_relu", {"%negative_slope"});
@@ -974,8 +968,6 @@
clamp,
hardtanh,
hardtanh_,
- elu,
- elu_,
leaky_relu,
leaky_relu_,
// fixed qparam ops
diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py
index 602475b..19cc1cd 100644
--- a/torch/nn/quantized/functional.py
+++ b/torch/nn/quantized/functional.py
@@ -472,8 +472,8 @@
raise ValueError("Input to 'value' must be specified!")
return torch._ops.ops.quantized.threshold(input, threshold, value)
-def elu(input, alpha=1., inplace=False, scale=None, zero_point=None):
- # type: (Tensor, Optional[float], bool, Optional[float], Optional[int]) -> Tensor
+def elu(input, scale, zero_point, alpha=1.):
+ # type: (Tensor, float, int, float) -> Tensor
r"""
Applies the quantized ELU function element-wise:
@@ -482,25 +482,12 @@
Args:
input: quantized input
- alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
- inplace: Inplace modification of the input tensor
scale, zero_point: Scale and zero point of the output tensor.
+ alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.elu' must be quantized!")
- if (scale is not None) != (zero_point is not None):
- raise ValueError("Either both or none of (scale, zero_point) must be specified!")
-
- if scale is not None and zero_point is not None:
- assert not inplace, "Cannot rescale with `inplace`"
- output = torch._empty_affine_quantized(
- input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype)
- torch._C._nn.elu(input, alpha, out=output)
- return output
- elif inplace:
- return torch._C._nn.elu_(input, alpha)
- else:
- return torch._C._nn.elu(input, alpha)
+ return torch.ops.quantized.elu(input, scale, zero_point, alpha)
def hardsigmoid(input):
# type: (Tensor) -> Tensor