Implementing NumPy-like function torch.heaviside() (#42523)

Summary:
- Related with https://github.com/pytorch/pytorch/issues/38349
- Implementing the NumPy-like function `torch.heaviside()` .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42523

Reviewed By: glaringlee

Differential Revision: D23391941

Pulled By: mruberry

fbshipit-source-id: 7b942321a62567a5fc0a3679a289f4c4c19e6134
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index e2d2cd3..9715a61 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -367,6 +367,7 @@
 _(aten, hardtanh) \
 _(aten, hardtanh_backward) \
 _(aten, hardtanh_forward) \
+_(aten, heaviside) \
 _(aten, hinge_embedding_loss) \
 _(aten, histc) \
 _(aten, hspmm) \
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index ea12fd8..9472cd6 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -47,6 +47,7 @@
 DEFINE_DISPATCH(lcm_stub);
 DEFINE_DISPATCH(hypot_stub);
 DEFINE_DISPATCH(nextafter_stub);
+DEFINE_DISPATCH(heaviside_stub);
 
 Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
   auto iter = TensorIterator::binary_op(result, self, other);
@@ -931,6 +932,33 @@
   return self - (other * alpha);
 }
 
+Tensor& heaviside_out(Tensor& result, const Tensor& self, const Tensor& values) {
+  TORCH_CHECK(!self.is_complex() && !result.is_complex() && !values.is_complex(),
+              "heaviside is not yet implemented for complex tensors.");
+  TORCH_CHECK(self.dtype() == values.dtype() &&  result.dtype() == self.dtype(),
+              "heaviside is not yet implemented for tensors with different dtypes.");
+
+  auto iter = TensorIterator::binary_op(result, self, values, /*check_mem_overlap=*/true);
+  heaviside_stub(iter.device_type(), iter);
+  return result;
+}
+
+Tensor heaviside(const Tensor& self, const Tensor& values) {
+  TORCH_CHECK(!self.is_complex() && !values.is_complex(),
+              "heaviside is not yet implemented for complex tensors.");
+  TORCH_CHECK(self.dtype() == values.dtype(),
+              "heaviside is not yet implemented for tensors with different dtypes.");
+
+  Tensor result;
+  auto iter = TensorIterator::binary_op(result, self, values);
+  heaviside_stub(iter.device_type(), iter);
+  return iter.output();
+}
+
+Tensor& heaviside_(Tensor& self, const Tensor& values) {
+  return at::heaviside_out(self, self, values);
+}
+
 // TODO: Deduplicate this with the TensorIterator logic.  This would
 // also fix the TODOs below.
 Tensor binary_op_meta(const Tensor& self, const Tensor& other) {
diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h
index d9da3a0..e2dad35 100644
--- a/aten/src/ATen/native/BinaryOps.h
+++ b/aten/src/ATen/native/BinaryOps.h
@@ -67,5 +67,6 @@
 DECLARE_DISPATCH(binary_fn, lcm_stub);
 DECLARE_DISPATCH(binary_fn, hypot_stub);
 DECLARE_DISPATCH(binary_fn, nextafter_stub);
+DECLARE_DISPATCH(binary_fn, heaviside_stub);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
index 946c0a7..1bafcb7 100644
--- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
@@ -764,6 +764,14 @@
   });
 }
 
+void heaviside_kernel(TensorIterator& iter) {
+  AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cpu", [&]() {
+    cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
+        return a == 0 ? b : static_cast<scalar_t>(a > 0);
+    });
+  });
+}
+
 } // namespace
 
 REGISTER_DISPATCH(add_stub, &add_kernel);
@@ -802,6 +810,7 @@
 REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
 REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
 REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
+REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
index 2f1b554..4514083 100644
--- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
+++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
@@ -101,6 +101,14 @@
   });
 }
 
+void heaviside_kernel_cuda(TensorIterator& iter) {
+  AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() {
+    gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
+      return a == 0 ? b : static_cast<scalar_t>(a > 0);
+    });
+  });
+}
+
 REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
 REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
 REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
@@ -110,5 +118,6 @@
 REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
 REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
 REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
+REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 9fba236..5020fb3 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3613,6 +3613,17 @@
   use_c10_dispatcher: full
   variants: function
 
+- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: heaviside_out
+
+- func: heaviside(Tensor self, Tensor values) -> Tensor
+  use_c10_dispatcher: full
+  variants: function, method
+
+- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)
+  variants: method
+
 # For C++ only, until we have conversion from C++ numbers to Tensor
 - func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
   use_c10_dispatcher: full
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index b4ad275..8693c07 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -332,6 +332,7 @@
    .. automethod:: gt_
    .. automethod:: half
    .. automethod:: hardshrink
