[numpy] Add `torch.xlogy` (#48777)
Summary:
Reference https://github.com/pytorch/pytorch/issues/38349
Fixes https://github.com/pytorch/pytorch/issues/22656
TODO:
* [x] Add docs
* [x] Add tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48777
Reviewed By: ngimel
Differential Revision: D25681346
Pulled By: mruberry
fbshipit-source-id: 369e0a29ac8a2c44de95eec115bf75943fe1aa45
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 7b0759c..644d75c 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -436,6 +436,7 @@
_(aten, logit) \
_(aten, logspace) \
_(aten, logsumexp) \
+_(aten, xlogy) \
_(aten, lstm) \
_(aten, lstm_cell) \
_(aten, lstsq) \
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index e8751be..9103eaf 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -62,6 +62,7 @@
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
DEFINE_DISPATCH(copysign_stub);
+DEFINE_DISPATCH(xlogy_stub);
static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
@@ -1101,5 +1102,42 @@
return at::ldexp_out(self, self, other);
}
+Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) {
+ auto iter = TensorIterator::binary_float_op(result, self, other);
+ xlogy_stub(iter.device_type(), iter);
+ return result;
+}
+
+Tensor& xlogy_out(Tensor& result, Scalar self, const Tensor& other) {
+ return at::xlogy_out(result, c10::scalar_to_tensor(self, other.device()), other);
+}
+
+Tensor& xlogy_out(Tensor& result, const Tensor& self, Scalar other) {
+ return at::xlogy_out(result, self, c10::scalar_to_tensor(other, self.device()));
+}
+
+Tensor xlogy(const Tensor& x, const Tensor& y) {
+ Tensor result;
+ auto iter = TensorIterator::binary_float_op(result, x, y);
+ xlogy_stub(iter.device_type(), iter);
+ return iter.output();
+}
+
+Tensor xlogy(Scalar x, const Tensor& y) {
+ return at::xlogy(c10::scalar_to_tensor(x, y.device()), y);
+}
+
+Tensor xlogy(const Tensor& x, Scalar y) {
+ return at::xlogy(x, c10::scalar_to_tensor(y, x.device()));
+}
+
+Tensor& xlogy_(Tensor& x, const Tensor& y) {
+ return at::xlogy_out(x, x, y);
+}
+
+Tensor& xlogy_(Tensor& x, Scalar y) {
+ return at::xlogy_out(x, x, c10::scalar_to_tensor(y, x.device()));
+}
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h
index 1fdb805..1916118 100644
--- a/aten/src/ATen/native/BinaryOps.h
+++ b/aten/src/ATen/native/BinaryOps.h
@@ -74,5 +74,6 @@
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);
+DECLARE_DISPATCH(binary_fn, xlogy_stub);
}} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
index ddfa8a2..3dfe130 100644
--- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
@@ -818,6 +818,20 @@
});
}
+void xlogy_kernel(TensorIterator& iter) {
+ AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "xlogy_cpu", [&]() {
+ cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t {
+ if (at::_isnan(y)){
+ return NAN;
+ }
+ if (x == 0){
+ return 0;
+ }
+ return x * std::log(y);
+ });
+ });
+}
+
} // namespace
REGISTER_DISPATCH(add_stub, &add_kernel);
@@ -859,6 +873,7 @@
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
REGISTER_DISPATCH(copysign_stub, ©sign_kernel);
+REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel);
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
index c0efde1..2379877 100644
--- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
+++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
@@ -3,6 +3,7 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
+#include <ATen/NumericUtils.h>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
@@ -29,8 +30,23 @@
});
}
+void xlogy_kernel_cuda(TensorIterator& iter) {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_cuda", [&]() {
+ gpu_kernel(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
+ if (at::_isnan(y)){
+ return NAN;
+ }
+ if (x == 0){
+ return 0;
+ }
+ return x * std::log(y);
+ });
+ });
+}
+
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
+REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda);
// DO NOT ADD ANY NEW KERNELS HERE
// CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel.
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 78ad112..9c0053f 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2560,6 +2560,56 @@
dispatch:
DefaultBackend: logaddexp2
+- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+ dispatch:
+ CPU, CUDA: xlogy
+
+- func: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
+ use_c10_dispatcher: full
+ variants: function
+ dispatch:
+ CPU, CUDA: xlogy
+
+- func: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+ dispatch:
+ CPU, CUDA: xlogy
+
+# xlogy: inplace variant
+- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+ use_c10_dispatcher: full
+ variants: function, method
+ dispatch:
+ CPU, CUDA: xlogy_
+
+- func: xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!)
+ use_c10_dispatcher: full
+ variants: function, method
+ dispatch:
+ CPU, CUDA: xlogy_
+
+# xlogy: out variant
+- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+ use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
+ variants: function
+ dispatch:
+ CPU, CUDA: xlogy_out
+
+- func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+ use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
+ variants: function
+ dispatch:
+ CPU, CUDA: xlogy_out
+
+- func: xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+ use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
+ variants: function
+ dispatch:
+ CPU, CUDA: xlogy_out
+
- func: logdet(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index f737537..315cc9d 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -645,6 +645,8 @@
.. automethod:: view
.. automethod:: view_as
.. automethod:: where
+ .. automethod:: xlogy
+ .. automethod:: xlogy_
.. automethod:: zero_
.. class:: BoolTensor()
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index c82035e..3057339 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -350,6 +350,7 @@
tanh
true_divide
trunc
+ xlogy
Reduction Ops
~~~~~~~~~~~~~~~~~~~~~~
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 3d29529..a8a1305 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -10,7 +10,7 @@
import warnings
from copy import deepcopy
from collections import OrderedDict
-from itertools import product
+from itertools import product, permutations
from operator import mul
from functools import reduce
import torch
@@ -7396,6 +7396,54 @@
self._test_atleast(device, torch.atleast_2d)
self._test_atleast(device, torch.atleast_3d)
+ def test_xlogy(self, device):
+
+ def _tensor_tensor_helper(x, y):
+ gradcheck(lambda x, y: torch.xlogy(x, y), (x, y))
+ gradgradcheck(lambda x, y: torch.xlogy(x, y), (x, y))
+
+ with torch.no_grad():
+ x = x.clone()
+ x[torch.rand_like(x) > 0.5] = 0
+
+ gradcheck(lambda y: torch.xlogy(x, y), (y))
+ gradgradcheck(lambda y: torch.xlogy(x, y), (y))
+
+ shapes = ((4,), (1, 4), (1, 1, 4), (1, 1, 1, 4))
+
+ # For broadcastible shapes and scalar.
+ for x_shape, y_shape in permutations(shapes, 2):
+ x = torch.rand(*x_shape, dtype=torch.double, device=device, requires_grad=True)
+ y = torch.rand(*y_shape, dtype=torch.double, device=device, requires_grad=True)
+
+ _tensor_tensor_helper(x, y)
+ _tensor_tensor_helper(y, x)
+
+ gradcheck(lambda y: torch.xlogy(0, y), (y))
+ gradgradcheck(lambda y: torch.xlogy(0, y), (y))
+
+ gradcheck(lambda y: torch.xlogy(2, y), (y))
+ gradgradcheck(lambda y: torch.xlogy(2, y), (y))
+ gradcheck(lambda y: torch.xlogy(y, 2), (y))
+ gradgradcheck(lambda y: torch.xlogy(y, 2), (y))
+
+ # Different shape
+ x = torch.rand(2, 3, 4, 5, dtype=torch.double, device=device, requires_grad=True)
+ y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
+ _tensor_tensor_helper(x, y)
+ _tensor_tensor_helper(y, x)
+ _tensor_tensor_helper(x, x)
+ _tensor_tensor_helper(y, y)
+
+ # Same shape
+ x = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
+ y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
+ _tensor_tensor_helper(x, y)
+ _tensor_tensor_helper(y, x)
+ _tensor_tensor_helper(x, x)
+ _tensor_tensor_helper(y, y)
+
+
class TestMultithreadAutograd(TestCase):
def _run_py_multithread_fn(self, fn, args=(), num_threads=10, kwargs=None):
threads = []
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 9888c29..5739fb5 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -8,15 +8,19 @@
import unittest
import warnings
import operator
+from functools import partial
from torch._six import inf, nan
from torch.testing._internal.common_utils import (
TestCase, iter_indices, TEST_WITH_ASAN, run_tests,
- torch_to_numpy_dtype_dict, make_tensor)
+ torch_to_numpy_dtype_dict, make_tensor, TEST_SCIPY)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA,
dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA,
- skipCUDAIfRocm)
+ skipCUDAIfRocm, skipIf)
+
+if TEST_SCIPY:
+ import scipy.special
# TODO: remove this
def _generate_input(shape, dtype, device, with_extremal):
@@ -2488,6 +2492,103 @@
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
torch.Tensor.float_power_(base.clone(), exp)
+ @skipIf(not TEST_SCIPY, "Scipy required for the test.")
+ @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False),
+ torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False)))
+ def test_xlogy(self, device, dtypes):
+ def out_variant_helper(torch_fn, x, y):
+ expected = torch_fn(x, y)
+ out = torch.empty_like(expected)
+ torch_fn(x, y, out=out)
+ self.assertEqual(expected, out)
+
+ def inplace_variant_helper(x, y):
+ if x.dtype in torch.testing.get_all_int_dtypes() + [torch.bool]:
+ with self.assertRaisesRegex(RuntimeError,
+ "can't be cast to the desired output type"):
+ x.clone().xlogy_(y)
+ else:
+ expected = torch.empty_like(x)
+ torch.xlogy(x, y, out=expected)
+ inplace_out = x.clone().xlogy_(y)
+ self.assertEqual(expected, inplace_out)
+
+ x_dtype, y_dtype = dtypes
+
+ # Tensor-Tensor Test (tensor of same and different shape)
+ x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000)
+ y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000)
+ z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000)
+
+ torch_fn = partial(torch.xlogy, x)
+ reference_fn = partial(scipy.special.xlogy, x.cpu().numpy())
+
+ self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False)
+ self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False)
+ self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False)
+ out_variant_helper(torch.xlogy, x, x)
+ out_variant_helper(torch.xlogy, x, y)
+ out_variant_helper(torch.xlogy, x, z)
+ inplace_variant_helper(x, x)
+ inplace_variant_helper(x, y)
+ inplace_variant_helper(x, z)
+
+ # Scalar-Tensor Test
+ torch_fn = partial(torch.xlogy, 3.14)
+ reference_fn = partial(scipy.special.xlogy, 3.14)
+
+ self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False)
+ self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False)
+ self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False)
+ out_variant_helper(torch.xlogy, 3.14, x)
+ out_variant_helper(torch.xlogy, 3.14, y)
+ out_variant_helper(torch.xlogy, 3.14, z)
+
+ # Special Values Tensor-Tensor
+ t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device)
+ zeros = torch.zeros(6, dtype=y_dtype, device=device)
+
+ torch_fn = partial(torch.xlogy, zeros)
+ reference_fn = partial(scipy.special.xlogy, zeros.cpu().numpy())
+ self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
+ out_variant_helper(torch.xlogy, zeros, t)
+ inplace_variant_helper(zeros, t)
+
+ # Special Values Scalar-Tensor
+ torch_fn = partial(torch.xlogy, 0)
+ reference_fn = partial(scipy.special.xlogy, 0)
+ self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
+ out_variant_helper(torch.xlogy, 0, t)
+
+ @skipIf(not TEST_SCIPY, "Scipy required for the test.")
+ def test_xlogy_bfloat16(self, device):
+ def _compare_helper(x, y):
+ x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy()
+ y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy()
+ expected = torch.from_numpy(scipy.special.xlogy(x_np, y_np))
+ actual = torch.xlogy(x, y)
+ self.assertEqual(expected, actual, exact_dtype=False)
+
+ x_dtype, y_dtype = torch.bfloat16, torch.bfloat16
+
+ # Tensor-Tensor Test (tensor of same and different shape)
+ x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000)
+ y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000)
+ z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000)
+
+ _compare_helper(x, x)
+ _compare_helper(x, y)
+ _compare_helper(x, z)
+
+ _compare_helper(x, 3.14)
+ _compare_helper(y, 3.14)
+ _compare_helper(z, 3.14)
+
+ # Special Values Tensor-Tensor
+ t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device)
+ zeros = torch.tensor(5, dtype=y_dtype, device=device)
+ _compare_helper(t, zeros)
+ _compare_helper(t, 0.)
tensor_binary_ops = [
'__lt__', '__le__',
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 7a619b9..9f68622 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -647,6 +647,16 @@
self: grad / (1 + pow(2, other - self))
other: grad / (1 + pow(2, self - other))
+- name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
+ self: grad * at::xlogy((self != 0), other)
+ other: grad * self / other
+
+- name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
+ other: grad * self / other
+
+- name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
+ self: grad * at::xlogy((self != 0), other)
+
- name: logdet(Tensor self) -> Tensor
self: logdet_backward(grad, self, result)
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index f081b59..e944320 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -4472,6 +4472,20 @@
Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
""")
+add_docstr_all('xlogy',
+ r"""
+xlogy(other) -> Tensor
+
+See :func:`torch.xlogy`
+""")
+
+add_docstr_all('xlogy_',
+ r"""
+xlogy_(other) -> Tensor
+
+In-place version of :meth:`~Tensor.xlogy`
+""")
+
add_docstr_all('masked_fill',
r"""
masked_fill(mask, value) -> Tensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 91da41b..0294942 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -4371,6 +4371,48 @@
{out}
""".format(**common_args))
+add_docstr(torch.xlogy,
+ r"""
+xlogy(input, other, *, out=None) -> Tensor
+
+Computes ``input * log(other)`` with the following cases.
+
+.. math::
+ \text{out}_{i} = \begin{cases}
+ \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\
+ 0 & \text{if } \text{input}_{i} = 0.0 \\
+ \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise}
+ \end{cases}
+
+Similar to SciPy's `scipy.special.xlogy`.
+
+""" + r"""
+
+Args:
+ input (Number or Tensor)
+ other (Number or Tensor)
+
+.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor.
+
+Keyword args:
+ {out}
+
+Example::
+
+ >>> x = torch.zeros(5,)
+ >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
+ >>> torch.xlogy(x, y)
+ tensor([0., 0., 0., 0., nan])
+ >>> x = torch.tensor([1, 2, 3])
+ >>> y = torch.tensor([3, 2, 1])
+ >>> torch.xlogy(x, y)
+ tensor([1.0986, 1.3863, 0.0000])
+ >>> torch.xlogy(x, 4)
+ tensor([1.3863, 2.7726, 4.1589])
+ >>> torch.xlogy(2, y)
+ tensor([2.1972, 1.3863, 0.0000])
+""".format(**common_args))
+
add_docstr(torch.logical_and,
r"""
logical_and(input, other, *, out=None) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index c0e3463..d23e348 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -495,6 +495,7 @@
torch.logaddexp: lambda input, other, out=None: -1,
torch.logaddexp2: lambda input, other, out=None: -1,
torch.logdet: lambda input: -1,
+ torch.xlogy: lambda x, y: -1,
torch.logical_and: lambda input, other, out=None: -1,
torch.logical_not: lambda input, out=None: -1,
torch.logical_or: lambda input, other, out=None: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index ba29c42..55b97b3 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -291,12 +291,21 @@
return (SampleInput((make_tensor((S, S), device, dtype,
low=None, high=None,
requires_grad=requires_grad),
- make_tensor((S, S), device, dtype,
- low=None, high=None,
- requires_grad=requires_grad),
- make_tensor((S, S), device, dtype,
- low=None, high=None,
- requires_grad=False))),)
+ make_tensor((S, S), device, dtype,
+ low=None, high=None,
+ requires_grad=requires_grad),
+ make_tensor((S, S), device, dtype,
+ low=None, high=None,
+ requires_grad=False))),)
+
+
+def sample_inputs_xlogy(self, device, dtype, requires_grad):
+ return (SampleInput((make_tensor((S, S), device, dtype,
+ low=None, high=None,
+ requires_grad=requires_grad),
+ make_tensor((S, S), device, dtype,
+ low=0, high=None,
+ requires_grad=requires_grad))),)
def np_sinc_with_fp16_as_fp32(x):
# Wraps numpy's sinc function so that fp16 values are promoted to fp32
@@ -1084,6 +1093,14 @@
dtypes=[torch.bfloat16]),),
assert_autodiffed=True,
promotes_integers_to_float=True),
+ OpInfo('xlogy',
+ dtypes=all_types_and(torch.bool),
+ dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ test_inplace_grad=True,
+ supports_tensor_out=True,
+ promotes_integers_to_float=True,
+ sample_inputs_func=sample_inputs_xlogy),
]
op_db = op_db + op_db_scipy_reference