[special] migrate xlogy (#60641)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/50345

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60641

Reviewed By: gchanan

Differential Revision: D29709306

Pulled By: mruberry

fbshipit-source-id: e8a5f64009a895a25618637de40b55cf36b8f794
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 0c89f03..f30dede 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -445,8 +445,6 @@
 _(aten, logcumsumexp) \
 _(aten, logdet) \
 _(aten, logspace) \
-_(aten, xlogy) \
-_(aten, special_xlog1py) \
 _(aten, lstm) \
 _(aten, lstm_cell) \
 _(aten, lstsq) \
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 7dad347..3a18b4c 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -354,6 +354,9 @@
   _(aten, special_i0e)               \
   _(aten, special_i1)                \
   _(aten, special_i1e)               \
+  _(aten, xlogy)                     \
+  _(aten, special_xlogy)             \
+  _(aten, special_xlog1py)           \
   _(aten, log_softmax)               \
   _(aten, special_log_softmax)       \
   _(aten, special_zeta)              \
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index 8a39ecb..bbceb5b 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -1204,5 +1204,29 @@
   return at::xlogy_out(x, x, wrapped_scalar_tensor(y));
 }
 
+Tensor& special_xlogy_out(const Tensor& self, const Tensor& other, Tensor& result) {
+  return at::xlogy_out(result, self, other);
+}
+
+Tensor& special_xlogy_out(const Scalar& self, const Tensor& other, Tensor& result) {
+  return at::xlogy_out(result, self, other);
+}
+
+Tensor& special_xlogy_out(const Tensor& self, const Scalar& other, Tensor& result) {
+  return at::xlogy_out(result, self, other);
+}
+
+Tensor special_xlogy(const Tensor& x, const Tensor& y) {
+  return at::xlogy(x, y);
+}
+
+Tensor special_xlogy(const Scalar& x, const Tensor& y) {
+  return at::xlogy(x, y);
+}
+
+Tensor special_xlogy(const Tensor& x, const Scalar& y) {
+  return at::xlogy(x, y);
+}
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 4fd94b9..78511135 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9761,6 +9761,36 @@
   dispatch:
     CompositeExplicitAutograd: special_xlog1py_out
 
+- func: special_xlogy(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
 - func: special_zeta(Tensor self, Tensor other) -> Tensor
   device_check: NoCheck   # TensorIterator
   python_module: special
diff --git a/docs/source/special.rst b/docs/source/special.rst
index db77298..c53eae4 100644
--- a/docs/source/special.rst
+++ b/docs/source/special.rst
@@ -42,4 +42,5 @@
 .. autofunction:: round
 .. autofunction:: sinc
 .. autofunction:: xlog1py
+.. autofunction:: xlogy
 .. autofunction:: zeta
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 2955bba..02ee378 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -5106,43 +5106,8 @@
            r"""
 xlogy(input, other, *, out=None) -> Tensor
 
-Computes ``input * log(other)`` with the following cases.
-
-.. math::
-    \text{out}_{i} = \begin{cases}
-        \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\
-        0 & \text{if } \text{input}_{i} = 0.0 \\
-        \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise}
-    \end{cases}
-
-Similar to SciPy's `scipy.special.xlogy`.
-
-""" + r"""
-
-Args:
-    input (Number or Tensor) : Multiplier
-    other (Number or Tensor) : Argument
-
-.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor.
-
-Keyword args:
-    {out}
-
-Example::
-
-    >>> x = torch.zeros(5,)
-    >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
-    >>> torch.xlogy(x, y)
-    tensor([0., 0., 0., 0., nan])
-    >>> x = torch.tensor([1, 2, 3])
-    >>> y = torch.tensor([3, 2, 1])
-    >>> torch.xlogy(x, y)
-    tensor([1.0986, 1.3863, 0.0000])
-    >>> torch.xlogy(x, 4)
-    tensor([1.3863, 2.7726, 4.1589])
-    >>> torch.xlogy(2, y)
-    tensor([2.1972, 1.3863, 0.0000])
-""".format(**common_args))
+Alias for :func:`torch.special.xlogy`.
+""")
 
 add_docstr(torch.logical_and,
            r"""
diff --git a/torch/csrc/api/include/torch/special.h b/torch/csrc/api/include/torch/special.h
index 55153a5..387a031 100644
--- a/torch/csrc/api/include/torch/special.h
+++ b/torch/csrc/api/include/torch/special.h
@@ -221,6 +221,39 @@
   return torch::special_expm1_out(result, self);
 }
 
