implement numpy-like functionality isposinf, isneginf (#41588)
Summary:
Related https://github.com/pytorch/pytorch/issues/38349
Numpy-like functionalities `isposinf` and `isneginf` are implemented.
Test-Plan:
- pytest test/test_torch.py -k "test_isposinf_isneginf"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41588
Reviewed By: ngimel
Differential Revision: D22770732
Pulled By: mruberry
fbshipit-source-id: 7448653e8fb8df6b9cd4604a4739fe18a1135578
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 99c8bb4..40cfc23 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -405,6 +405,8 @@
_(aten, isclose) \
_(aten, isreal) \
_(aten, istft) \
+_(aten, isposinf) \
+_(aten, isneginf) \
_(aten, kl_div) \
_(aten, kl_div_backward) \
_(aten, kthvalue) \
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 2aecf12..901ff55 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -13,6 +13,8 @@
DEFINE_DISPATCH(where_kernel);
DEFINE_DISPATCH(max_stub);
DEFINE_DISPATCH(min_stub);
+DEFINE_DISPATCH(isposinf_stub);
+DEFINE_DISPATCH(isneginf_stub);
bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
@@ -106,6 +108,56 @@
});
}
+Tensor isposinf(const Tensor &self) {
+ Tensor result = at::empty_like(self, at::kBool, at::MemoryFormat::Preserve);
+ at::isposinf_out(result, self);
+ return result;
+}
+
+Tensor& isposinf_out(Tensor& result, const Tensor& self) {
+ TORCH_CHECK(!self.is_complex(), "isposinf does not support complex inputs.");
+ TORCH_CHECK(result.scalar_type() == at::kBool, "isposinf does not support non-boolean outputs.");
+ result.resize_(self.sizes());
+
+ if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
+ result.fill_(false);
+ } else {
+ auto iter = TensorIteratorConfig()
+ .check_all_same_dtype(false)
+ .set_check_mem_overlap(true)
+ .add_output(result)
+ .add_input(self)
+ .build();
+ isposinf_stub(iter.device_type(), iter);
+ }
+ return result;
+}
+
+Tensor isneginf(const Tensor &self) {
+ Tensor result = at::empty_like(self, at::kBool, at::MemoryFormat::Preserve);
+ at::isneginf_out(result, self);
+ return result;
+}
+
+Tensor& isneginf_out(Tensor& result, const Tensor& self) {
+ TORCH_CHECK(!self.is_complex(), "isneginf does not support complex inputs.");
+ TORCH_CHECK(result.scalar_type() == at::kBool, "isneginf does not support non-boolean outputs.");
+ result.resize_(self.sizes());
+
+ if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
+ result.fill_(false);
+ } else {
+ auto iter = TensorIteratorConfig()
+ .check_all_same_dtype(false)
+ .set_check_mem_overlap(true)
+ .add_output(result)
+ .add_input(self)
+ .build();
+ isneginf_stub(iter.device_type(), iter);
+ }
+ return result;
+}
+
Tensor isfinite(const Tensor& self) {
// Note: Integral tensor values are always finite
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
diff --git a/aten/src/ATen/native/TensorCompare.h b/aten/src/ATen/native/TensorCompare.h
index 145837e..2193cb9 100644
--- a/aten/src/ATen/native/TensorCompare.h
+++ b/aten/src/ATen/native/TensorCompare.h
@@ -16,4 +16,7 @@
using where_fn = void (*)(TensorIterator &, ScalarType);
DECLARE_DISPATCH(where_fn, where_kernel);
+using is_infinity_op_fn = void (*)(TensorIterator &);
+DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub);
+DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub);
}} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
index 3c2cd28..3cbb64b 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
@@ -161,10 +161,24 @@
});
}
+static void isposinf_kernel_impl(TensorIterator& iter) {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isposinf_cpu", [&]() {
+ cpu_kernel(iter, [](scalar_t a) -> bool { return a == std::numeric_limits<scalar_t>::infinity(); });
+ });
+}
+
+static void isneginf_kernel_impl(TensorIterator& iter) {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isneginf_cpu", [&]() {
+ cpu_kernel(iter, [](scalar_t a) -> bool { return a == -std::numeric_limits<scalar_t>::infinity(); });
+ });
+}
+
} // anonymous namespace
REGISTER_DISPATCH(max_stub, &max_kernel_impl);
REGISTER_DISPATCH(min_stub, &min_kernel_impl);
REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
+REGISTER_DISPATCH(isposinf_stub, &isposinf_kernel_impl);
+REGISTER_DISPATCH(isneginf_stub, &isneginf_kernel_impl);
}} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu
index 0dbedcf..443bea3 100644
--- a/aten/src/ATen/native/cuda/TensorCompare.cu
+++ b/aten/src/ATen/native/cuda/TensorCompare.cu
@@ -10,6 +10,10 @@
using where_fn = void (*)(TensorIterator &, ScalarType);
DECLARE_DISPATCH(where_fn, where_kernel);
+using is_infinity_op_fn = void (*)(TensorIterator &);
+DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub);
+DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub);
+
namespace {
void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
@@ -30,9 +34,29 @@
});
}
+void isposinf_kernel_impl(TensorIterator &iter) {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isposinf_cuda", [&]() {
+ gpu_kernel(
+ iter,
+ [] GPU_LAMBDA (scalar_t a) -> bool { return a == std::numeric_limits<scalar_t>::infinity(); }
+ );
+ });
+}
+
+void isneginf_kernel_impl(TensorIterator &iter) {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isneginf_cuda", [&]() {
+ gpu_kernel(
+ iter,
+ [] GPU_LAMBDA (scalar_t a) -> bool { return a == -std::numeric_limits<scalar_t>::infinity(); }
+ );
+ });
+}
+
} // anonymous namespace
REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
+REGISTER_DISPATCH(isposinf_stub, &isposinf_kernel_impl);
+REGISTER_DISPATCH(isneginf_stub, &isneginf_kernel_impl);
}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index e104c43..4c148c6 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6759,6 +6759,22 @@
variants: function, method
device_guard: False
+- func: isposinf(Tensor self) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+
+- func: isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+ dispatch:
+ CPU, CUDA: isposinf_out
+
+- func: isneginf(Tensor self) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+
+- func: isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+ dispatch:
+ CPU, CUDA: isneginf_out
+
# NOTE [_add_batch_dim and _remove_batch_dim]
# _add_batch_dim and _remove_batch_dim are meant to be used in the implementation
# of the vmap frontend API (see torch/_vmap_internals.py). They are not
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index 11575b0..8fcca7a 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -339,6 +339,8 @@
.. automethod:: isclose
.. automethod:: isfinite
.. automethod:: isinf
+ .. automethod:: isposinf
+ .. automethod:: isneginf
.. automethod:: isnan
.. automethod:: is_contiguous
.. automethod:: is_complex
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index 94a4e5b..313136b 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -351,6 +351,8 @@
isclose
isfinite
isinf
+ isposinf
+ isneginf
isnan
isreal
kthvalue
diff --git a/test/test_torch.py b/test/test_torch.py
index 101b197..2e2207b 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -6453,6 +6453,72 @@
self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype)
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
+ @dtypes(*(torch.testing.get_all_fp_dtypes()))
+ def test_isposinf_isneginf_float(self, device, dtype):
+ ops = ((torch.isposinf, np.isposinf), (torch.isneginf, np.isneginf))
+ vals = (-float('inf'), float('inf'), float('nan'), -1, 0, 1)
+
+ for torch_op, numpy_op in ops:
+ if torch_op == torch.isposinf:
+ target_vals = (0, 1, 0, 0, 0, 0)
+ else:
+ target_vals = (1, 0, 0, 0, 0, 0)
+
+ t = torch.tensor(vals, device=device, dtype=dtype)
+ # Manual check here as numpy does not support bfloat16
+ if dtype == torch.bfloat16:
+ self.assertEqual(torch_op(t),
+ torch.tensor(target_vals, device=device, dtype=torch.bool))
+ else:
+ self.compare_with_numpy(torch_op, numpy_op, vals, device, dtype)
+
+ # test the boolean tensor as the `out=` parameter
+ out = torch.empty_like(t, dtype=torch.bool)
+ t_target = torch.tensor(target_vals, device=device, dtype=torch.bool)
+ torch_op(t, out=out)
+ self.assertEqual(out, t_target)
+
+ @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
+ @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool]))
+ def test_isposinf_isneginf_int_and_bool(self, device, dtype):
+ ops = ((torch.isposinf, np.isposinf), (torch.isneginf, np.isneginf))
+ vals = (-1, 0, 1)
+
+ for torch_op, numpy_op in ops:
+ self.compare_with_numpy(torch_op, numpy_op, vals, device, dtype)
+
+ # test the boolean tensor as the `out=` parameter
+ t = torch.tensor(vals, device=device, dtype=dtype)
+ out = torch.empty_like(t, dtype=torch.bool)
+ t_target = torch.zeros_like(t, dtype=torch.bool)
+ torch_op(t, out=out)
+ self.assertEqual(out, t_target)
+
+ @dtypes(torch.complex64, torch.complex128)
+ def test_isposinf_isneginf_complex(self, device, dtype):
+ torch_ops = (torch.isposinf, torch.isneginf)
+ vals = (complex(0, float('inf')), complex(1, -float('inf')))
+ t = torch.tensor(vals, device=device, dtype=dtype)
+ out = torch.empty_like(t)
+
+ for torch_op in torch_ops:
+ with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'):
+ torch_op(t)
+ with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'):
+ torch_op(t, out=out)
+
+ @dtypes(*(torch.testing.get_all_dtypes(include_bool=False)))
+ def test_isposinf_isneginf_non_boolean_output(self, device, dtype):
+ # test non-boolean tensors as the `out=` parameters
+ # boolean outputs are tested in the above testcases
+ vals = (float('inf'), -float('inf'), 1.2)
+ t = torch.tensor(vals, device=device)
+ for torch_op in (torch.isposinf, torch.isneginf):
+ out = torch.empty_like(t, dtype=dtype)
+ with self.assertRaisesRegex(RuntimeError, 'does not support non-boolean outputs'):
+ torch_op(t, out=out)
+
+ @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
@dtypes(torch.complex64)
def test_isfinite_isinf_isnan_complex(self, device, dtype):
vals = (
diff --git a/torch/_overrides.py b/torch/_overrides.py
index 80137f7..a796baa 100644
--- a/torch/_overrides.py
+++ b/torch/_overrides.py
@@ -348,6 +348,8 @@
torch.isfinite: lambda tensor: -1,
torch.isinf: lambda tensor: -1,
torch.isreal: lambda tensor: -1,
+ torch.isposinf: lambda input, out=None: -1,
+ torch.isneginf: lambda input, out=None: -1,
torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
cudnn_enabled: -1),
torch.int_repr: lambda input: -1,
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index fb31e16..2c2e947 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -1637,6 +1637,20 @@
See :func:`torch.isinf`
""")
+add_docstr_all('isposinf',
+ r"""
+isposinf() -> Tensor
+
+See :func:`torch.isposinf`
+""")
+
+add_docstr_all('isneginf',
+ r"""
+isneginf() -> Tensor
+
+See :func:`torch.isneginf`
+""")
+
add_docstr_all('isfinite',
r"""
isfinite() -> Tensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 779f0e3..eec97ec 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -2875,6 +2875,40 @@
tensor([False, True, False, True, False])
""")
+add_docstr(torch.isposinf,
+ r"""
+isposinf(input, *, out=None) -> Tensor
+Tests if each element of :attr:`input` is positive infinity or not.
+
+Args:
+ {input}
+
+Keyword args:
+ {out}
+
+Example::
+ >>> a = torch.tensor([-float('inf'), float('inf'), 1.2])
+ >>> torch.isposinf(a)
+ tensor([False, True, False])
+""".format(**common_args))
+
+add_docstr(torch.isneginf,
+ r"""
+isneginf(input, *, out=None) -> Tensor
+Tests if each element of :attr:`input` is negative infinity or not.
+
+Args:
+ {input}
+
+Keyword args:
+ {out}
+
+Example::
+ >>> a = torch.tensor([-float('inf'), float('inf'), 1.2])
+ >>> torch.isneginf(a)
+ tensor([ True, False, False])
+""".format(**common_args))
+
add_docstr(torch.isclose, r"""
isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor