implement numpy-like functionality isposinf, isneginf (#41588)

Summary:
Related https://github.com/pytorch/pytorch/issues/38349

Numpy-like functionalities `isposinf` and `isneginf` are implemented.

Test-Plan:
- pytest test/test_torch.py -k "test_isposinf_isneginf"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41588

Reviewed By: ngimel

Differential Revision: D22770732

Pulled By: mruberry

fbshipit-source-id: 7448653e8fb8df6b9cd4604a4739fe18a1135578
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 99c8bb4..40cfc23 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -405,6 +405,8 @@
 _(aten, isclose) \
 _(aten, isreal) \
 _(aten, istft) \
+_(aten, isposinf) \
+_(aten, isneginf) \
 _(aten, kl_div) \
 _(aten, kl_div_backward) \
 _(aten, kthvalue) \
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 2aecf12..901ff55 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -13,6 +13,8 @@
 DEFINE_DISPATCH(where_kernel);
 DEFINE_DISPATCH(max_stub);
 DEFINE_DISPATCH(min_stub);
+DEFINE_DISPATCH(isposinf_stub);
+DEFINE_DISPATCH(isneginf_stub);
 
 bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
   return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
@@ -106,6 +108,56 @@
   });
 }
 
+Tensor isposinf(const Tensor &self) {
+  Tensor result = at::empty_like(self, at::kBool, at::MemoryFormat::Preserve);
+  at::isposinf_out(result, self);
+  return result;
+}
+
+Tensor& isposinf_out(Tensor& result, const Tensor& self) {
+  TORCH_CHECK(!self.is_complex(), "isposinf does not support complex inputs.");
+  TORCH_CHECK(result.scalar_type() == at::kBool, "isposinf does not support non-boolean outputs.");
+  result.resize_(self.sizes());
+
+  if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
+    result.fill_(false);
+  } else {
+    auto iter = TensorIteratorConfig()
+      .check_all_same_dtype(false)
+      .set_check_mem_overlap(true)
+      .add_output(result)
+      .add_input(self)
+      .build();
+    isposinf_stub(iter.device_type(), iter);
+  }
+  return result;
+}
+
+Tensor isneginf(const Tensor &self) {
+  Tensor result = at::empty_like(self, at::kBool, at::MemoryFormat::Preserve);
+  at::isneginf_out(result, self);
+  return result;
+}
+
+Tensor& isneginf_out(Tensor& result, const Tensor& self) {
+  TORCH_CHECK(!self.is_complex(), "isneginf does not support complex inputs.");
+  TORCH_CHECK(result.scalar_type() == at::kBool, "isneginf does not support non-boolean outputs.");
+  result.resize_(self.sizes());
+
+  if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
+    result.fill_(false);
+  } else {
+    auto iter = TensorIteratorConfig()
+      .check_all_same_dtype(false)
+      .set_check_mem_overlap(true)
+      .add_output(result)
+      .add_input(self)
+      .build();
+    isneginf_stub(iter.device_type(), iter);
+  }
+  return result;
+}
+
 Tensor isfinite(const Tensor& self) {
   // Note: Integral tensor values are always finite
   if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
diff --git a/aten/src/ATen/native/TensorCompare.h b/aten/src/ATen/native/TensorCompare.h
index 145837e..2193cb9 100644
--- a/aten/src/ATen/native/TensorCompare.h
+++ b/aten/src/ATen/native/TensorCompare.h
@@ -16,4 +16,7 @@
 using where_fn = void (*)(TensorIterator &, ScalarType);
 DECLARE_DISPATCH(where_fn, where_kernel);
 
+using is_infinity_op_fn = void (*)(TensorIterator &);
+DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub);
+DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub);
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
index 3c2cd28..3cbb64b 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
@@ -161,10 +161,24 @@
   });
 }
 
+static void isposinf_kernel_impl(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isposinf_cpu", [&]() {
+    cpu_kernel(iter, [](scalar_t a) -> bool { return a == std::numeric_limits<scalar_t>::infinity(); });
+  });
+}
+
+static void isneginf_kernel_impl(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isneginf_cpu", [&]() {
+    cpu_kernel(iter, [](scalar_t a) -> bool { return a == -std::numeric_limits<scalar_t>::infinity(); });
+  });
+}
+
 } // anonymous namespace
 
 REGISTER_DISPATCH(max_stub, &max_kernel_impl);
 REGISTER_DISPATCH(min_stub, &min_kernel_impl);
 REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
+REGISTER_DISPATCH(isposinf_stub, &isposinf_kernel_impl);
+REGISTER_DISPATCH(isneginf_stub, &isneginf_kernel_impl);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu
index 0dbedcf..443bea3 100644
--- a/aten/src/ATen/native/cuda/TensorCompare.cu
+++ b/aten/src/ATen/native/cuda/TensorCompare.cu
@@ -10,6 +10,10 @@
 using where_fn = void (*)(TensorIterator &, ScalarType);
 DECLARE_DISPATCH(where_fn, where_kernel);
 
+using is_infinity_op_fn = void (*)(TensorIterator &);
+DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub);
+DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub);
+
 namespace {
 
 void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
@@ -30,9 +34,29 @@
   });
 }
 
+void isposinf_kernel_impl(TensorIterator &iter) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isposinf_cuda", [&]() {
+    gpu_kernel(
+      iter,
+      [] GPU_LAMBDA (scalar_t a) -> bool { return a == std::numeric_limits<scalar_t>::infinity(); }
+    );
+  });
+}
+
+void isneginf_kernel_impl(TensorIterator &iter) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isneginf_cuda", [&]() {
+    gpu_kernel(
+      iter,
+      [] GPU_LAMBDA (scalar_t a) -> bool { return a == -std::numeric_limits<scalar_t>::infinity(); }
+    );
+  });
+}
+
 } // anonymous namespace
 
 
 REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
+REGISTER_DISPATCH(isposinf_stub, &isposinf_kernel_impl);
+REGISTER_DISPATCH(isneginf_stub, &isneginf_kernel_impl);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index e104c43..4c148c6 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6759,6 +6759,22 @@
   variants: function, method
   device_guard: False
 
+- func: isposinf(Tensor self) -> Tensor
+  use_c10_dispatcher: full
+  variants: function, method
+
+- func: isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: isposinf_out
+
+- func: isneginf(Tensor self) -> Tensor
+  use_c10_dispatcher: full
+  variants: function, method
+
+- func: isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: isneginf_out
+
 # NOTE [_add_batch_dim and _remove_batch_dim]
 # _add_batch_dim and _remove_batch_dim are meant to be used in the implementation
 # of the vmap frontend API (see torch/_vmap_internals.py). They are not
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index 11575b0..8fcca7a 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -339,6 +339,8 @@
    .. automethod:: isclose
    .. automethod:: isfinite
    .. automethod:: isinf
+   .. automethod:: isposinf
+   .. automethod:: isneginf
    .. automethod:: isnan
    .. automethod:: is_contiguous
    .. automethod:: is_complex
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index 94a4e5b..313136b 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -351,6 +351,8 @@
     isclose
     isfinite
     isinf
+    isposinf
+    isneginf
     isnan
     isreal
     kthvalue
diff --git a/test/test_torch.py b/test/test_torch.py
index 101b197..2e2207b 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -6453,6 +6453,72 @@
         self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype)
 
     @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
+    @dtypes(*(torch.testing.get_all_fp_dtypes()))
+    def test_isposinf_isneginf_float(self, device, dtype):
+        ops = ((torch.isposinf, np.isposinf), (torch.isneginf, np.isneginf))
+        vals = (-float('inf'), float('inf'), float('nan'), -1, 0, 1)
+
+        for torch_op, numpy_op in ops:
+            if torch_op == torch.isposinf:
+                target_vals = (0, 1, 0, 0, 0, 0)
+            else:
+                target_vals = (1, 0, 0, 0, 0, 0)
+
+            t = torch.tensor(vals, device=device, dtype=dtype)
+            # Manual check here as numpy does not support bfloat16
+            if dtype == torch.bfloat16:
+                self.assertEqual(torch_op(t),
+                                 torch.tensor(target_vals, device=device, dtype=torch.bool))
+            else:
+                self.compare_with_numpy(torch_op, numpy_op, vals, device, dtype)
+
+            # test the boolean tensor as the `out=` parameter
+            out = torch.empty_like(t, dtype=torch.bool)
+            t_target = torch.tensor(target_vals, device=device, dtype=torch.bool)
+            torch_op(t, out=out)
+            self.assertEqual(out, t_target)
+
+    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
+    @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool]))
+    def test_isposinf_isneginf_int_and_bool(self, device, dtype):
+        ops = ((torch.isposinf, np.isposinf), (torch.isneginf, np.isneginf))
+        vals = (-1, 0, 1)
+
+        for torch_op, numpy_op in ops:
+            self.compare_with_numpy(torch_op, numpy_op, vals, device, dtype)
+
+            # test the boolean tensor as the `out=` parameter
+            t = torch.tensor(vals, device=device, dtype=dtype)
+            out = torch.empty_like(t, dtype=torch.bool)
+            t_target = torch.zeros_like(t, dtype=torch.bool)
+            torch_op(t, out=out)
+            self.assertEqual(out, t_target)
+
+    @dtypes(torch.complex64, torch.complex128)
+    def test_isposinf_isneginf_complex(self, device, dtype):
+        torch_ops = (torch.isposinf, torch.isneginf)
+        vals = (complex(0, float('inf')), complex(1, -float('inf')))
+        t = torch.tensor(vals, device=device, dtype=dtype)
+        out = torch.empty_like(t)
+
+        for torch_op in torch_ops:
+            with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'):
+                torch_op(t)
+            with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'):
+                torch_op(t, out=out)
+
+    @dtypes(*(torch.testing.get_all_dtypes(include_bool=False)))
+    def test_isposinf_isneginf_non_boolean_output(self, device, dtype):
+        # test non-boolean tensors as the `out=` parameters
+        # boolean outputs are tested in the above testcases
+        vals = (float('inf'), -float('inf'), 1.2)
+        t = torch.tensor(vals, device=device)
+        for torch_op in (torch.isposinf, torch.isneginf):
+            out = torch.empty_like(t, dtype=dtype)
+            with self.assertRaisesRegex(RuntimeError, 'does not support non-boolean outputs'):
+                torch_op(t, out=out)
+
+    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
     @dtypes(torch.complex64)
     def test_isfinite_isinf_isnan_complex(self, device, dtype):
         vals = (
diff --git a/torch/_overrides.py b/torch/_overrides.py
index 80137f7..a796baa 100644
--- a/torch/_overrides.py
+++ b/torch/_overrides.py
@@ -348,6 +348,8 @@
         torch.isfinite: lambda tensor: -1,
         torch.isinf: lambda tensor: -1,
         torch.isreal: lambda tensor: -1,
+        torch.isposinf: lambda input, out=None: -1,
+        torch.isneginf: lambda input, out=None: -1,
         torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
                               cudnn_enabled: -1),
         torch.int_repr: lambda input: -1,
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index fb31e16..2c2e947 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -1637,6 +1637,20 @@
 See :func:`torch.isinf`
 """)
 
+add_docstr_all('isposinf',
+               r"""
+isposinf() -> Tensor
+
+See :func:`torch.isposinf`
+""")
+
+add_docstr_all('isneginf',
+               r"""
+isneginf() -> Tensor
+
+See :func:`torch.isneginf`
+""")
+
 add_docstr_all('isfinite',
                r"""
 isfinite() -> Tensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 779f0e3..eec97ec 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -2875,6 +2875,40 @@
         tensor([False,  True,  False,  True,  False])
 """)
 
+add_docstr(torch.isposinf,
+           r"""
+isposinf(input, *, out=None) -> Tensor
+Tests if each element of :attr:`input` is positive infinity or not.
+
+Args:
+  {input}
+
+Keyword args:
+  {out}
+
+Example::
+    >>> a = torch.tensor([-float('inf'), float('inf'), 1.2])
+    >>> torch.isposinf(a)
+    tensor([False,  True, False])
+""".format(**common_args))
+
+add_docstr(torch.isneginf,
+           r"""
+isneginf(input, *, out=None) -> Tensor
+Tests if each element of :attr:`input` is negative infinity or not.
+
+Args:
+  {input}
+
+Keyword args:
+  {out}
+
+Example::
+    >>> a = torch.tensor([-float('inf'), float('inf'), 1.2])
+    >>> torch.isneginf(a)
+    tensor([ True, False, False])
+""".format(**common_args))
+
 add_docstr(torch.isclose, r"""
 isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor