Add shortcuts for refs.pow (#80219)
As per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80219
Approved by: https://github.com/mruberry
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index fcd62e0..3ee173d 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -925,6 +925,30 @@
supports_lhs_python_scalar=False,
)
+
+def _pow(
+ a: Union[TensorLikeType, NumberType],
+ b: Union[TensorLikeType, NumberType],
+) -> TensorLikeType:
+ assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType)
+
+ if isinstance(b, Number):
+ if b == 1.0:
+ return a.clone() # type: ignore[return-value,union-attr]
+ elif b == 2.0:
+ return a * a # type: ignore[return-value]
+ elif b == 0.5:
+ return torch.sqrt(a) # type: ignore[arg-type]
+ return prims.pow(a, b)
+
+
+# TODO: add docstring
+pow = _make_elementwise_binary_reference(
+ _pow,
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
+ aten_op=torch.ops.aten.pow,
+)
+
# TODO: add docstring
# Float power has its own implementation because it has unique type promotion.
# NB: aten_op not registered because CompositeExplicitAutograd
@@ -955,7 +979,7 @@
b = prims.to_dtype(b, dtype)
a, b = _maybe_broadcast(a, b)
- return prims.pow(a, b)
+ return pow(a, b)
# >>> a = torch.tensor(-0.2500, dtype=torch.float64)
@@ -1306,12 +1330,6 @@
)
# TODO: add docstring
-pow = _make_elementwise_binary_reference(
- prims.pow,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
-)
-
-# TODO: add docstring
remainder = _make_elementwise_binary_reference(
prims.remainder,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 23205de..5b7f3c1 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -20774,6 +20774,7 @@
ElementwiseBinaryPythonRefInfo(
"_refs.pow",
torch_opinfo_name="pow",
+ supports_nvfuser=False, # clone default
skips=(
# Reference result was farther (inf) from the precise
# computation than the torch result was (nan)!