+   .. automethod:: heaviside
    .. automethod:: histc
    .. automethod:: hypot
    .. automethod:: hypot_
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index fa03b90..9abcde9 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -75,6 +75,7 @@
     dequantize
     complex
     polar
+    heaviside
 
 Indexing, Slicing, Joining, Mutating Ops
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/test/test_torch.py b/test/test_torch.py
index 5d3cb7f..d62cf74 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -6261,6 +6261,71 @@
                          torch.bitwise_xor(torch.tensor([True, True, False], device=device),
                                            torch.tensor([False, True, False], device=device)))
 
+    @onlyOnCPUAndCUDA
+    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+    @dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False),
+                          torch.testing.get_all_dtypes(include_complex=False))))
+    def test_heaviside(self, device, dtypes):
+        input_dtype = dtypes[0]
+        values_dtype = dtypes[1]
+
+        rng = np.random.default_rng()
+        input = np.array(rng.integers(-10, 10, size=10),
+                         dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64])
+        input[0] = input[3] = input[7] = 0
+        values = np.array(rng.integers(-10, 10, size=10),
+                          dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64])
+        np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype)
+
+        input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
+        values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
+        out = torch.empty_like(input)
+
+        if input_dtype == values_dtype:
+            torch_result = torch.heaviside(input, values)
+            self.assertEqual(np_result, torch_result)
+
+            torch_result = input.heaviside(values)
+            self.assertEqual(np_result, torch_result)
+
+            torch.heaviside(input, values, out=out)
+            self.assertEqual(np_result, out)
+
+            input.heaviside_(values)
+            self.assertEqual(np_result, input)
+        else:
+            with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
+                torch.heaviside(input, values)
+            with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
+                input.heaviside(values)
+            with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
+                torch.heaviside(input, values, out=out)
+            with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
+                input.heaviside_(values)
+
+
+    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+    @dtypes(*list(product(torch.testing.get_all_complex_dtypes(),
+                          torch.testing.get_all_complex_dtypes())))
+    def test_heaviside_complex(self, device, dtypes):
+        input_dtype = dtypes[0]
+        values_dtype = dtypes[1]
+
+        data = (complex(0, -6), complex(-1, 3), complex(1, 1))
+        input = torch.tensor(data, device=device, dtype=input_dtype)
+        values = torch.tensor(data, device=device, dtype=values_dtype)
+        out = torch.empty_like(input)
+        real = input.real
+
+        with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
+            torch.heaviside(input, real)
+        with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
+            real.heaviside(values)
+        with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
+            input.heaviside_(values)
+        with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
+            torch.heaviside(real, real, out=out)
+
     @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
     @dtypes(*torch.testing.get_all_dtypes())
     def test_logical_not(self, device, dtype):
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index 18cd382..e747127 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -1535,6 +1535,20 @@
 See :func:`torch.nn.functional.hardshrink`
 """)
 
+add_docstr_all('heaviside',
+               r"""
+heaviside(values) -> Tensor
+
+See :func:`torch.heaviside`
+""")
+
+add_docstr_all('heaviside_',
+               r"""
+heaviside_(values) -> Tensor
+
+In-place version of :meth:`~Tensor.heaviside`
+""")
+
 add_docstr_all('histc',
                r"""
 histc(bins=100, min=0, max=0) -> Tensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 13c2be2..157ef73 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -5591,6 +5591,40 @@
 
 """.format(**common_args))
 
+add_docstr(torch.heaviside,
+           r"""
+heaviside(input, values, *, out=None) -> Tensor
+
+Computes the Heaviside step function for each element in :attr:`input`.
+The Heaviside step function is defined as:
+
+.. math::
+    \text{{heaviside}}(input, values) = \begin{cases}
+        \0, & \text{if input < 0}\\
+        \values, & \text{if input == 0}\\
+        \1, & \text{if input > 0}
+    \end{cases}
+""" + r"""
+
+Args:
+    {input}
+    values (Tensor): The values to use where :attr:`input` is zero.
+
+Keyword arguments:
+    {out}
+
+Example::
+
+    >>> input = torch.tensor([-1.5, 0, 2.0])
+    >>> values = torch.tensor([0.5])
+    >>> torch.heaviside(input, values)
+    tensor([0.0000, 0.5000, 1.0000])
+    >>> values = torch.tensor([1.2, -2.0, 3.5])
+    >>> torch.heaviside(input, values)
+    tensor([0., -2., 1.])
+
+""".format(**common_args))
+
 add_docstr(torch.rand,
            r"""
 rand(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index f0e6599..ea9956a 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -377,6 +377,7 @@
         torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
         torch.gt: lambda input, other, out=None: -1,
         torch.hardshrink: lambda input, lambd=0.5: -1,
+        torch.heaviside: lambda input, values, out=None: -1,
         torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
         torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
         torch.hspmm: lambda mat1, mat2, out=None: -1,