[special] migrate xlogy (#60641)
Summary:
Reference: https://github.com/pytorch/pytorch/issues/50345
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60641
Reviewed By: gchanan
Differential Revision: D29709306
Pulled By: mruberry
fbshipit-source-id: e8a5f64009a895a25618637de40b55cf36b8f794
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 0c89f03..f30dede 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -445,8 +445,6 @@
_(aten, logcumsumexp) \
_(aten, logdet) \
_(aten, logspace) \
-_(aten, xlogy) \
-_(aten, special_xlog1py) \
_(aten, lstm) \
_(aten, lstm_cell) \
_(aten, lstsq) \
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 7dad347..3a18b4c 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -354,6 +354,9 @@
_(aten, special_i0e) \
_(aten, special_i1) \
_(aten, special_i1e) \
+ _(aten, xlogy) \
+ _(aten, special_xlogy) \
+ _(aten, special_xlog1py) \
_(aten, log_softmax) \
_(aten, special_log_softmax) \
_(aten, special_zeta) \
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index 8a39ecb..bbceb5b 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -1204,5 +1204,29 @@
return at::xlogy_out(x, x, wrapped_scalar_tensor(y));
}
+Tensor& special_xlogy_out(const Tensor& self, const Tensor& other, Tensor& result) {
+ return at::xlogy_out(result, self, other);
+}
+
+Tensor& special_xlogy_out(const Scalar& self, const Tensor& other, Tensor& result) {
+ return at::xlogy_out(result, self, other);
+}
+
+Tensor& special_xlogy_out(const Tensor& self, const Scalar& other, Tensor& result) {
+ return at::xlogy_out(result, self, other);
+}
+
+Tensor special_xlogy(const Tensor& x, const Tensor& y) {
+ return at::xlogy(x, y);
+}
+
+Tensor special_xlogy(const Scalar& x, const Tensor& y) {
+ return at::xlogy(x, y);
+}
+
+Tensor special_xlogy(const Tensor& x, const Scalar& y) {
+ return at::xlogy(x, y);
+}
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 4fd94b9..78511135 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9761,6 +9761,36 @@
dispatch:
CompositeExplicitAutograd: special_xlog1py_out
+- func: special_xlogy(Tensor self, Tensor other) -> Tensor
+ device_check: NoCheck # TensorIterator
+ python_module: special
+ variants: function
+
+- func: special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor
+ device_check: NoCheck # TensorIterator
+ python_module: special
+ variants: function
+
+- func: special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor
+ device_check: NoCheck # TensorIterator
+ python_module: special
+ variants: function
+
+- func: special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+ device_check: NoCheck # TensorIterator
+ python_module: special
+ variants: function
+
+- func: special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+ device_check: NoCheck # TensorIterator
+ python_module: special
+ variants: function
+
+- func: special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+ device_check: NoCheck # TensorIterator
+ python_module: special
+ variants: function
+
- func: special_zeta(Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
python_module: special
diff --git a/docs/source/special.rst b/docs/source/special.rst
index db77298..c53eae4 100644
--- a/docs/source/special.rst
+++ b/docs/source/special.rst
@@ -42,4 +42,5 @@
.. autofunction:: round
.. autofunction:: sinc
.. autofunction:: xlog1py
+.. autofunction:: xlogy
.. autofunction:: zeta
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 2955bba..02ee378 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -5106,43 +5106,8 @@
r"""
xlogy(input, other, *, out=None) -> Tensor
-Computes ``input * log(other)`` with the following cases.
-
-.. math::
- \text{out}_{i} = \begin{cases}
- \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\
- 0 & \text{if } \text{input}_{i} = 0.0 \\
- \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise}
- \end{cases}
-
-Similar to SciPy's `scipy.special.xlogy`.
-
-""" + r"""
-
-Args:
- input (Number or Tensor) : Multiplier
- other (Number or Tensor) : Argument
-
-.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor.
-
-Keyword args:
- {out}
-
-Example::
-
- >>> x = torch.zeros(5,)
- >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
- >>> torch.xlogy(x, y)
- tensor([0., 0., 0., 0., nan])
- >>> x = torch.tensor([1, 2, 3])
- >>> y = torch.tensor([3, 2, 1])
- >>> torch.xlogy(x, y)
- tensor([1.0986, 1.3863, 0.0000])
- >>> torch.xlogy(x, 4)
- tensor([1.3863, 2.7726, 4.1589])
- >>> torch.xlogy(2, y)
- tensor([2.1972, 1.3863, 0.0000])
-""".format(**common_args))
+Alias for :func:`torch.special.xlogy`.
+""")
add_docstr(torch.logical_and,
r"""
diff --git a/torch/csrc/api/include/torch/special.h b/torch/csrc/api/include/torch/special.h
index 55153a5..387a031 100644
--- a/torch/csrc/api/include/torch/special.h
+++ b/torch/csrc/api/include/torch/special.h
@@ -221,6 +221,39 @@
return torch::special_expm1_out(result, self);
}
+/// Computes x * log(y) for inputs, elementwise
+/// See https://pytorch.org/docs/master/special.html#torch.special.xlogy.
+///
+/// Example:
+/// ```
+/// auto x = torch::randn(128, dtype=kDouble);
+/// auto y = torch::randn(128, dtype=kDouble);
+/// torch::special::xlogy(x, y);
+/// ```
+inline Tensor xlogy(const Tensor& self, const Tensor& other) {
+ return torch::special_xlogy(self, other);
+}
+
+inline Tensor xlogy(const Scalar& self, const Tensor& other) {
+ return torch::special_xlogy(self, other);
+}
+
+inline Tensor xlogy(const Tensor& self, const Scalar& other) {
+ return torch::special_xlogy(self, other);
+}
+
+inline Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) {
+ return torch::special_xlogy_out(result, self, other);
+}
+
+inline Tensor& xlogy_out(Tensor& result, const Scalar& self, const Tensor& other) {
+ return torch::special_xlogy_out(result, self, other);
+}
+
+inline Tensor& xlogy_out(Tensor& result, const Tensor& self, const Scalar& other) {
+ return torch::special_xlogy_out(result, self, other);
+}
+
/// Computes x * log1p(y) for inputs, elementwise
/// See https://pytorch.org/docs/master/special.html#torch.special.xlog1py.
///
diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp
index e9c462f..e976cd6 100644
--- a/torch/csrc/jit/passes/normalize_ops.cpp
+++ b/torch/csrc/jit/passes/normalize_ops.cpp
@@ -124,6 +124,7 @@
{aten::special_digamma, aten::digamma},
{aten::special_psi, aten::digamma},
{aten::special_i0, aten::i0},
+ {aten::special_xlogy, aten::xlogy},
{aten::special_log_softmax, aten::log_softmax},
{aten::orgqr, aten::linalg_householder_product},
{aten::special_gammaln, aten::lgamma}};
diff --git a/torch/overrides.py b/torch/overrides.py
index aaf94c4..986050f 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -572,7 +572,7 @@
torch.logaddexp: lambda input, other, out=None: -1,
torch.logaddexp2: lambda input, other, out=None: -1,
torch.logdet: lambda input: -1,
- torch.xlogy: lambda x, y: -1,
+ torch.xlogy: lambda x, y, out=None: -1,
torch.logical_and: lambda input, other, out=None: -1,
torch.logical_not: lambda input, out=None: -1,
torch.logical_or: lambda input, other, out=None: -1,
@@ -925,6 +925,7 @@
torch.special.sinc: lambda input: -1,
torch.special.ndtri: lambda input: -1,
torch.special.ndtr: lambda input: -1,
+ torch.special.xlogy: lambda input, other, out=None: -1,
torch.special.xlog1py: lambda input, other, out=None: -1,
torch.special.zeta: lambda self, other, out=None: -1,
torch.t: lambda input: -1,
diff --git a/torch/special/__init__.py b/torch/special/__init__.py
index 99993a3..1ba0c50 100644
--- a/torch/special/__init__.py
+++ b/torch/special/__init__.py
@@ -338,6 +338,48 @@
tensor([2.7726, 2.1972, 1.3863])
""".format(**common_args))
+xlogy = _add_docstr(_special.special_xlogy,
+ r"""
+xlogy(input, other, *, out=None) -> Tensor
+
+Computes ``input * log(other)`` with the following cases.
+
+.. math::
+ \text{out}_{i} = \begin{cases}
+ \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\
+ 0 & \text{if } \text{input}_{i} = 0.0 \\
+ \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise}
+ \end{cases}
+
+Similar to SciPy's `scipy.special.xlogy`.
+
+""" + r"""
+
+Args:
+ input (Number or Tensor) : Multiplier
+ other (Number or Tensor) : Argument
+
+.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor.
+
+Keyword args:
+ {out}
+
+Example::
+
+ >>> x = torch.zeros(5,)
+ >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
+ >>> torch.special.xlogy(x, y)
+ tensor([0., 0., 0., 0., nan])
+ >>> x = torch.tensor([1, 2, 3])
+ >>> y = torch.tensor([3, 2, 1])
+ >>> torch.special.xlogy(x, y)
+ tensor([1.0986, 1.3863, 0.0000])
+ >>> torch.special.xlogy(x, 4)
+ tensor([1.3863, 2.7726, 4.1589])
+ >>> torch.special.xlogy(2, y)
+ tensor([2.1972, 1.3863, 0.0000])
+""".format(**common_args))
+
i0 = _add_docstr(_special.special_i0,
r"""
i0(input, *, out=None) -> Tensor
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 887a367..7ef8c95 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -7428,10 +7428,8 @@
assert_autodiffed=True,
),
OpInfo('xlogy',
- dtypes=all_types_and(torch.bool),
- dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
- supports_inplace_autograd=True,
+ aliases=('special.xlogy',),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
safe_casts_outputs=True,
sample_inputs_func=sample_inputs_xlogy),