add bfloat16 impl for nextafter (#61829)

Summary:
Add `BFloat16` support for `nextafter`.

* [x] Add OpInfo
* [x] Add Implementation Test (C++ tests)
* [x] Add credit

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61829

Reviewed By: ejguan

Differential Revision: D29932498

Pulled By: mruberry

fbshipit-source-id: 89524531a4800569ba1addd08a4ace330a6f72a4
diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
index 4c3954e..2a8f73c 100644
--- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
@@ -929,7 +929,15 @@
 }
 
 void nextafter_kernel(TensorIteratorBase& iter) {
-  AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nextafter_cpu", [&]() {
+  if (iter.common_dtype() == kBFloat16) {
+    using scalar_t = c10::BFloat16;
+    cpu_kernel(
+        iter,
+        [=](scalar_t a, scalar_t b) -> scalar_t {
+            return std::nextafter(a, b);
+        });
+  } else {
+    AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cpu", [&]() {
     cpu_kernel_vec(
         iter,
         [=](scalar_t a, scalar_t b) -> scalar_t {
@@ -939,6 +947,7 @@
             return a.nextafter(b);
         });
   });
+  }
 }
 
 void heaviside_kernel(TensorIteratorBase& iter) {
diff --git a/aten/src/ATen/native/cuda/StepKernel.cu b/aten/src/ATen/native/cuda/StepKernel.cu
index 25161d7..94472b4 100644
--- a/aten/src/ATen/native/cuda/StepKernel.cu
+++ b/aten/src/ATen/native/cuda/StepKernel.cu
@@ -3,6 +3,7 @@
 #include <ATen/native/cuda/Loops.cuh>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/BinaryOps.h>
+#include <c10/util/BFloat16-math.h>
 
 // NOTE: CUDA on Windows requires that the enclosing function
 // of a __device__ lambda not have internal linkage.
@@ -10,9 +11,9 @@
 namespace at { namespace native {
 
 void nextafter_kernel_cuda(TensorIteratorBase& iter) {
-  AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cuda", [&]() {
+  AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.common_dtype(), "nextafter_cuda", [&]() {
     gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
-      return ::nextafter(a, b);
+      return std::nextafter(a, b);
     });
   });
 }
diff --git a/c10/test/util/bfloat16_test.cpp b/c10/test/util/bfloat16_test.cpp
index 54f03e8..2cce81f 100644
--- a/c10/test/util/bfloat16_test.cpp
+++ b/c10/test/util/bfloat16_test.cpp
@@ -1,4 +1,7 @@
+// clang-format off
 #include <c10/util/BFloat16.h>
+#include <c10/util/BFloat16-math.h>
+// clang-format on
 #include <gtest/gtest.h>
 
 namespace {
@@ -139,6 +142,22 @@
   EXPECT_EQ(res, expected);
 }
 
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+TEST(BFloat16Math, NextAfterZero) {
+  const c10::BFloat16 zero{0};
+
+  auto check_nextafter =
+      [](c10::BFloat16 from, c10::BFloat16 to, c10::BFloat16 expected) {
+        c10::BFloat16 actual = std::nextafter(from, to);
+        // Check for bitwise equality!
+        ASSERT_EQ(actual.x ^ expected.x, uint16_t{0});
+      };
+  check_nextafter(zero, zero, /*expected=*/zero);
+  check_nextafter(zero, -zero, /*expected=*/-zero);
+  check_nextafter(-zero, zero, /*expected=*/zero);
+  check_nextafter(-zero, -zero, /*expected=*/-zero);
+}
+
 float BinaryToFloat(uint32_t bytes) {
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   float res;
diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h
index ac625ba..2760100 100644
--- a/c10/util/BFloat16-math.h
+++ b/c10/util/BFloat16-math.h
@@ -91,4 +91,81 @@
   return std::fmod(float(a), float(b));
 }
 
