Implement NumPy-like function torch.float_power() (#44937)

Summary:
- Related with https://github.com/pytorch/pytorch/issues/38349
- Implementing the NumPy-like function `torch.float_power()` .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44937

Reviewed By: ngimel

Differential Revision: D25192119

Pulled By: mruberry

fbshipit-source-id: 2e446b8e0c2825f045fe057e30c9419335557a05
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index a1eee28..4a34b27 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -552,6 +552,7 @@
 _(aten, poisson) \
 _(aten, polygamma) \
 _(aten, pow) \
+_(aten, float_power) \
 _(aten, prelu) \
 _(aten, prelu_backward) \
 _(aten, prod) \
diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp
index ca5d184..bfc5f91 100644
--- a/aten/src/ATen/native/Pow.cpp
+++ b/aten/src/ATen/native/Pow.cpp
@@ -81,6 +81,48 @@
   return native::pow_out(result, base, exp);
 }
 
+Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) {
+  auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ?
+                at::kComplexDouble : at::kDouble;
+  TORCH_CHECK(result.scalar_type() == dtype,
+              "output type ", result.scalar_type(), "is not the desired output type ", dtype);
+
+  return at::pow_out(result, base.to(dtype), exp.to(dtype));
+}
+
+Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) {
+  return at::float_power_out(result, base, c10::scalar_to_tensor(exp, base.device()));
+}
+
+Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) {
+  return at::float_power_out(result, c10::scalar_to_tensor(base, exp.device()), exp);
+}
+
+Tensor float_power(const Tensor& base, const Tensor& exp) {
+  auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
+  return at::pow(base.to(dtype), exp.to(dtype));
+}
+
+Tensor float_power(const Tensor& base, Scalar exp) {
+  return at::float_power(base, c10::scalar_to_tensor(exp, base.device()));
+}
+
+Tensor float_power(Scalar base, const Tensor& exp) {
+  return at::float_power(c10::scalar_to_tensor(base, exp.device()), exp);
+}
+
+Tensor& float_power_(Tensor& base, const Tensor& exp) {
+  auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
+  TORCH_CHECK(base.scalar_type() == dtype,
+              "self tensor type ", base.scalar_type(), "is not the desired type ", dtype);
+
+  return base.pow_(exp.to(dtype));
+}
+
+Tensor& float_power_(Tensor& base, Scalar exp) {
+  return base.float_power_(c10::scalar_to_tensor(exp, base.device()));
+}
+
 } // namespace native
 
 } // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 9a289cb..2d6e570 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6872,6 +6872,47 @@
     CPU, CUDA: pow
     SparseCPU, SparseCUDA: pow_sparse_scalar
 
+- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    Math: float_power_out
+
+- func: float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
+  use_c10_dispatcher: full
+  variants: function, method
+  dispatch:
+    Math: float_power
+
+- func: float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    Math: float_power_out
+
+- func: float_power.Scalar(Scalar self, Tensor exponent) -> Tensor
+  use_c10_dispatcher: full
+  dispatch:
+    Math: float_power
+
+- func: float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    Math: float_power_out
+
+- func: float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
+  use_c10_dispatcher: full
+  variants: function, method
+  dispatch:
+    Math: float_power
+
+- func: float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
+  use_c10_dispatcher: full
+  variants: method
+  dispatch:
+    Math: float_power_
+
+- func: float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
+  use_c10_dispatcher: full
+  variants: method
+  dispatch:
+    Math: float_power_
+
 - func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
   variants: method
   dispatch:
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index 730a385..d7b0af7 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -322,6 +322,8 @@
    .. automethod:: fliplr
    .. automethod:: flipud
    .. automethod:: float
+   .. automethod:: float_power
+   .. automethod:: float_power_
    .. automethod:: floor
    .. automethod:: floor_
    .. automethod:: floor_divide
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index aab84fc..4399e63 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -296,6 +296,7 @@
     exp2
     expm1
     fix
+    float_power
     floor
     floor_divide
     fmod
diff --git a/test/test_torch.py b/test/test_torch.py
index f417d10..86dd202 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -6205,6 +6205,102 @@
             torch.pow(m1, 1, out=out)
             self.assertEqual(out, m1)
 
