Add shortcuts for refs.pow (#80219)

As per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80219
Approved by: https://github.com/mruberry
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index fcd62e0..3ee173d 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -925,6 +925,30 @@
     supports_lhs_python_scalar=False,
 )
 
+
+def _pow(
+    a: Union[TensorLikeType, NumberType],
+    b: Union[TensorLikeType, NumberType],
+) -> TensorLikeType:
+    assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType)
+
+    if isinstance(b, Number):
+        if b == 1.0:
+            return a.clone()  # type: ignore[return-value,union-attr]
+        elif b == 2.0:
+            return a * a  # type: ignore[return-value]
+        elif b == 0.5:
+            return torch.sqrt(a)  # type: ignore[arg-type]
+    return prims.pow(a, b)
+
+
+# TODO: add docstring
+pow = _make_elementwise_binary_reference(
+    _pow,
+    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
+    aten_op=torch.ops.aten.pow,
+)
+
 # TODO: add docstring
 # Float power has its own implementation because it has unique type promotion.
 # NB: aten_op not registered because CompositeExplicitAutograd
@@ -955,7 +979,7 @@
         b = prims.to_dtype(b, dtype)
 
     a, b = _maybe_broadcast(a, b)
-    return prims.pow(a, b)
+    return pow(a, b)
 
 
 # >>> a = torch.tensor(-0.2500, dtype=torch.float64)
@@ -1306,12 +1330,6 @@
 )
 
 # TODO: add docstring
-pow = _make_elementwise_binary_reference(
-    prims.pow,
-    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
-)
-
-# TODO: add docstring
 remainder = _make_elementwise_binary_reference(
     prims.remainder,
     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 23205de..5b7f3c1 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -20774,6 +20774,7 @@
     ElementwiseBinaryPythonRefInfo(
         "_refs.pow",
         torch_opinfo_name="pow",
+        supports_nvfuser=False,  # clone default
         skips=(
             # Reference result was farther (inf) from the precise
             # computation than the torch result was (nan)!