Implement NumPy-like function torch.float_power() (#44937)
Summary:
- Related with https://github.com/pytorch/pytorch/issues/38349
- Implementing the NumPy-like function `torch.float_power()` .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44937
Reviewed By: ngimel
Differential Revision: D25192119
Pulled By: mruberry
fbshipit-source-id: 2e446b8e0c2825f045fe057e30c9419335557a05
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index a1eee28..4a34b27 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -552,6 +552,7 @@
_(aten, poisson) \
_(aten, polygamma) \
_(aten, pow) \
+_(aten, float_power) \
_(aten, prelu) \
_(aten, prelu_backward) \
_(aten, prod) \
diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp
index ca5d184..bfc5f91 100644
--- a/aten/src/ATen/native/Pow.cpp
+++ b/aten/src/ATen/native/Pow.cpp
@@ -81,6 +81,48 @@
return native::pow_out(result, base, exp);
}
+Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) {
+ auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ?
+ at::kComplexDouble : at::kDouble;
+ TORCH_CHECK(result.scalar_type() == dtype,
+ "output type ", result.scalar_type(), "is not the desired output type ", dtype);
+
+ return at::pow_out(result, base.to(dtype), exp.to(dtype));
+}
+
+Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) {
+ return at::float_power_out(result, base, c10::scalar_to_tensor(exp, base.device()));
+}
+
+Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) {
+ return at::float_power_out(result, c10::scalar_to_tensor(base, exp.device()), exp);
+}
+
+Tensor float_power(const Tensor& base, const Tensor& exp) {
+ auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
+ return at::pow(base.to(dtype), exp.to(dtype));
+}
+
+Tensor float_power(const Tensor& base, Scalar exp) {
+ return at::float_power(base, c10::scalar_to_tensor(exp, base.device()));
+}
+
+Tensor float_power(Scalar base, const Tensor& exp) {
+ return at::float_power(c10::scalar_to_tensor(base, exp.device()), exp);
+}
+
+Tensor& float_power_(Tensor& base, const Tensor& exp) {
+ auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
+ TORCH_CHECK(base.scalar_type() == dtype,
+ "self tensor type ", base.scalar_type(), "is not the desired type ", dtype);
+
+ return base.pow_(exp.to(dtype));
+}
+
+Tensor& float_power_(Tensor& base, Scalar exp) {
+ return base.float_power_(c10::scalar_to_tensor(exp, base.device()));
+}
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 9a289cb..2d6e570 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6872,6 +6872,47 @@
CPU, CUDA: pow
SparseCPU, SparseCUDA: pow_sparse_scalar
+- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+ dispatch:
+ Math: float_power_out
+
+- func: float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+ dispatch:
+ Math: float_power
+
+- func: float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+ dispatch:
+ Math: float_power_out
+
+- func: float_power.Scalar(Scalar self, Tensor exponent) -> Tensor
+ use_c10_dispatcher: full
+ dispatch:
+ Math: float_power
+
+- func: float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
+ dispatch:
+ Math: float_power_out
+
+- func: float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+ dispatch:
+ Math: float_power
+
+- func: float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
+ use_c10_dispatcher: full
+ variants: method
+ dispatch:
+ Math: float_power_
+
+- func: float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
+ use_c10_dispatcher: full
+ variants: method
+ dispatch:
+ Math: float_power_
+
- func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
variants: method
dispatch:
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index 730a385..d7b0af7 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -322,6 +322,8 @@
.. automethod:: fliplr
.. automethod:: flipud
.. automethod:: float
+ .. automethod:: float_power
+ .. automethod:: float_power_
.. automethod:: floor
.. automethod:: floor_
.. automethod:: floor_divide
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index aab84fc..4399e63 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -296,6 +296,7 @@
exp2
expm1
fix
+ float_power
floor
floor_divide
fmod
diff --git a/test/test_torch.py b/test/test_torch.py
index f417d10..86dd202 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -6205,6 +6205,102 @@
torch.pow(m1, 1, out=out)
self.assertEqual(out, m1)
+ @dtypes(*list(product(torch.testing.get_all_dtypes(include_bool=False),
+ torch.testing.get_all_dtypes(include_bool=False))))
+ def test_float_power(self, device, dtypes):
+ def to_np(value):
+ if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16:
+ return value.to(torch.float).cpu().numpy()
+ return value.cpu().numpy() if isinstance(value, torch.Tensor) else value
+
+ base_dtype = dtypes[0]
+ exp_dtype = dtypes[1]
+ out_dtype = torch.complex128 if base_dtype.is_complex or exp_dtype.is_complex else torch.float64
+
+ base = make_tensor((30,), device, base_dtype, low=1, high=100)
+ # Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0
+ # Related: https://github.com/pytorch/pytorch/issues/48000
+ # base[0] = base[3] = base[7] = 0
+ exp = make_tensor((30,), device, exp_dtype, low=-2, high=2)
+ exp[0] = exp[4] = exp[6] = 0
+
+ expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp)))
+
+ exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2]
+ complex_exponents = exponents + [-2.5j, -1.0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j]
+
+ for op in (torch.float_power, torch.Tensor.float_power, torch.Tensor.float_power_):
+
+ # Case of Tensor x Tensor
+ if op is torch.Tensor.float_power_ and base_dtype != out_dtype:
+ with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
+ op(base.clone(), exp)
+ else:
+ result = op(base.clone(), exp)
+ self.assertEqual(expected, result)
+
+ if op is torch.float_power:
+ out = torch.empty_like(base).to(device=device, dtype=out_dtype)
+ op(base, exp, out=out)
+ self.assertEqual(expected, out)
+
+ # Case of Tensor x Scalar
+ for i in complex_exponents if exp_dtype.is_complex else exponents:
+ out_dtype_scalar_exp = torch.complex128 if base_dtype.is_complex or type(i) == complex else torch.float64
+ expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))
+
+ if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp:
+ with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
+ op(base.clone(), i)
+ else:
+ result = op(base.clone(), i)
+ self.assertEqual(expected_scalar_exp, result)
+
+ if op is torch.float_power:
+ out = torch.empty_like(base).to(device=device, dtype=out_dtype_scalar_exp)
+ op(base, i, out=out)
+ self.assertEqual(expected_scalar_exp, out)
+
+ # Case of Scalar x Tensor
+ for i in complex_exponents if base_dtype.is_complex else exponents:
+ out_dtype_scalar_base = torch.complex128 if exp_dtype.is_complex or type(i) == complex else torch.float64
+ expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))
+
+ result = torch.float_power(i, exp)
+ self.assertEqual(expected_scalar_base, result)
+
+ out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base)
+ torch.float_power(i, exp, out=out)
+ self.assertEqual(expected_scalar_base, out)
+
+ def test_float_power_exceptions(self, device):
+ def _promo_helper(x, y):
+ for i in (x, y):
+ if type(i) == complex:
+ return torch.complex128
+ elif type(i) == torch.Tensor and i.is_complex():
+ return torch.complex128
+ return torch.double
+
+ test_cases = ((torch.tensor([-2, -1, 0, 1, 2], device=device), -.25),
+ (torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 2.))
+ for base, exp in test_cases:
+ for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble):
+ out = torch.empty(1, device=device, dtype=out_dtype)
+ required_dtype = _promo_helper(base, exp)
+
+ if out.dtype == required_dtype:
+ torch.float_power(base, exp, out=out)
+ else:
+ with self.assertRaisesRegex(RuntimeError, "is not the desired output type"):
+ torch.float_power(base, exp, out=out)
+
+ if base.dtype == required_dtype:
+ torch.Tensor.float_power_(base.clone(), exp)
+ else:
+ with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
+ torch.Tensor.float_power_(base.clone(), exp)
+
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@onlyOnCPUAndCUDA
@dtypes(torch.int8, torch.int16, torch.int32, torch.int64)
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index c11a5b4..2d8533b 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -2671,6 +2671,20 @@
In-place version of :meth:`~Tensor.pow`
""")
+add_docstr_all('float_power',
+ r"""
+float_power(exponent) -> Tensor
+
+See :func:`torch.float_power`
+""")
+
+add_docstr_all('float_power_',
+ r"""
+float_power_(exponent) -> Tensor
+
+In-place version of :meth:`~Tensor.float_power`
+""")
+
add_docstr_all('prod',
r"""
prod(dim=None, keepdim=False, dtype=None) -> Tensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 545578c..3b6ee12 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -6367,6 +6367,46 @@
tensor([ 2., 4., 8., 16.])
""".format(**common_args))
+add_docstr(torch.float_power,
+ r"""
+float_power(input, exponent, *, out=None) -> Tensor
+
+Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision.
+If neither input is complex returns a ``torch.float64`` tensor,
+and if one or more inputs is complex returns a ``torch.complex128`` tensor.
+
+.. note::
+ This function always computes in double precision, unlike :func:`torch.pow`,
+ which implements more typical :ref:`type promotion <type-promotion-doc>`.
+ This is useful when the computation needs to be performed in a wider or more precise dtype,
+ or the results of the computation may contain fractional values not representable in the input dtypes,
+ like when an integer base is raised to a negative integer exponent.
+
+Args:
+ input (Tensor or Number): the base value(s)
+ exponent (Tensor or Number): the exponent value(s)
+
+Keyword args:
+ {out}
+
+Example::
+
+ >>> a = torch.randint(10, (4,))
+ >>> a
+ tensor([6, 4, 7, 1])
+ >>> torch.float_power(a, 2)
+ tensor([36., 16., 49., 1.], dtype=torch.float64)
+
+ >>> a = torch.arange(1, 5)
+ >>> a
+ tensor([ 1, 2, 3, 4])
+ >>> exp = torch.tensor([2, -3, 4, -5])
+ >>> exp
+ tensor([ 2, -3, 4, -5])
+ >>> torch.float_power(a, exp)
+ tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64)
+""".format(**common_args))
+
add_docstr(torch.prod,
r"""
prod(input, *, dtype=None) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index 2c48e71..eb863c7 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -366,6 +366,7 @@
torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
torch.floor: lambda input, out=None: -1,
torch.floor_divide: lambda input, other: -1,
+ torch.float_power: lambda input, exponent, out=None: -1,
torch.fmod: lambda input, other, out=None: -1,
torch.frac: lambda input, out=None: -1,
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 36e7fb6..8850c3f 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -694,6 +694,14 @@
('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14,), 'complex_scalar_constant', (True,)),
('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14j,), 'complex_imaginary_exponent', (True,)),
('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')),
+ ('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), ''),
+ ('float_power', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'),
+ ('float_power', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'),
+ ('float_power', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'),
+ ('float_power', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar'),
+ ('float_power', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs'),
+ ('float_power', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs'),
+ ('float_power', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'),
('transpose', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]),
('transpose', (), (0, 0), 'scalar', (False,)),
('transpose', (1,), (0, 0), '1d', (False,)),