+    @dtypes(*list(product(torch.testing.get_all_dtypes(include_bool=False),
+                          torch.testing.get_all_dtypes(include_bool=False))))
+    def test_float_power(self, device, dtypes):
+        def to_np(value):
+            if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16:
+                return value.to(torch.float).cpu().numpy()
+            return value.cpu().numpy() if isinstance(value, torch.Tensor) else value
+
+        base_dtype = dtypes[0]
+        exp_dtype = dtypes[1]
+        out_dtype = torch.complex128 if base_dtype.is_complex or exp_dtype.is_complex else torch.float64
+
+        base = make_tensor((30,), device, base_dtype, low=1, high=100)
+        # Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0
+        # Related: https://github.com/pytorch/pytorch/issues/48000
+        # base[0] = base[3] = base[7] = 0
+        exp = make_tensor((30,), device, exp_dtype, low=-2, high=2)
+        exp[0] = exp[4] = exp[6] = 0
+
+        expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp)))
+
+        exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2]
+        complex_exponents = exponents + [-2.5j, -1.0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j]
+
+        for op in (torch.float_power, torch.Tensor.float_power, torch.Tensor.float_power_):
+
+            # Case of Tensor x Tensor
+            if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
+                with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
+                    op(base.clone(), exp)
+            else:
+                result = op(base.clone(), exp)
+                self.assertEqual(expected, result)
+
+            if op is torch.float_power:
+                out = torch.empty_like(base).to(device=device, dtype=out_dtype)
+                op(base, exp, out=out)
+                self.assertEqual(expected, out)
+
+            # Case of Tensor x Scalar
+            for i in complex_exponents if exp_dtype.is_complex else exponents:
+                out_dtype_scalar_exp = torch.complex128 if base_dtype.is_complex or type(i) == complex else torch.float64
+                expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))
+
+                if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp:
+                    with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
+                        op(base.clone(), i)
+                else:
+                    result = op(base.clone(), i)
+                    self.assertEqual(expected_scalar_exp, result)
+
+                if op is torch.float_power:
+                    out = torch.empty_like(base).to(device=device, dtype=out_dtype_scalar_exp)
+                    op(base, i, out=out)
+                    self.assertEqual(expected_scalar_exp, out)
+
+        # Case of Scalar x Tensor
+        for i in complex_exponents if base_dtype.is_complex else exponents:
+            out_dtype_scalar_base = torch.complex128 if exp_dtype.is_complex or type(i) == complex else torch.float64
+            expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))
+
+            result = torch.float_power(i, exp)
+            self.assertEqual(expected_scalar_base, result)
+
+            out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base)
+            torch.float_power(i, exp, out=out)
+            self.assertEqual(expected_scalar_base, out)
+
+    def test_float_power_exceptions(self, device):
+        def _promo_helper(x, y):
+            for i in (x, y):
+                if type(i) == complex:
+                    return torch.complex128
+                elif type(i) == torch.Tensor and i.is_complex():
+                    return torch.complex128
+            return torch.double
+
+        test_cases = ((torch.tensor([-2, -1, 0, 1, 2], device=device), -.25),
+                      (torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 2.))
+        for base, exp in test_cases:
+            for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble):
+                out = torch.empty(1, device=device, dtype=out_dtype)
+                required_dtype = _promo_helper(base, exp)
+
+                if out.dtype == required_dtype:
+                    torch.float_power(base, exp, out=out)
+                else:
+                    with self.assertRaisesRegex(RuntimeError, "is not the desired output type"):
+                        torch.float_power(base, exp, out=out)
+
+                if base.dtype == required_dtype:
+                    torch.Tensor.float_power_(base.clone(), exp)
+                else:
+                    with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
+                        torch.Tensor.float_power_(base.clone(), exp)
+
     @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
     @onlyOnCPUAndCUDA
     @dtypes(torch.int8, torch.int16, torch.int32, torch.int64)
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index c11a5b4..2d8533b 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -2671,6 +2671,20 @@
 In-place version of :meth:`~Tensor.pow`
 """)
 
+add_docstr_all('float_power',
+               r"""
+float_power(exponent) -> Tensor
+
+See :func:`torch.float_power`
+""")
+
+add_docstr_all('float_power_',
+               r"""
+float_power_(exponent) -> Tensor
+
+In-place version of :meth:`~Tensor.float_power`
+""")
+
 add_docstr_all('prod',
                r"""
 prod(dim=None, keepdim=False, dtype=None) -> Tensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 545578c..3b6ee12 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -6367,6 +6367,46 @@
     tensor([  2.,   4.,   8.,  16.])
 """.format(**common_args))
 
+add_docstr(torch.float_power,
+           r"""
+float_power(input, exponent, *, out=None) -> Tensor
+
+Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. 
+If neither input is complex returns a ``torch.float64`` tensor, 
+and if one or more inputs is complex returns a ``torch.complex128`` tensor.
+
+.. note:: 
+    This function always computes in double precision, unlike :func:`torch.pow`, 
+    which implements more typical :ref:`type promotion <type-promotion-doc>`.
+    This is useful when the computation needs to be performed in a wider or more precise dtype, 
+    or the results of the computation may contain fractional values not representable in the input dtypes, 
+    like when an integer base is raised to a negative integer exponent.
+
+Args:
+    input (Tensor or Number): the base value(s)
+    exponent (Tensor or Number): the exponent value(s)
+
+Keyword args:
+    {out}
+
+Example::
+
+    >>> a = torch.randint(10, (4,))
+    >>> a
+    tensor([6, 4, 7, 1])
+    >>> torch.float_power(a, 2)
+    tensor([36., 16., 49.,  1.], dtype=torch.float64)
+
+    >>> a = torch.arange(1, 5)
+    >>> a
+    tensor([ 1,  2,  3,  4])
+    >>> exp = torch.tensor([2, -3, 4, -5])
+    >>> exp
+    tensor([ 2, -3,  4, -5])
+    >>> torch.float_power(a, exp)
+    tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64)
+""".format(**common_args))
+
 add_docstr(torch.prod,
            r"""
 prod(input, *, dtype=None) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index 2c48e71..eb863c7 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -366,6 +366,7 @@
         torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
         torch.floor: lambda input, out=None: -1,
         torch.floor_divide: lambda input, other: -1,
+        torch.float_power: lambda input, exponent, out=None: -1,
         torch.fmod: lambda input, other, out=None: -1,
         torch.frac: lambda input, out=None: -1,
         torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 36e7fb6..8850c3f 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -694,6 +694,14 @@
         ('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14,), 'complex_scalar_constant', (True,)),
         ('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14j,), 'complex_imaginary_exponent', (True,)),
         ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')),
+        ('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), ''),
+        ('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'),
+        ('float_power', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'),
+        ('float_power', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'),
+        ('float_power', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar'),
+        ('float_power', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs'),
+        ('float_power', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs'),
+        ('float_power', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'),
         ('transpose', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]),
         ('transpose', (), (0, 0), 'scalar', (False,)),
         ('transpose', (1,), (0, 0), '1d', (False,)),