Reland hardtanh ref (again) (#78914)
Fixes land race between https://github.com/pytorch/pytorch/commit/823ddb6e87e8111c9b5a99523503172e5bf62c49 and Ed's stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78914
Approved by: https://github.com/wanchaol
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index e9c1e7a..df548f7 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -136,12 +136,6 @@
)
-@register_decomposition(aten.hardtanh)
-@pw_cast_for_opmath
-def hardtanh(self: Tensor, min_val: float = -1, max_val: float = 1) -> Tensor:
- return torch.clamp(self, min_val, max_val)
-
-
@register_decomposition(aten.hardtanh_backward)
@pw_cast_for_opmath
def hardtanh_backward(
diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py
index 61faf2e..7dd7fa5 100644
--- a/torch/_refs/nn/functional/__init__.py
+++ b/torch/_refs/nn/functional/__init__.py
@@ -26,6 +26,7 @@
"dropout",
"elu",
"relu",
+ "hardtanh",
"hinge_embedding_loss",
"margin_ranking_loss",
"mish",
@@ -311,3 +312,38 @@
"Expected a tensor input for an elementwise unary operation!"
)
return refs.sub(a, refs.tanh(a))
+
+
+@register_decomposition(torch.ops.aten.hardtanh)
+@elementwise_unary_scalar_wrapper
+@elementwise_type_promotion_wrapper(
+ type_promoting_args=("a"),
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+)
+def hardtanh(
+ a: TensorLikeType,
+ min_val: NumberType = -1,
+ max_val: NumberType = 1,
+ inplace: bool = False,
+) -> TensorLikeType:
+ """
+ Reference implementation of torch.nn.functional.hardtanh
+ """
+ if inplace:
+ raise NotImplementedError
+ if not isinstance(a, TensorLike):
+ raise RuntimeError(
+ "Expected a tensor input for an elementwise unary operation!"
+ )
+ if utils.is_boolean_dtype(a.dtype):
+ raise RuntimeError("Bool inputs not supported for hardtanh")
+
+ # preserve legacy behavior of boundaries not causing type promotion
+ if utils.is_integer_dtype(a.dtype):
+ min_val = int(min_val) # type: ignore[arg-type]
+ max_val = int(max_val) # type: ignore[arg-type]
+ if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)):
+ raise RuntimeError(
+ "Cannot do hardtanh on an unsigned type with negative limits"
+ )
+ return torch.clamp(a, min_val, max_val) # type: ignore[arg-type]
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 79695aa..e65b0ce 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -19735,6 +19735,10 @@
torch_opinfo_name="nn.functional.elu",
),
PythonRefInfo(
+ "_refs.nn.functional.hardtanh",
+ torch_opinfo_name="nn.functional.hardtanh",
+ ),
+ PythonRefInfo(
"_refs.nn.functional.leaky_relu",
torch_opinfo_name="nn.functional.leaky_relu",
),