[torch] Add cuda support for segment reduction 'max' (#54175)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54175

Building on top of previous PR. This PR adds cuda support for 1D max reduction.

Next steps:
- Add support for other major reduction types (e.g. min, sum) for 1D tensor
- Documentation for the op
- Perf optimizations and benchmark util
- Backward support  (not high priority)
- Support for multi dimensional tensors (on data and lengths) (not high priority)
- Support for 'indices' (not high priority)

Test Plan: Added unit test

Reviewed By: ngimel

Differential Revision: D27121170

fbshipit-source-id: 1c2565f42e2903e6fc089d56983ce8857efbfa3c
diff --git a/BUILD.bazel b/BUILD.bazel
index 9f0d759..bed4ebc 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -498,6 +498,7 @@
         "aten/src/ATen/native/cuda/Repeat.cu.cc",
         "aten/src/ATen/native/cuda/ReplicationPadding.cu.cc",
         "aten/src/ATen/native/cuda/Resize.cu.cc",
+        "aten/src/ATen/native/cuda/SegmentReduce.cu.cc",
         "aten/src/ATen/native/cuda/SoftMax.cu.cc",
         "aten/src/ATen/native/cuda/SortingKthValue.cu.cc",
         "aten/src/ATen/native/cuda/SparseMM.cu.cc",
diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp
index 6c59c9e..d474d8f 100644
--- a/aten/src/ATen/native/SegmentReduce.cpp
+++ b/aten/src/ATen/native/SegmentReduce.cpp
@@ -1,43 +1,22 @@
 #include <ATen/native/SegmentReduce.h>
 
 #include <ATen/ATen.h>
+#include <ATen/Dispatch.h>
 #include <ATen/NumericUtils.h>
 
 namespace at {
 namespace native {
 
-DEFINE_DISPATCH(segment_reduce_stub);
+DEFINE_DISPATCH(_segment_reduce_stub);
 
-enum ReductionType { MAX };
-const std::map<std::string, ReductionType> reduce2REDUCE = {
-    {"max", MAX},
-};
+namespace {
 
-Tensor _segment_reduce_cpu(
+Tensor _segment_reduce_cpu_kernel(
     const Tensor& data,
-    std::string reduce,
-    const c10::optional<Tensor>& lengths,
-    const c10::optional<Tensor>& indices,
+    const Tensor& lengths,
     int64_t axis,
     bool unsafe) {
-  axis = maybe_wrap_dim(axis, data.ndimension());
-  TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!");
-  TORCH_CHECK(data.dim() == 1);
-  TORCH_CHECK(data.numel() > 0);
-  TORCH_CHECK(
-      reduce2REDUCE.at(reduce) == MAX,
-      "Currently only 'max' reduction is supported!");
-
-  // length related checks
-  TORCH_CHECK(
-      lengths.has_value() && !indices.has_value(),
-      "Currently only lengths based reduction is supported!")
-  const auto& lengths_value = lengths.value();
-  TORCH_CHECK(lengths_value.dim() == 1);
-  TORCH_CHECK(data.get_device() == lengths_value.get_device());
-  TORCH_CHECK(data.dim() >= lengths_value.dim());
-
-  const auto lengths_contig = lengths_value.contiguous();
+  const auto lengths_contig = lengths.contiguous();
   const auto data_contig = data.contiguous();
 
   int64_t batch_size = lengths_contig.numel();
@@ -47,7 +26,8 @@
   if (!unsafe) {
     int64_t sum = 0;
     for (int64_t i = 0; i < batch_size; ++i) {
-      TORCH_CHECK(lengths_data[i] > 0);
+      TORCH_CHECK(
+          (lengths_data[i] > 0), "lengths contains non positive value!");
       sum += lengths_data[i];
     }
     TORCH_CHECK(sum == data.numel());
@@ -80,5 +60,49 @@
   return output;
 }
 
+} // namespace
+
+enum SegmentReductionType { MAX };
+static const std::map<std::string, SegmentReductionType> segmentReduce2REDUCE =
+    {
+        {"max", MAX},
+};
+
+Tensor segment_reduce_kernel(
+    const Tensor& data,
+    std::string reduce,
+    const c10::optional<Tensor>& lengths,
+    const c10::optional<Tensor>& indices,
+    int64_t axis,
+    bool unsafe) {
+  axis = maybe_wrap_dim(axis, data.ndimension());
+  TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!");
+  TORCH_CHECK(data.dim() == 1);
+  TORCH_CHECK(data.numel() > 0);
+  TORCH_CHECK(
+      at::native::segmentReduce2REDUCE.at(reduce) == MAX,
+      "Currently only 'max' reduction is supported!");
+
+  // length related checks
+  TORCH_CHECK(
+      lengths.has_value() && !indices.has_value(),
+      "Currently only lengths based reduction is supported!")
+  const auto& lengths_value = lengths.value();
+  TORCH_CHECK(lengths_value.dim() == 1);
+  TORCH_CHECK(data.get_device() == lengths_value.get_device());
+  TORCH_CHECK(data.dim() >= lengths_value.dim());
+
+  return _segment_reduce_stub(
+      data.device().type(), data, lengths_value, axis, unsafe);
+}
+
+REGISTER_ARCH_DISPATCH(
+    _segment_reduce_stub,
+    DEFAULT,
+    &_segment_reduce_cpu_kernel);
+REGISTER_AVX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
+REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
+REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/SegmentReduce.h b/aten/src/ATen/native/SegmentReduce.h
index 94eef1f..302690a 100644
--- a/aten/src/ATen/native/SegmentReduce.h
+++ b/aten/src/ATen/native/SegmentReduce.h
@@ -7,14 +7,9 @@
 namespace at {
 namespace native {
 
-using segment_reduce_fn = void (*)(
-    const Tensor&,
-    std::string,
-    const c10::optional<Tensor>&,
-    const c10::optional<Tensor>&,
-    int64_t,
-    bool);
-DECLARE_DISPATCH(segment_reduce_fn, segment_reduce_stub);
+using segment_reduce_fn =
+    Tensor (*)(const Tensor&, const Tensor&, int64_t, bool);
+DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/cuda/SegmentReduce.cu b/aten/src/ATen/native/cuda/SegmentReduce.cu
new file mode 100644
index 0000000..10395a9
--- /dev/null
+++ b/aten/src/ATen/native/cuda/SegmentReduce.cu
@@ -0,0 +1,122 @@
+
+#include <ATen/native/SegmentReduce.h>
+
+#include <ATen/ATen.h>
+#include <ATen/NumericUtils.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/detail/KernelUtils.h>
+#include <c10/cuda/CUDACachingAllocator.h>
+#include <ATen/cuda/CubUtils.cuh>
+#include <iostream>
+
+namespace at {
+namespace native {
+
+struct CustomMax {
+  template <typename OutputT>
+  __host__ __device__ __forceinline__ OutputT
+  operator()(const OutputT& a, const OutputT& b) {
+    if (at::_isnan(a)) {
+      return a;
+    } else if (at::_isnan(b)) {
+      return b;
+    }
+    return std::max<OutputT>(a, b);
+  }
+};
+
+Tensor _get_complete_sum(const Tensor& lengths) {
+  int64_t segment_count = lengths.numel();
+  auto offsets = at::empty({segment_count + 1}, lengths.options());
+  offsets[0].zero_();
+  auto* lengths_data_ptr = lengths.data_ptr<int64_t>();
+  auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
+  size_t temp_storage_bytes = 0;
+  AT_CUDA_CHECK(cub::DeviceScan::InclusiveSum(
+                    nullptr,
+                    temp_storage_bytes,
+                    lengths_data_ptr,
+                    offsets_data_ptr + 1,
+                    segment_count,
+                    at::cuda::getCurrentCUDAStream()););
+
+  auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
+  auto dataPtr = allocator.allocate(temp_storage_bytes);
+
+  AT_CUDA_CHECK(cub::DeviceScan::InclusiveSum(
+                    dataPtr.get(),
+                    temp_storage_bytes,
+                    lengths_data_ptr,
+                    offsets_data_ptr + 1,
+                    segment_count,
+                    at::cuda::getCurrentCUDAStream()););
+
+  return offsets;
+}
+
+Tensor _segment_reduce_cuda_kernel(
+    const Tensor& data,
+    const Tensor& lengths,
+    int64_t axis,
+    bool unsafe) {
+  if (!unsafe) {
+    TORCH_CHECK(
+        (lengths.min().item<int64_t>() > 0),
+        "lengths contains non positive value!");
+    TORCH_CHECK(lengths.sum().item<int64_t>() == data.numel());
+  }
+
+  int64_t segment_count = lengths.numel();
+  const auto data_contig = data.contiguous();
+  auto output = at::empty({segment_count}, data.options());
+
+  const auto lengths_contig = lengths.contiguous();
+  auto offsets = _get_complete_sum(lengths_contig);
+  auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
+
+  AT_DISPATCH_ALL_TYPES_AND2(
+      at::ScalarType::Half,
+      at::ScalarType::BFloat16,
+      data.scalar_type(),
+      "segment_reduce_cuda",
+      [&]() {
+        auto* data_contig_data_ptr = data_contig.data_ptr<scalar_t>();
+        auto* output_data_ptr = output.data_ptr<scalar_t>();
+
+        CustomMax max_op{};
+        size_t temp_storage_bytes = 0;
+        scalar_t initial_value = std::numeric_limits<scalar_t>::lowest();
+        AT_CUDA_CHECK(cub::DeviceSegmentedReduce::Reduce(
+                          nullptr,
+                          temp_storage_bytes,
+                          data_contig_data_ptr,
+                          output_data_ptr,
+                          segment_count,
+                          offsets_data_ptr,
+                          offsets_data_ptr + 1,
+                          max_op,
+                          initial_value,
+                          at::cuda::getCurrentCUDAStream()););
+
+        auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
+        auto dataPtr = allocator.allocate(temp_storage_bytes);
+
+        AT_CUDA_CHECK(cub::DeviceSegmentedReduce::Reduce(
+                          dataPtr.get(),
+                          temp_storage_bytes,
+                          data_contig_data_ptr,
+                          output_data_ptr,
+                          segment_count,
+                          offsets_data_ptr,
+                          offsets_data_ptr + 1,
+                          max_op,
+                          initial_value,
+                          at::cuda::getCurrentCUDAStream()););
+      });
+  return output;
+}
+
+REGISTER_DISPATCH(_segment_reduce_stub, &_segment_reduce_cuda_kernel);
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index d106f77..052cb7d 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -8872,4 +8872,4 @@
 - func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False) -> Tensor
   variants: function
   dispatch:
-    CPU: _segment_reduce_cpu
+    CPU, CUDA: segment_reduce_kernel
diff --git a/test/test_segment_reductions.py b/test/test_segment_reductions.py
index b65d799..4483db2 100644
--- a/test/test_segment_reductions.py
+++ b/test/test_segment_reductions.py
@@ -1,8 +1,8 @@
 import torch
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests,
-    onlyCPU,
     dtypes,
+    dtypesIfCUDA,
 )
 from torch.testing._internal.common_utils import (
     TestCase,
@@ -11,25 +11,29 @@
 
 
 class TestSegmentReductions(TestCase):
-    @onlyCPU
-    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
-    def test_max_simple_1d(self, device, dtype):
+    def _test_max_simple_1d(self, device, dtype, unsafe):
         lengths = torch.tensor([1, 2, 3], device=device)
         data = torch.tensor([1, float("nan"), 3, 4, 5, 6], device=device, dtype=dtype)
         expected_result = torch.tensor([1, float("nan"), 6], device=device, dtype=dtype)
         actual_result = torch.segment_reduce(
-            data=data, reduce="max", lengths=lengths, axis=0, unsafe=False
+            data=data, reduce="max", lengths=lengths, axis=0, unsafe=unsafe
         )
         self.assertEqual(
             expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
         )
         actual_result = torch.segment_reduce(
-            data=data, reduce="max", lengths=lengths, axis=-1, unsafe=False
+            data=data, reduce="max", lengths=lengths, axis=-1, unsafe=unsafe
         )
         self.assertEqual(
             expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
         )
 
+    @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
+    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
+    def test_max_simple_1d(self, device, dtype):
+        self._test_max_simple_1d(device, dtype, False)
+        self._test_max_simple_1d(device, dtype, True)
+
 
 instantiate_device_type_tests(TestSegmentReductions, globals())