add bfloat16 impl for nextafter (#61829)
Summary:
Add `BFloat16` support for `nextafter`.
* [x] Add OpInfo
* [x] Add Implementation Test (C++ tests)
* [x] Add credit
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61829
Reviewed By: ejguan
Differential Revision: D29932498
Pulled By: mruberry
fbshipit-source-id: 89524531a4800569ba1addd08a4ace330a6f72a4
diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
index 4c3954e..2a8f73c 100644
--- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
@@ -929,7 +929,15 @@
}
void nextafter_kernel(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nextafter_cpu", [&]() {
+ if (iter.common_dtype() == kBFloat16) {
+ using scalar_t = c10::BFloat16;
+ cpu_kernel(
+ iter,
+ [=](scalar_t a, scalar_t b) -> scalar_t {
+ return std::nextafter(a, b);
+ });
+ } else {
+ AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t {
@@ -939,6 +947,7 @@
return a.nextafter(b);
});
});
+ }
}
void heaviside_kernel(TensorIteratorBase& iter) {
diff --git a/aten/src/ATen/native/cuda/StepKernel.cu b/aten/src/ATen/native/cuda/StepKernel.cu
index 25161d7..94472b4 100644
--- a/aten/src/ATen/native/cuda/StepKernel.cu
+++ b/aten/src/ATen/native/cuda/StepKernel.cu
@@ -3,6 +3,7 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
+#include <c10/util/BFloat16-math.h>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
@@ -10,9 +11,9 @@
namespace at { namespace native {
void nextafter_kernel_cuda(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cuda", [&]() {
+ AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.common_dtype(), "nextafter_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
- return ::nextafter(a, b);
+ return std::nextafter(a, b);
});
});
}
diff --git a/c10/test/util/bfloat16_test.cpp b/c10/test/util/bfloat16_test.cpp
index 54f03e8..2cce81f 100644
--- a/c10/test/util/bfloat16_test.cpp
+++ b/c10/test/util/bfloat16_test.cpp
@@ -1,4 +1,7 @@
+// clang-format off
#include <c10/util/BFloat16.h>
+#include <c10/util/BFloat16-math.h>
+// clang-format on
#include <gtest/gtest.h>
namespace {
@@ -139,6 +142,22 @@
EXPECT_EQ(res, expected);
}
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+TEST(BFloat16Math, NextAfterZero) {
+ const c10::BFloat16 zero{0};
+
+ auto check_nextafter =
+ [](c10::BFloat16 from, c10::BFloat16 to, c10::BFloat16 expected) {
+ c10::BFloat16 actual = std::nextafter(from, to);
+ // Check for bitwise equality!
+ ASSERT_EQ(actual.x ^ expected.x, uint16_t{0});
+ };
+ check_nextafter(zero, zero, /*expected=*/zero);
+ check_nextafter(zero, -zero, /*expected=*/-zero);
+ check_nextafter(-zero, zero, /*expected=*/zero);
+ check_nextafter(-zero, -zero, /*expected=*/-zero);
+}
+
float BinaryToFloat(uint32_t bytes) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float res;
diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h
index ac625ba..2760100 100644
--- a/c10/util/BFloat16-math.h
+++ b/c10/util/BFloat16-math.h
@@ -91,4 +91,81 @@
return std::fmod(float(a), float(b));
}
+/*
+ The following function is inspired from the implementation in `musl`
+ Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
+ ----------------------------------------------------------------------
+ Copyright © 2005-2020 Rich Felker, et al.
+
+ Permission is hereby granted, free of charge, to any person obtaining
+ a copy of this software and associated documentation files (the
+ "Software"), to deal in the Software without restriction, including
+ without limitation the rights to use, copy, modify, merge, publish,
+ distribute, sublicense, and/or sell copies of the Software, and to
+ permit persons to whom the Software is furnished to do so, subject to
+ the following conditions:
+
+ The above copyright notice and this permission notice shall be
+ included in all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ ----------------------------------------------------------------------
+ */
+C10_HOST_DEVICE inline c10::BFloat16 nextafter(
+ c10::BFloat16 from,
+ c10::BFloat16 to) {
+ // Reference:
+ // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
+ using int_repr_t = uint16_t;
+ using float_t = c10::BFloat16;
+ constexpr uint8_t bits = 16;
+ union {
+ float_t f;
+ int_repr_t i;
+ } ufrom = {from}, uto = {to};
+
+ // get a mask to get the sign bit i.e. MSB
+ int_repr_t sign_mask = int_repr_t{1} << (bits - 1);
+
+ // short-circuit: if either is NaN, return NaN
+ if (from != from || to != to) {
+ return from + to;
+ }
+
+ // short-circuit: if they are exactly the same.
+ if (ufrom.i == uto.i) {
+ return from;
+ }
+
+ // mask the sign-bit to zero i.e. positive
+ // equivalent to abs(x)
+ int_repr_t abs_from = ufrom.i & ~sign_mask;
+ int_repr_t abs_to = uto.i & ~sign_mask;
+ if (abs_from == 0) {
+ // if both are zero but with different sign,
+ // preserve the sign of `to`.
+ if (abs_to == 0) {
+ return to;
+ }
+ // smallest subnormal with sign of `to`.
+ ufrom.i = (uto.i & sign_mask) | int_repr_t{1};
+ return ufrom.f;
+ }
+
+ // if abs(from) > abs(to) or sign(from) != sign(to)
+ if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) {
+ ufrom.i--;
+ } else {
+ ufrom.i++;
+ }
+
+ return ufrom.f;
+}
+
} // namespace std
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index ff8f2f9..3de7f2b 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -1798,6 +1798,49 @@
expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy())
self.assertEqual(actual, expected, atol=0, rtol=0)
+ @onlyOnCPUAndCUDA
+ @dtypes(torch.bfloat16)
+ def test_nextafter_bfloat16(self, device, dtype):
+ nan = float('nan')
+ inf = float('inf')
+ cases = (
+ # (from, to, expected)
+ (0, 1, 9.183549615799121e-41),
+ (0, -1, -9.183549615799121e-41),
+ (1, -2, 0.99609375),
+ (1, 0, 0.99609375),
+ (1, 2, 1.0078125),
+ (-1, -2, -1.0078125),
+ (-1, 0, -0.99609375),
+ (2, -1, 1.9921875),
+ (2, 1, 1.9921875),
+ (20, 3000, 20.125),
+ (20, -3000, 19.875),
+ (3000, -20, 2992.0),
+ (-3000, 20, -2992.0),
+ (65536, 0, 65280.0) ,
+ (65536, inf, 66048.0),
+ (-65536, 0, -65280.0),
+ (-65536, -inf, -66048.0),
+ (nan, 0, nan),
+ (0, nan, nan),
+ (nan, nan, nan),
+ (nan, inf, nan),
+ (inf, nan, nan),
+ (inf, -inf, 3.3895313892515355e+38),
+ (-inf, inf, -3.3895313892515355e+38),
+ (inf, 0, 3.3895313892515355e+38),
+ (0, inf, 9.183549615799121e-41),
+ (-inf, 0, -3.3895313892515355e+38),
+ (0, -inf, -9.183549615799121e-41),
+ )
+
+ for from_v, to_v, expected in cases:
+ from_t = torch.tensor([from_v], device=device, dtype=dtype)
+ to_t = torch.tensor([to_v], device=device, dtype=dtype)
+ actual = torch.nextafter(from_t, to_t).item()
+ self.assertEqual(actual, expected, atol=0, rtol=0)
+
def _test_cop(self, torchfn, mathfn, dtype, device):
def reference_implementation(res2):
for i, j in iter_indices(sm1):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 91138d4..e4e6051 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -3626,6 +3626,23 @@
return list(sample_generator())
+
+def sample_inputs_nextafter(op_info, device, dtype, requires_grad, **kwargs):
+ make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+ cases = (
+ ((S, S), (S, S), False),
+ ((S, S), (S,), False),
+ ((S, ), (S, S), True)
+ )
+
+ def generator():
+ for shape, other_shape, broadcasts_input in cases:
+ yield SampleInput(make_arg(shape), args=(make_arg(other_shape),), broadcasts_input=broadcasts_input)
+
+ return list(generator())
+
+
def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs):
vec_sample = SampleInput(make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad))
@@ -6488,6 +6505,10 @@
'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
],
),
+ OpInfo('nextafter',
+ dtypes=floating_types_and(torch.bfloat16),
+ supports_autograd=False,
+ sample_inputs_func=sample_inputs_nextafter),
OpInfo('topk',
dtypes=all_types(),
dtypesIfCPU=all_types_and(torch.bfloat16),