+/// Computes x * log(y) for inputs, elementwise
+/// See https://pytorch.org/docs/master/special.html#torch.special.xlogy.
+///
+/// Example:
+/// ```
+/// auto x = torch::randn(128, dtype=kDouble);
+/// auto y = torch::randn(128, dtype=kDouble);
+/// torch::special::xlogy(x, y);
+/// ```
+inline Tensor xlogy(const Tensor& self, const Tensor& other) {
+  return torch::special_xlogy(self, other);
+}
+
+inline Tensor xlogy(const Scalar& self, const Tensor& other) {
+  return torch::special_xlogy(self, other);
+}
+
+inline Tensor xlogy(const Tensor& self, const Scalar& other) {
+  return torch::special_xlogy(self, other);
+}
+
+inline Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) {
+  return torch::special_xlogy_out(result, self, other);
+}
+
+inline Tensor& xlogy_out(Tensor& result, const Scalar& self, const Tensor& other) {
+  return torch::special_xlogy_out(result, self, other);
+}
+
+inline Tensor& xlogy_out(Tensor& result, const Tensor& self, const Scalar& other) {
+  return torch::special_xlogy_out(result, self, other);
+}
+
 /// Computes x * log1p(y) for inputs, elementwise
 /// See https://pytorch.org/docs/master/special.html#torch.special.xlog1py.
 ///
diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp
index e9c462f..e976cd6 100644
--- a/torch/csrc/jit/passes/normalize_ops.cpp
+++ b/torch/csrc/jit/passes/normalize_ops.cpp
@@ -124,6 +124,7 @@
       {aten::special_digamma, aten::digamma},
       {aten::special_psi, aten::digamma},
       {aten::special_i0, aten::i0},
+      {aten::special_xlogy, aten::xlogy},
       {aten::special_log_softmax, aten::log_softmax},
       {aten::orgqr, aten::linalg_householder_product},
       {aten::special_gammaln, aten::lgamma}};
diff --git a/torch/overrides.py b/torch/overrides.py
index aaf94c4..986050f 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -572,7 +572,7 @@
         torch.logaddexp: lambda input, other, out=None: -1,
         torch.logaddexp2: lambda input, other, out=None: -1,
         torch.logdet: lambda input: -1,
-        torch.xlogy: lambda x, y: -1,
+        torch.xlogy: lambda x, y, out=None: -1,
         torch.logical_and: lambda input, other, out=None: -1,
         torch.logical_not: lambda input, out=None: -1,
         torch.logical_or: lambda input, other, out=None: -1,
@@ -925,6 +925,7 @@
         torch.special.sinc: lambda input: -1,
         torch.special.ndtri: lambda input: -1,
         torch.special.ndtr: lambda input: -1,
+        torch.special.xlogy: lambda input, other, out=None: -1,
         torch.special.xlog1py: lambda input, other, out=None: -1,
         torch.special.zeta: lambda self, other, out=None: -1,
         torch.t: lambda input: -1,
diff --git a/torch/special/__init__.py b/torch/special/__init__.py
index 99993a3..1ba0c50 100644
--- a/torch/special/__init__.py
+++ b/torch/special/__init__.py
@@ -338,6 +338,48 @@
     tensor([2.7726, 2.1972, 1.3863])
 """.format(**common_args))
 
+xlogy = _add_docstr(_special.special_xlogy,
+                    r"""
+xlogy(input, other, *, out=None) -> Tensor
+
+Computes ``input * log(other)`` with the following cases.
+
+.. math::
+    \text{out}_{i} = \begin{cases}
+        \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\
+        0 & \text{if } \text{input}_{i} = 0.0 \\
+        \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise}
+    \end{cases}
+
+Similar to SciPy's `scipy.special.xlogy`.
+
+""" + r"""
+
+Args:
+    input (Number or Tensor) : Multiplier
+    other (Number or Tensor) : Argument
+
+.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor.
+
+Keyword args:
+    {out}
+
+Example::
+
+    >>> x = torch.zeros(5,)
+    >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
+    >>> torch.special.xlogy(x, y)
+    tensor([0., 0., 0., 0., nan])
+    >>> x = torch.tensor([1, 2, 3])
+    >>> y = torch.tensor([3, 2, 1])
+    >>> torch.special.xlogy(x, y)
+    tensor([1.0986, 1.3863, 0.0000])
+    >>> torch.special.xlogy(x, 4)
+    tensor([1.3863, 2.7726, 4.1589])
+    >>> torch.special.xlogy(2, y)
+    tensor([2.1972, 1.3863, 0.0000])
+""".format(**common_args))
+
 i0 = _add_docstr(_special.special_i0,
                  r"""
 i0(input, *, out=None) -> Tensor
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 887a367..7ef8c95 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -7428,10 +7428,8 @@
            assert_autodiffed=True,
            ),
     OpInfo('xlogy',
-           dtypes=all_types_and(torch.bool),
-           dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16),
-           dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
-           supports_inplace_autograd=True,
+           aliases=('special.xlogy',),
+           dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
            supports_forward_ad=True,
            safe_casts_outputs=True,
            sample_inputs_func=sample_inputs_xlogy),