[numpy] Add `torch.xlogy` (#48777)

Summary:
Reference https://github.com/pytorch/pytorch/issues/38349
Fixes https://github.com/pytorch/pytorch/issues/22656

TODO:
* [x] Add docs
* [x] Add tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48777

Reviewed By: ngimel

Differential Revision: D25681346

Pulled By: mruberry

fbshipit-source-id: 369e0a29ac8a2c44de95eec115bf75943fe1aa45
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 7b0759c..644d75c 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -436,6 +436,7 @@
 _(aten, logit) \
 _(aten, logspace) \
 _(aten, logsumexp) \
+_(aten, xlogy) \
 _(aten, lstm) \
 _(aten, lstm_cell) \
 _(aten, lstsq) \
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index e8751be..9103eaf 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -62,6 +62,7 @@
 DEFINE_DISPATCH(nextafter_stub);
 DEFINE_DISPATCH(heaviside_stub);
 DEFINE_DISPATCH(copysign_stub);
+DEFINE_DISPATCH(xlogy_stub);
 
 static Tensor wrapped_scalar_tensor(Scalar scalar) {
   auto tensor = scalar_to_tensor(scalar);
@@ -1101,5 +1102,42 @@
   return at::ldexp_out(self, self, other);
 }
 
+Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) {
+  auto iter = TensorIterator::binary_float_op(result, self, other);
+  xlogy_stub(iter.device_type(), iter);
+  return result;
+}
+
+Tensor& xlogy_out(Tensor& result, Scalar self, const Tensor& other) {
+  return at::xlogy_out(result, c10::scalar_to_tensor(self, other.device()), other);
+}
+
+Tensor& xlogy_out(Tensor& result, const Tensor& self, Scalar other) {
+  return at::xlogy_out(result, self, c10::scalar_to_tensor(other, self.device()));
+}
+
+Tensor xlogy(const Tensor& x, const Tensor& y) {
+  Tensor result;
+  auto iter = TensorIterator::binary_float_op(result, x, y);
+  xlogy_stub(iter.device_type(), iter);
+  return iter.output();
+}
+
+Tensor xlogy(Scalar x, const Tensor& y) {
+  return at::xlogy(c10::scalar_to_tensor(x, y.device()), y);
+}
+
+Tensor xlogy(const Tensor& x, Scalar y) {
+  return at::xlogy(x, c10::scalar_to_tensor(y, x.device()));
+}
+
+Tensor& xlogy_(Tensor& x, const Tensor& y) {
+  return at::xlogy_out(x, x, y);
+}
+
+Tensor& xlogy_(Tensor& x, Scalar y) {
+  return at::xlogy_out(x, x, c10::scalar_to_tensor(y, x.device()));
+}
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h
index 1fdb805..1916118 100644
--- a/aten/src/ATen/native/BinaryOps.h
+++ b/aten/src/ATen/native/BinaryOps.h
@@ -74,5 +74,6 @@
 DECLARE_DISPATCH(binary_fn, nextafter_stub);
 DECLARE_DISPATCH(binary_fn, heaviside_stub);
 DECLARE_DISPATCH(binary_fn, copysign_stub);
+DECLARE_DISPATCH(binary_fn, xlogy_stub);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
index ddfa8a2..3dfe130 100644
--- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
@@ -818,6 +818,20 @@
   });
 }
 
+void xlogy_kernel(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "xlogy_cpu", [&]() {
+    cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t {
+      if (at::_isnan(y)){
+        return NAN;
+      }
+      if (x == 0){
+        return 0;
+      }
+      return x * std::log(y);
+    });
+  });
+}
+
 } // namespace
 
 REGISTER_DISPATCH(add_stub, &add_kernel);
@@ -859,6 +873,7 @@
 REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
 REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
 REGISTER_DISPATCH(copysign_stub, &copysign_kernel);
+REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
index c0efde1..2379877 100644
--- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
+++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
@@ -3,6 +3,7 @@
 #include <ATen/native/cuda/Loops.cuh>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/BinaryOps.h>
+#include <ATen/NumericUtils.h>
 
 // NOTE: CUDA on Windows requires that the enclosing function
 // of a __device__ lambda not have internal linkage.
@@ -29,8 +30,23 @@
   });
 }
 
+void xlogy_kernel_cuda(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_cuda", [&]() {
+    gpu_kernel(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
+      if (at::_isnan(y)){
+        return NAN;
+      }
+      if (x == 0){
+        return 0;
+      }
+      return x * std::log(y);
+    });
+  });
+}
+
 REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
 REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
+REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda);
 
 // DO NOT ADD ANY NEW KERNELS HERE
 // CUDA compilation times grow quickly.  It's perfectly acceptable to have a file per kernel.
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 78ad112..9c0053f 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2560,6 +2560,56 @@
   dispatch:
     DefaultBackend: logaddexp2
 
+- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
+  use_c10_dispatcher: full
+  variants: function, method
+  dispatch:
+    CPU, CUDA: xlogy
+
+- func: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
+  use_c10_dispatcher: full
+  variants: function
+  dispatch:
+    CPU, CUDA: xlogy
+
+- func: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
+  use_c10_dispatcher: full
+  variants: function, method
+  dispatch:
+    CPU, CUDA: xlogy
+
+# xlogy: inplace variant
+- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  use_c10_dispatcher: full
+  variants: function, method
+  dispatch:
+    CPU, CUDA: xlogy_
+
+- func: xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  use_c10_dispatcher: full
+  variants: function, method
+  dispatch:
+    CPU, CUDA: xlogy_
+
+# xlogy: out variant
+- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
+  variants: function
+  dispatch:
+    CPU, CUDA: xlogy_out
+
+- func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
+  variants: function
+  dispatch:
+    CPU, CUDA: xlogy_out
+
+- func: xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
+  variants: function
+  dispatch:
+    CPU, CUDA: xlogy_out
+
 - func: logdet(Tensor self) -> Tensor
   use_c10_dispatcher: full
   variants: function, method
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index f737537..315cc9d 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -645,6 +645,8 @@
    .. automethod:: view
    .. automethod:: view_as
    .. automethod:: where
+   .. automethod:: xlogy
+   .. automethod:: xlogy_
    .. automethod:: zero_
 
 .. class:: BoolTensor()
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index c82035e..3057339 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -350,6 +350,7 @@
     tanh
     true_divide
     trunc
+    xlogy
 
 Reduction Ops
 ~~~~~~~~~~~~~~~~~~~~~~
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 3d29529..a8a1305 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -10,7 +10,7 @@
 import warnings
 from copy import deepcopy
 from collections import OrderedDict
-from itertools import product
+from itertools import product, permutations
 from operator import mul
 from functools import reduce
 import torch
@@ -7396,6 +7396,54 @@
         self._test_atleast(device, torch.atleast_2d)
         self._test_atleast(device, torch.atleast_3d)
 
+    def test_xlogy(self, device):
+
+        def _tensor_tensor_helper(x, y):
+            gradcheck(lambda x, y: torch.xlogy(x, y), (x, y))
+            gradgradcheck(lambda x, y: torch.xlogy(x, y), (x, y))
+
+            with torch.no_grad():
+                x = x.clone()
+                x[torch.rand_like(x) > 0.5] = 0
+
+            gradcheck(lambda y: torch.xlogy(x, y), (y))
+            gradgradcheck(lambda y: torch.xlogy(x, y), (y))
+
+        shapes = ((4,), (1, 4), (1, 1, 4), (1, 1, 1, 4))
+
+        # For broadcastible shapes and scalar.
+        for x_shape, y_shape in permutations(shapes, 2):
+            x = torch.rand(*x_shape, dtype=torch.double, device=device, requires_grad=True)
+            y = torch.rand(*y_shape, dtype=torch.double, device=device, requires_grad=True)
+
+            _tensor_tensor_helper(x, y)
+            _tensor_tensor_helper(y, x)
+
+            gradcheck(lambda y: torch.xlogy(0, y), (y))
+            gradgradcheck(lambda y: torch.xlogy(0, y), (y))
+
+            gradcheck(lambda y: torch.xlogy(2, y), (y))
+            gradgradcheck(lambda y: torch.xlogy(2, y), (y))
+            gradcheck(lambda y: torch.xlogy(y, 2), (y))
+            gradgradcheck(lambda y: torch.xlogy(y, 2), (y))
+
+        # Different shape
+        x = torch.rand(2, 3, 4, 5, dtype=torch.double, device=device, requires_grad=True)
+        y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
+        _tensor_tensor_helper(x, y)
+        _tensor_tensor_helper(y, x)
+        _tensor_tensor_helper(x, x)
+        _tensor_tensor_helper(y, y)
+
+        # Same shape
+        x = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
+        y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
+        _tensor_tensor_helper(x, y)
+        _tensor_tensor_helper(y, x)
+        _tensor_tensor_helper(x, x)
+        _tensor_tensor_helper(y, y)
+
+
 class TestMultithreadAutograd(TestCase):
     def _run_py_multithread_fn(self, fn, args=(), num_threads=10, kwargs=None):
         threads = []
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 9888c29..5739fb5 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -8,15 +8,19 @@
 import unittest
 import warnings
 import operator
