Implement torch.sort with cub::DeviceSegmentedRadixSort (#56821)

Summary:
Benchmark:
```python
import torch
import itertools

for i in range(1000):
    torch.arange(100000, device='cuda')

def run50_sync(f):
    for _ in range(50):
        f()
    torch.cuda.synchronize()

for i, j in itertools.product([512, 4096, 8192], repeat=2):
    print(i,j)
    t = torch.randn(i, j, device='cuda')
    torch.cuda.synchronize()
    %timeit run50_sync(lambda: torch.sort(t))
    torch.cuda.synchronize()
    %timeit run50_sync(lambda: torch.sort(t, dim=0))
    print()
```

Before
```
512 512
4.02 ms ± 28.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

512 4096
40.7 ms ± 74.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
33.9 ms ± 186 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

512 8192
71.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
66.4 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

4096 512
27.6 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
46.6 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

4096 4096
262 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
321 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

4096 8192
520 ms ± 5.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
661 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

8192 512
54.6 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
83.2 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

8192 4096
521 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
645 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

8192 8192
1.04 s ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.34 s ± 541 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

After
```
512 512
4.65 ms ± 62.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.75 ms ± 62.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

512 4096
30.3 ms ± 261 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
39.4 ms ± 421 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

512 8192
59.7 ms ± 344 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
77 ms ± 601 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

4096 512
32.2 ms ± 376 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
37.1 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

4096 4096
204 ms ± 471 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
272 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

4096 8192
422 ms ± 3.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
562 ms ± 4.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

8192 512
63.1 ms ± 595 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
72.7 ms ± 532 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

8192 4096
401 ms ± 3.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
573 ms ± 2.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

8192 8192
831 ms ± 7.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.2 s ± 9.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56821

Reviewed By: mrshenli

Differential Revision: D28172609

Pulled By: ngimel

fbshipit-source-id: 87314a6985a84d326304ff5220df5661ef00d710
diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh
index 60e6dc9..84e673d 100644
--- a/aten/src/ATen/cuda/cub.cuh
+++ b/aten/src/ATen/cuda/cub.cuh
@@ -63,7 +63,7 @@
 template<typename key_t>
 static inline void sort_keys(
     const key_t *keys_in, key_t *keys_out,
-    int64_t n, bool descending=false, int64_t start_bit=0, int64_t end_bit=sizeof(key_t)*8
+    int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
 ) {
   using key_t_ = typename cuda_type<key_t>::type;
 
@@ -73,11 +73,11 @@
   if (descending) {
     CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortKeysDescending,
       keys_in_, keys_out_, n,
-      start_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+      begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
   } else {
     CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortKeys,
       keys_in_, keys_out_, n,
-      start_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+      begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
   }
 }
 
@@ -85,23 +85,18 @@
 static inline void sort_pairs(
     const key_t *keys_in, key_t *keys_out,
     const value_t *values_in, value_t *values_out,
-    int64_t n, bool descending=false, int64_t start_bit=0, int64_t end_bit=sizeof(key_t)*8
+    int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
 ) {
   using key_t_ = typename cuda_type<key_t>::type;
   using value_t_ = typename cuda_type<value_t>::type;
 
   auto allocator = c10::cuda::CUDACachingAllocator::get();
   c10::DataPtr keys_out_owner;
-  c10::DataPtr values_out_owner;
 
   if (keys_out == nullptr) {
     keys_out_owner = allocator->allocate(n * sizeof(key_t));
     keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
   }
-  if (values_out == nullptr) {
-    values_out_owner = allocator->allocate(n * sizeof(value_t));
-    values_out = reinterpret_cast<value_t *>(values_out_owner.get());
-  }
 
   const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
   key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
@@ -111,11 +106,48 @@
   if (descending) {
     CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortPairsDescending,
       keys_in_, keys_out_, values_in_, values_out_, n,
-      start_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+      begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
   } else {
     CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortPairs,
       keys_in_, keys_out_, values_in_, values_out_, n,
-      start_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+      begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+  }
+}
+
+template<typename key_t, typename value_t, typename OffsetIteratorT>
+static inline void segmented_sort_pairs(
+    const key_t *keys_in, key_t *keys_out,
+    const value_t *values_in, value_t *values_out,
+    int64_t num_elements, int64_t num_segments,
+    OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
+    bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
+) {
+  using key_t_ = typename cuda_type<key_t>::type;
+  using value_t_ = typename cuda_type<value_t>::type;
+
+  auto allocator = c10::cuda::CUDACachingAllocator::get();
+  c10::DataPtr keys_out_owner;
+
+  if (keys_out == nullptr) {
+    keys_out_owner = allocator->allocate(num_elements * sizeof(key_t));
+    keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
+  }
+
+  const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
+  key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
+  const value_t_ *values_in_ = reinterpret_cast<const value_t_*>(values_in);
+  value_t_ *values_out_ = reinterpret_cast<value_t_*>(values_out);
+
+  if (descending) {
+    CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
+      keys_in_, keys_out_, values_in_, values_out_,
+      num_elements, num_segments, begin_offsets, end_offsets,
+      begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+  } else {
+    CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSegmentedRadixSort::SortPairs,
+      keys_in_, keys_out_, values_in_, values_out_,
+      num_elements, num_segments, begin_offsets, end_offsets,
+      begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
   }
 }
 
diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu
index 4ae74e2..928d209 100644
--- a/aten/src/ATen/native/cuda/Sort.cu
+++ b/aten/src/ATen/native/cuda/Sort.cu
@@ -11,28 +11,6 @@
 #include <ATen/native/cuda/SortUtils.cuh>
 #include <ATen/native/cuda/SortingCommon.cuh>
 
-namespace {
-
-template<typename scalar_t>
-__global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64_t *index, const int2 *i_s_ptr, int nsegments, int nsort) {
-  CUDA_KERNEL_LOOP(i, nsegments * nsort) {
-    int segment = i / nsort;
-    int j = i % nsort;
-
-    int offset = segment * nsort;
-    const scalar_t *in_ = in + offset;
-    scalar_t *out_ = out + offset;
-    int64_t *index_ = index + offset;
-    const int2 *i_s_ptr_ = i_s_ptr + offset;
-
-    int idx = i_s_ptr_[j].y;
-    index_[j] = idx;
-    out_[j] = in_[idx];
-  }
-}
-
-}
-
 namespace at { namespace native {
 
 bool should_use_small_sort(const Tensor &self, int64_t dim) {
@@ -209,34 +187,21 @@
 #undef HANDLE_A_CASE
 }
 
-// We perform a vectorized segmented sort in cub with inputs that have
+namespace {
+
+struct offset_t {
+  int stride;
+  int begin;
+  __device__ int operator[](int i) {
+    return stride * (begin + i);
+  }
+};
+
+}
+
+// We perform a segmented sort in cub with inputs that have
 // more than 1024/2048 elements along the selected dimension.
 // Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace).
-// Large sort algorithm:.
-// Say we are sorting a (2, 3) tensor. We have in flattened form:
-// values       0.4 1.2 5.3 6.2 1.3 2.3
-// indices        0   1   2   0   1   2
-// segment_id     0   0   0   1   1   1
-
-// First we sort by values, globally:
-// values       6.2 5.3 2.3 1.2 1.3 0.4
-// indices        0   2   2   1   1   0
-// segment_id     1   0   1   0   1   0
-
-// Then we stable sort by segment id:
-// values       5.3 1.2 0.4 6.2 2.3 1.3
-// indices        2   1   0   0   2   1
-// segment_id     0   0   0   1   1   1
-
-// This method can only work if the slice we are sorting (`dim`) is
-// innermost, and both values and indices are contiguous. We do this
-// by re-arranging the input into this form as needed, which will
-// unfortunately allocate memory if the request is not in this form.
-// Vectorized sort is slower than iterated sort if the number of
-// slices is small (since we're sorting twice, instead of invoking a
-// smaller sort `numSlices` times), but the cub sort
-// implementation here is a catch-all, so we're not looking for
-// efficiency, but instead correctness.
 std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::optional<bool> stable, int64_t dim, bool descending, Tensor & values, Tensor & indices) {
   // this algorithm is always stable
   TORCH_INTERNAL_ASSERT(stable.has_value(), "sort_out(): c10::optional<bool> for stable has to have value.");
@@ -365,32 +330,12 @@
     while (remaining > 0) {
       int64_t n = std::min(remaining, nbatch);
       int64_t nsegments = n / nsort;
-      int64_t segment_bits = std::max<int64_t>(1L, static_cast<int64_t>(std::ceil(std::log2(nsegments))));
 
-      auto int_options = indices.options().dtype(kInt);
-      auto indices_and_segment = at::empty({nsegments, nsort, 2}, int_options);
-      indices_and_segment.select(-1, 0).copy_(  // segment id
-        at::arange(nsegments, int_options).view({nsegments, 1}).expand({nsegments, nsort}));
-      indices_and_segment.select(-1, 1).copy_(  // reverse indices
-        at::arange(nsort, int_options).view({1, nsort}).expand({nsegments, nsort}));
+      auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous();
 
-      auto i_s_ptr = reinterpret_cast<int2 *>(indices_and_segment.data_ptr<int>());
-      auto indices_and_segment2 = at::empty_like(indices_and_segment);
-      auto i_s_ptr2 = reinterpret_cast<int2 *>(indices_and_segment2.data_ptr<int>());
-
-      at::cuda::cub::sort_pairs<scalar_t, int2>(
-        self_ptr, nullptr, i_s_ptr, i_s_ptr2,
-        n, descending);
-
-      TORCH_INTERNAL_ASSERT(segment_bits <= 32);
-
-      // sort on lower 32bits, i.e. segment index
-      at::cuda::cub::sort_keys<int64_t>(
-        reinterpret_cast<int64_t *>(i_s_ptr2), reinterpret_cast<int64_t *>(i_s_ptr),
-        n, false, 0, segment_bits);
-
-      sort_postprocess_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
-        self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort);
+      at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr,
+        reverse_indices.data_ptr<int64_t>(), indices_ptr, n, nsegments,
+        offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending);
 
       remaining -= n;
       self_ptr += n;
diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py
index e3224716..38cb069 100644
--- a/test/test_sort_and_select.py
+++ b/test/test_sort_and_select.py
@@ -121,14 +121,7 @@
     # FIXME: remove torch.bool from unsupported types once support is added for cub sort
     @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bool, torch.bfloat16, torch.complex64, torch.complex128})
     def test_stable_sort(self, device, dtype):
-        if self.device_type == 'cpu':
-            sizes = (100, 1000, 10000)
-        elif self.device_type == 'cuda':
-            # On CUDA, stable sort is supported only when the size of
-            # the sorted dim is greater than 2048
-            sizes = (1025, 10000)
-        else:
-            return
+        sizes = (100, 1000, 10000)
         for ncopies in sizes:
             x = torch.tensor([0, 1] * ncopies, dtype=dtype, device=device)
             _, idx = x.sort(stable=True)