quant docs: add and clean up ELU (#40377)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40377
Cleans up the docstring for quantized ELU and adds it to the quantization docs.
Test Plan: * build on Mac OS and inspect
Differential Revision: D22162834
Pulled By: vkuzo
fbshipit-source-id: e548fd4dc8d67db27ed19cac4dbdf2a942586759
diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst
index db94bca..870f19a 100644
--- a/docs/source/quantization.rst
+++ b/docs/source/quantization.rst
@@ -189,6 +189,7 @@
* :meth:`~torch.nn.functional.relu` — Rectified linear unit (copy)
* :meth:`~torch.nn.functional.relu_` — Rectified linear unit (inplace)
+* :meth:`~torch.nn.functional.elu` - ELU
* :meth:`~torch.nn.functional.max_pool2d` - Maximum pooling
* :meth:`~torch.nn.functional.adaptive_avg_pool2d` - Adaptive average pooling
* :meth:`~torch.nn.functional.avg_pool2d` - Average pooling
@@ -353,6 +354,7 @@
* :class:`~torch.nn.quantized.ReLU` — Rectified linear unit
* :class:`~torch.nn.quantized.ReLU6` — Rectified linear unit with cut-off at
quantized representation of 6
+* :class:`~torch.nn.quantized.ELU` — ELU
* :class:`~torch.nn.quantized.Hardswish` — Hardswish
* :class:`~torch.nn.quantized.BatchNorm2d` — BatchNorm2d. *Note: this module is usually fused with Conv or Linear. Performance on ARM is not optimized*.
* :class:`~torch.nn.quantized.BatchNorm3d` — BatchNorm3d. *Note: this module is usually fused with Conv or Linear. Performance on ARM is not optimized*.
@@ -386,6 +388,7 @@
* :func:`~torch.nn.quantized.functional.linear` — Linear (fully-connected) op
* :func:`~torch.nn.quantized.functional.max_pool2d` — 2D max pooling
* :func:`~torch.nn.quantized.functional.relu` — Rectified linear unit
+* :func:`~torch.nn.quantized.functional.elu` — ELU
* :func:`~torch.nn.quantized.functional.hardsigmoid` — Hardsigmoid
* :func:`~torch.nn.quantized.functional.hardswish` — Hardswish
* :func:`~torch.nn.quantized.functional.hardtanh` — Hardtanh
@@ -749,6 +752,11 @@
.. autoclass:: ReLU6
:members:
+ELU
+~~~~~~~~~~~~~~~
+.. autoclass:: ELU
+ :members:
+
Hardswish
~~~~~~~~~~~~~~~
.. autoclass:: Hardswish
diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py
index be66a98..ac93e49 100644
--- a/torch/nn/quantized/functional.py
+++ b/torch/nn/quantized/functional.py
@@ -455,16 +455,13 @@
def elu(input, scale, zero_point, alpha=1.):
# type: (Tensor, float, int, float) -> Tensor
- r"""
- Applies the quantized ELU function element-wise:
-
- .. math::
- \text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))
+ r"""This is the quantized version of :func:`~torch.nn.functional.elu`.
Args:
input: quantized input
- scale, zero_point: Scale and zero point of the output tensor.
- alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
+ scale: quantization scale of the output tensor
+ zero_point: quantization zero point of the output tensor
+ alpha: the alpha constant
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.elu' must be quantized!")
diff --git a/torch/nn/quantized/modules/activation.py b/torch/nn/quantized/modules/activation.py
index e0fac03..1df7c06 100644
--- a/torch/nn/quantized/modules/activation.py
+++ b/torch/nn/quantized/modules/activation.py
@@ -106,7 +106,12 @@
return Hardswish(float(scale), int(zero_point))
class ELU(torch.nn.ELU):
- r"""This is the quantized equivalent of :class:`torch.nn.ELU`.
+ r"""This is the quantized equivalent of :class:`~torch.nn.ELU`.
+
+ Args:
+ scale: quantization scale of the output tensor
+ zero_point: quantization zero point of the output tensor
+ alpha: the alpha constant
"""
def __init__(self, scale, zero_point, alpha=1.):
super(ELU, self).__init__(alpha)