+/*
+  The following function is inspired from the implementation in `musl`
+  Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
+  ----------------------------------------------------------------------
+  Copyright © 2005-2020 Rich Felker, et al.
+
+  Permission is hereby granted, free of charge, to any person obtaining
+  a copy of this software and associated documentation files (the
+  "Software"), to deal in the Software without restriction, including
+  without limitation the rights to use, copy, modify, merge, publish,
+  distribute, sublicense, and/or sell copies of the Software, and to
+  permit persons to whom the Software is furnished to do so, subject to
+  the following conditions:
+
+  The above copyright notice and this permission notice shall be
+  included in all copies or substantial portions of the Software.
+
+  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+  EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+  MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+  IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+  CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+  TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+  SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+  ----------------------------------------------------------------------
+ */
+C10_HOST_DEVICE inline c10::BFloat16 nextafter(
+    c10::BFloat16 from,
+    c10::BFloat16 to) {
+  // Reference:
+  // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
+  using int_repr_t = uint16_t;
+  using float_t = c10::BFloat16;
+  constexpr uint8_t bits = 16;
+  union {
+    float_t f;
+    int_repr_t i;
+  } ufrom = {from}, uto = {to};
+
+  // get a mask to get the sign bit i.e. MSB
+  int_repr_t sign_mask = int_repr_t{1} << (bits - 1);
+
+  // short-circuit: if either is NaN, return NaN
+  if (from != from || to != to) {
+    return from + to;
+  }
+
+  // short-circuit: if they are exactly the same.
+  if (ufrom.i == uto.i) {
+    return from;
+  }
+
+  // mask the sign-bit to zero i.e. positive
+  // equivalent to abs(x)
+  int_repr_t abs_from = ufrom.i & ~sign_mask;
+  int_repr_t abs_to = uto.i & ~sign_mask;
+  if (abs_from == 0) {
+    // if both are zero but with different sign,
+    // preserve the sign of `to`.
+    if (abs_to == 0) {
+      return to;
+    }
+    // smallest subnormal with sign of `to`.
+    ufrom.i = (uto.i & sign_mask) | int_repr_t{1};
+    return ufrom.f;
+  }
+
+  // if abs(from) > abs(to) or sign(from) != sign(to)
+  if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) {
+    ufrom.i--;
+  } else {
+    ufrom.i++;
+  }
+
+  return ufrom.f;
+}
+
 } // namespace std
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index ff8f2f9..3de7f2b 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -1798,6 +1798,49 @@
         expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy())
         self.assertEqual(actual, expected, atol=0, rtol=0)
 
+    @onlyOnCPUAndCUDA
+    @dtypes(torch.bfloat16)
+    def test_nextafter_bfloat16(self, device, dtype):
+        nan = float('nan')
+        inf = float('inf')
+        cases = (
+            # (from, to, expected)
+            (0, 1, 9.183549615799121e-41),
+            (0, -1, -9.183549615799121e-41),
+            (1, -2, 0.99609375),
+            (1, 0, 0.99609375),
+            (1, 2, 1.0078125),
+            (-1, -2, -1.0078125),
+            (-1, 0, -0.99609375),
+            (2, -1, 1.9921875),
+            (2, 1, 1.9921875),
+            (20, 3000, 20.125),
+            (20, -3000, 19.875),
+            (3000, -20, 2992.0),
+            (-3000, 20, -2992.0),
+            (65536, 0, 65280.0) ,
+            (65536, inf, 66048.0),
+            (-65536, 0, -65280.0),
+            (-65536, -inf, -66048.0),
+            (nan, 0, nan),
+            (0, nan, nan),
+            (nan, nan, nan),
+            (nan, inf, nan),
+            (inf, nan, nan),
+            (inf, -inf, 3.3895313892515355e+38),
+            (-inf, inf, -3.3895313892515355e+38),
+            (inf, 0, 3.3895313892515355e+38),
+            (0, inf, 9.183549615799121e-41),
+            (-inf, 0, -3.3895313892515355e+38),
+            (0, -inf, -9.183549615799121e-41),
+        )
+
+        for from_v, to_v, expected in cases:
+            from_t = torch.tensor([from_v], device=device, dtype=dtype)
+            to_t = torch.tensor([to_v], device=device, dtype=dtype)
+            actual = torch.nextafter(from_t, to_t).item()
+            self.assertEqual(actual, expected, atol=0, rtol=0)
+
     def _test_cop(self, torchfn, mathfn, dtype, device):
         def reference_implementation(res2):
             for i, j in iter_indices(sm1):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 91138d4..e4e6051 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -3626,6 +3626,23 @@
 
     return list(sample_generator())
 
+
+def sample_inputs_nextafter(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    cases = (
+        ((S, S), (S, S), False),
+        ((S, S), (S,), False),
+        ((S, ), (S, S), True)
+    )
+
+    def generator():
+        for shape, other_shape, broadcasts_input in cases:
+            yield SampleInput(make_arg(shape), args=(make_arg(other_shape),), broadcasts_input=broadcasts_input)
+
+    return list(generator())
+
+
 def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs):
     vec_sample = SampleInput(make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad))
 
@@ -6488,6 +6505,10 @@
                 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
         ],
     ),
+    OpInfo('nextafter',
+           dtypes=floating_types_and(torch.bfloat16),
+           supports_autograd=False,
+           sample_inputs_func=sample_inputs_nextafter),
     OpInfo('topk',
            dtypes=all_types(),
            dtypesIfCPU=all_types_and(torch.bfloat16),