+from functools import partial
 
 from torch._six import inf, nan
 from torch.testing._internal.common_utils import (
     TestCase, iter_indices, TEST_WITH_ASAN, run_tests,
-    torch_to_numpy_dtype_dict, make_tensor)
+    torch_to_numpy_dtype_dict, make_tensor, TEST_SCIPY)
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA,
     dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA,
-    skipCUDAIfRocm)
+    skipCUDAIfRocm, skipIf)
+
+if TEST_SCIPY:
+    import scipy.special
 
 # TODO: remove this
 def _generate_input(shape, dtype, device, with_extremal):
@@ -2488,6 +2492,103 @@
                     with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
                         torch.Tensor.float_power_(base.clone(), exp)
 
+    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
+    @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False),
+                     torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False)))
+    def test_xlogy(self, device, dtypes):
+        def out_variant_helper(torch_fn, x, y):
+            expected = torch_fn(x, y)
+            out = torch.empty_like(expected)
+            torch_fn(x, y, out=out)
+            self.assertEqual(expected, out)
+
+        def inplace_variant_helper(x, y):
+            if x.dtype in torch.testing.get_all_int_dtypes() + [torch.bool]:
+                with self.assertRaisesRegex(RuntimeError,
+                                            "can't be cast to the desired output type"):
+                    x.clone().xlogy_(y)
+            else:
+                expected = torch.empty_like(x)
+                torch.xlogy(x, y, out=expected)
+                inplace_out = x.clone().xlogy_(y)
+                self.assertEqual(expected, inplace_out)
+
+        x_dtype, y_dtype = dtypes
+
+        # Tensor-Tensor Test (tensor of same and different shape)
+        x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000)
+        y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000)
+        z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000)
+
+        torch_fn = partial(torch.xlogy, x)
+        reference_fn = partial(scipy.special.xlogy, x.cpu().numpy())
+
+        self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False)
+        self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False)
+        self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False)
+        out_variant_helper(torch.xlogy, x, x)
+        out_variant_helper(torch.xlogy, x, y)
+        out_variant_helper(torch.xlogy, x, z)
+        inplace_variant_helper(x, x)
+        inplace_variant_helper(x, y)
+        inplace_variant_helper(x, z)
+
+        # Scalar-Tensor Test
+        torch_fn = partial(torch.xlogy, 3.14)
+        reference_fn = partial(scipy.special.xlogy, 3.14)
+
+        self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False)
+        self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False)
+        self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False)
+        out_variant_helper(torch.xlogy, 3.14, x)
+        out_variant_helper(torch.xlogy, 3.14, y)
+        out_variant_helper(torch.xlogy, 3.14, z)
+
+        # Special Values Tensor-Tensor
+        t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device)
+        zeros = torch.zeros(6, dtype=y_dtype, device=device)
+
+        torch_fn = partial(torch.xlogy, zeros)
+        reference_fn = partial(scipy.special.xlogy, zeros.cpu().numpy())
+        self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
+        out_variant_helper(torch.xlogy, zeros, t)
+        inplace_variant_helper(zeros, t)
+
+        # Special Values Scalar-Tensor
+        torch_fn = partial(torch.xlogy, 0)
+        reference_fn = partial(scipy.special.xlogy, 0)
+        self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
+        out_variant_helper(torch.xlogy, 0, t)
+
+    @skipIf(not TEST_SCIPY, "Scipy required for the test.")
+    def test_xlogy_bfloat16(self, device):
+        def _compare_helper(x, y):
+            x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy()
+            y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy()
+            expected = torch.from_numpy(scipy.special.xlogy(x_np, y_np))
+            actual = torch.xlogy(x, y)
+            self.assertEqual(expected, actual, exact_dtype=False)
+
+        x_dtype, y_dtype = torch.bfloat16, torch.bfloat16
+
+        # Tensor-Tensor Test (tensor of same and different shape)
+        x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000)
+        y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000)
+        z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000)
+
+        _compare_helper(x, x)
+        _compare_helper(x, y)
+        _compare_helper(x, z)
+
+        _compare_helper(x, 3.14)
+        _compare_helper(y, 3.14)
+        _compare_helper(z, 3.14)
+
+        # Special Values Tensor-Tensor
+        t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device)
+        zeros = torch.tensor(5, dtype=y_dtype, device=device)
+        _compare_helper(t, zeros)
+        _compare_helper(t, 0.)
 
 tensor_binary_ops = [
     '__lt__', '__le__',
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 7a619b9..9f68622 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -647,6 +647,16 @@
   self: grad / (1 + pow(2, other - self))
   other: grad / (1 + pow(2, self - other))
 
+- name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
+  self: grad * at::xlogy((self != 0), other)
+  other: grad * self / other
+
+- name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
+  other: grad * self / other
+
+- name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
+  self: grad * at::xlogy((self != 0), other)
+
 - name: logdet(Tensor self) -> Tensor
   self: logdet_backward(grad, self, result)
 
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index f081b59..e944320 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -4472,6 +4472,20 @@
 Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
 """)
 
+add_docstr_all('xlogy',
+               r"""
+xlogy(other) -> Tensor
+
+See :func:`torch.xlogy`
+""")
+
+add_docstr_all('xlogy_',
+               r"""
+xlogy_(other) -> Tensor
+
+In-place version of :meth:`~Tensor.xlogy`
+""")
+
 add_docstr_all('masked_fill',
                r"""
 masked_fill(mask, value) -> Tensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 91da41b..0294942 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -4371,6 +4371,48 @@
     {out}
 """.format(**common_args))
 
+add_docstr(torch.xlogy,
+           r"""
+xlogy(input, other, *, out=None) -> Tensor
+
+Computes ``input * log(other)`` with the following cases.
+
+.. math::
+    \text{out}_{i} = \begin{cases}
+        \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\
+        0 & \text{if } \text{input}_{i} = 0.0 \\
+        \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise}
+    \end{cases}
+
+Similar to SciPy's `scipy.special.xlogy`.
+
+""" + r"""
+
+Args:
+    input (Number or Tensor)
+    other (Number or Tensor)
+
+.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor.
+
+Keyword args:
+    {out}
+
+Example::
+
+    >>> x = torch.zeros(5,)
+    >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
+    >>> torch.xlogy(x, y)
+    tensor([0., 0., 0., 0., nan])
+    >>> x = torch.tensor([1, 2, 3])
+    >>> y = torch.tensor([3, 2, 1])
+    >>> torch.xlogy(x, y)
+    tensor([1.0986, 1.3863, 0.0000])
+    >>> torch.xlogy(x, 4)
+    tensor([1.3863, 2.7726, 4.1589])
+    >>> torch.xlogy(2, y)
+    tensor([2.1972, 1.3863, 0.0000])
+""".format(**common_args))
+
 add_docstr(torch.logical_and,
            r"""
 logical_and(input, other, *, out=None) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index c0e3463..d23e348 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -495,6 +495,7 @@
         torch.logaddexp: lambda input, other, out=None: -1,
         torch.logaddexp2: lambda input, other, out=None: -1,
         torch.logdet: lambda input: -1,
