Implementing NumPy-like function torch.heaviside() (#42523)
Summary:
- Related with https://github.com/pytorch/pytorch/issues/38349
- Implementing the NumPy-like function `torch.heaviside()` .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42523
Reviewed By: glaringlee
Differential Revision: D23391941
Pulled By: mruberry
fbshipit-source-id: 7b942321a62567a5fc0a3679a289f4c4c19e6134
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index e2d2cd3..9715a61 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -367,6 +367,7 @@
_(aten, hardtanh) \
_(aten, hardtanh_backward) \
_(aten, hardtanh_forward) \
+_(aten, heaviside) \
_(aten, hinge_embedding_loss) \
_(aten, histc) \
_(aten, hspmm) \
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index ea12fd8..9472cd6 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -47,6 +47,7 @@
DEFINE_DISPATCH(lcm_stub);
DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(nextafter_stub);
+DEFINE_DISPATCH(heaviside_stub);
Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
auto iter = TensorIterator::binary_op(result, self, other);
@@ -931,6 +932,33 @@
return self - (other * alpha);
}
+Tensor& heaviside_out(Tensor& result, const Tensor& self, const Tensor& values) {
+ TORCH_CHECK(!self.is_complex() && !result.is_complex() && !values.is_complex(),
+ "heaviside is not yet implemented for complex tensors.");
+ TORCH_CHECK(self.dtype() == values.dtype() && result.dtype() == self.dtype(),
+ "heaviside is not yet implemented for tensors with different dtypes.");
+
+ auto iter = TensorIterator::binary_op(result, self, values, /*check_mem_overlap=*/true);
+ heaviside_stub(iter.device_type(), iter);
+ return result;
+}
+
+Tensor heaviside(const Tensor& self, const Tensor& values) {
+ TORCH_CHECK(!self.is_complex() && !values.is_complex(),
+ "heaviside is not yet implemented for complex tensors.");
+ TORCH_CHECK(self.dtype() == values.dtype(),
+ "heaviside is not yet implemented for tensors with different dtypes.");
+
+ Tensor result;
+ auto iter = TensorIterator::binary_op(result, self, values);
+ heaviside_stub(iter.device_type(), iter);
+ return iter.output();
+}
+
+Tensor& heaviside_(Tensor& self, const Tensor& values) {
+ return at::heaviside_out(self, self, values);
+}
+
// TODO: Deduplicate this with the TensorIterator logic. This would
// also fix the TODOs below.
Tensor binary_op_meta(const Tensor& self, const Tensor& other) {
diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h
index d9da3a0..e2dad35 100644
--- a/aten/src/ATen/native/BinaryOps.h
+++ b/aten/src/ATen/native/BinaryOps.h
@@ -67,5 +67,6 @@
DECLARE_DISPATCH(binary_fn, lcm_stub);
DECLARE_DISPATCH(binary_fn, hypot_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
+DECLARE_DISPATCH(binary_fn, heaviside_stub);
}} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
index 946c0a7..1bafcb7 100644
--- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
@@ -764,6 +764,14 @@
});
}
+void heaviside_kernel(TensorIterator& iter) {
+ AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cpu", [&]() {
+ cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
+ return a == 0 ? b : static_cast<scalar_t>(a > 0);
+ });
+ });
+}
+
} // namespace
REGISTER_DISPATCH(add_stub, &add_kernel);
@@ -802,6 +810,7 @@
REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
+REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
index 2f1b554..4514083 100644
--- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
+++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
@@ -101,6 +101,14 @@
});
}
+void heaviside_kernel_cuda(TensorIterator& iter) {
+ AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() {
+ gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
+ return a == 0 ? b : static_cast<scalar_t>(a > 0);
+ });
+ });
+}
+
REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
@@ -110,5 +118,6 @@
REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
+REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 9fba236..5020fb3 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3613,6 +3613,17 @@
use_c10_dispatcher: full
variants: function
+- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
+ dispatch:
+ CPU, CUDA: heaviside_out
+
+- func: heaviside(Tensor self, Tensor values) -> Tensor
+ use_c10_dispatcher: full
+ variants: function, method
+
+- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)
+ variants: method
+
# For C++ only, until we have conversion from C++ numbers to Tensor
- func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index b4ad275..8693c07 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -332,6 +332,7 @@
.. automethod:: gt_
.. automethod:: half
.. automethod:: hardshrink
+ .. automethod:: heaviside
.. automethod:: histc
.. automethod:: hypot
.. automethod:: hypot_
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index fa03b90..9abcde9 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -75,6 +75,7 @@
dequantize
complex
polar
+ heaviside
Indexing, Slicing, Joining, Mutating Ops
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/test/test_torch.py b/test/test_torch.py
index 5d3cb7f..d62cf74 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -6261,6 +6261,71 @@
torch.bitwise_xor(torch.tensor([True, True, False], device=device),
torch.tensor([False, True, False], device=device)))
+ @onlyOnCPUAndCUDA
+ @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+ @dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False),
+ torch.testing.get_all_dtypes(include_complex=False))))
+ def test_heaviside(self, device, dtypes):
+ input_dtype = dtypes[0]
+ values_dtype = dtypes[1]
+
+ rng = np.random.default_rng()
+ input = np.array(rng.integers(-10, 10, size=10),
+ dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64])
+ input[0] = input[3] = input[7] = 0
+ values = np.array(rng.integers(-10, 10, size=10),
+ dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64])
+ np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype)
+
+ input = torch.from_numpy(input).to(device=device, dtype=input_dtype)
+ values = torch.from_numpy(values).to(device=device, dtype=values_dtype)
+ out = torch.empty_like(input)
+
+ if input_dtype == values_dtype:
+ torch_result = torch.heaviside(input, values)
+ self.assertEqual(np_result, torch_result)
+
+ torch_result = input.heaviside(values)
+ self.assertEqual(np_result, torch_result)
+
+ torch.heaviside(input, values, out=out)
+ self.assertEqual(np_result, out)
+
+ input.heaviside_(values)
+ self.assertEqual(np_result, input)
+ else:
+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
+ torch.heaviside(input, values)
+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
+ input.heaviside(values)
+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
+ torch.heaviside(input, values, out=out)
+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'):
+ input.heaviside_(values)
+
+
+ @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+ @dtypes(*list(product(torch.testing.get_all_complex_dtypes(),
+ torch.testing.get_all_complex_dtypes())))
+ def test_heaviside_complex(self, device, dtypes):
+ input_dtype = dtypes[0]
+ values_dtype = dtypes[1]
+
+ data = (complex(0, -6), complex(-1, 3), complex(1, 1))
+ input = torch.tensor(data, device=device, dtype=input_dtype)
+ values = torch.tensor(data, device=device, dtype=values_dtype)
+ out = torch.empty_like(input)
+ real = input.real
+
+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
+ torch.heaviside(input, real)
+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
+ real.heaviside(values)
+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
+ input.heaviside_(values)
+ with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'):
+ torch.heaviside(real, real, out=out)
+
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
@dtypes(*torch.testing.get_all_dtypes())
def test_logical_not(self, device, dtype):
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index 18cd382..e747127 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -1535,6 +1535,20 @@
See :func:`torch.nn.functional.hardshrink`
""")
+add_docstr_all('heaviside',
+ r"""
+heaviside(values) -> Tensor
+
+See :func:`torch.heaviside`
+""")
+
+add_docstr_all('heaviside_',
+ r"""
+heaviside_(values) -> Tensor
+
+In-place version of :meth:`~Tensor.heaviside`
+""")
+
add_docstr_all('histc',
r"""
histc(bins=100, min=0, max=0) -> Tensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 13c2be2..157ef73 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -5591,6 +5591,40 @@
""".format(**common_args))
+add_docstr(torch.heaviside,
+ r"""
+heaviside(input, values, *, out=None) -> Tensor
+
+Computes the Heaviside step function for each element in :attr:`input`.
+The Heaviside step function is defined as:
+
+.. math::
+ \text{{heaviside}}(input, values) = \begin{cases}
+ \0, & \text{if input < 0}\\
+ \values, & \text{if input == 0}\\
+ \1, & \text{if input > 0}
+ \end{cases}
+""" + r"""
+
+Args:
+ {input}
+ values (Tensor): The values to use where :attr:`input` is zero.
+
+Keyword arguments:
+ {out}
+
+Example::
+
+ >>> input = torch.tensor([-1.5, 0, 2.0])
+ >>> values = torch.tensor([0.5])
+ >>> torch.heaviside(input, values)
+ tensor([0.0000, 0.5000, 1.0000])
+ >>> values = torch.tensor([1.2, -2.0, 3.5])
+ >>> torch.heaviside(input, values)
+ tensor([0., -2., 1.])
+
+""".format(**common_args))
+
add_docstr(torch.rand,
r"""
rand(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index f0e6599..ea9956a 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -377,6 +377,7 @@
torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.gt: lambda input, other, out=None: -1,
torch.hardshrink: lambda input, lambd=0.5: -1,
+ torch.heaviside: lambda input, values, out=None: -1,
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
torch.hspmm: lambda mat1, mat2, out=None: -1,