+        torch.xlogy: lambda x, y: -1,
         torch.logical_and: lambda input, other, out=None: -1,
         torch.logical_not: lambda input, out=None: -1,
         torch.logical_or: lambda input, other, out=None: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index ba29c42..55b97b3 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -291,12 +291,21 @@
     return (SampleInput((make_tensor((S, S), device, dtype,
                                      low=None, high=None,
                                      requires_grad=requires_grad),
-                        make_tensor((S, S), device, dtype,
-                                    low=None, high=None,
-                                    requires_grad=requires_grad),
-                        make_tensor((S, S), device, dtype,
-                                    low=None, high=None,
-                                    requires_grad=False))),)
+                         make_tensor((S, S), device, dtype,
+                                     low=None, high=None,
+                                     requires_grad=requires_grad),
+                         make_tensor((S, S), device, dtype,
+                                     low=None, high=None,
+                                     requires_grad=False))),)
+
+
+def sample_inputs_xlogy(self, device, dtype, requires_grad):
+    return (SampleInput((make_tensor((S, S), device, dtype,
+                                     low=None, high=None,
+                                     requires_grad=requires_grad),
+                         make_tensor((S, S), device, dtype,
+                                     low=0, high=None,
+                                     requires_grad=requires_grad))),)
 
 def np_sinc_with_fp16_as_fp32(x):
     # Wraps numpy's sinc function so that fp16 values are promoted to fp32
@@ -1084,6 +1093,14 @@
                                     dtypes=[torch.bfloat16]),),
                        assert_autodiffed=True,
                        promotes_integers_to_float=True),
+        OpInfo('xlogy',
+               dtypes=all_types_and(torch.bool),
+               dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16),
+               dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+               test_inplace_grad=True,
+               supports_tensor_out=True,
+               promotes_integers_to_float=True,
+               sample_inputs_func=sample_inputs_xlogy),
     ]
     op_db = op_db + op_db_scipy_reference