Implement torch.sort with cub::DeviceSegmentedRadixSort (#56821)
Summary:
Benchmark:
```python
import torch
import itertools
for i in range(1000):
torch.arange(100000, device='cuda')
def run50_sync(f):
for _ in range(50):
f()
torch.cuda.synchronize()
for i, j in itertools.product([512, 4096, 8192], repeat=2):
print(i,j)
t = torch.randn(i, j, device='cuda')
torch.cuda.synchronize()
%timeit run50_sync(lambda: torch.sort(t))
torch.cuda.synchronize()
%timeit run50_sync(lambda: torch.sort(t, dim=0))
print()
```
Before
```
512 512
4.02 ms ± 28.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
512 4096
40.7 ms ± 74.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
33.9 ms ± 186 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
512 8192
71.7 ms ± 636 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
66.4 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 512
27.6 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
46.6 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 4096
262 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
321 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4096 8192
520 ms ± 5.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
661 ms ± 853 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 512
54.6 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
83.2 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
8192 4096
521 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
645 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 8192
1.04 s ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.34 s ± 541 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
After
```
512 512
4.65 ms ± 62.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.75 ms ± 62.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
512 4096
30.3 ms ± 261 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
39.4 ms ± 421 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
512 8192
59.7 ms ± 344 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
77 ms ± 601 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 512
32.2 ms ± 376 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
37.1 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4096 4096
204 ms ± 471 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
272 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4096 8192
422 ms ± 3.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
562 ms ± 4.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 512
63.1 ms ± 595 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
72.7 ms ± 532 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
8192 4096
401 ms ± 3.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
573 ms ± 2.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8192 8192
831 ms ± 7.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.2 s ± 9.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56821
Reviewed By: mrshenli
Differential Revision: D28172609
Pulled By: ngimel
fbshipit-source-id: 87314a6985a84d326304ff5220df5661ef00d710
diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh
index 60e6dc9..84e673d 100644
--- a/aten/src/ATen/cuda/cub.cuh
+++ b/aten/src/ATen/cuda/cub.cuh
@@ -63,7 +63,7 @@
template<typename key_t>
static inline void sort_keys(
const key_t *keys_in, key_t *keys_out,
- int64_t n, bool descending=false, int64_t start_bit=0, int64_t end_bit=sizeof(key_t)*8
+ int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
using key_t_ = typename cuda_type<key_t>::type;
@@ -73,11 +73,11 @@
if (descending) {
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortKeysDescending,
keys_in_, keys_out_, n,
- start_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+ begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortKeys,
keys_in_, keys_out_, n,
- start_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+ begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
}
}
@@ -85,23 +85,18 @@
static inline void sort_pairs(
const key_t *keys_in, key_t *keys_out,
const value_t *values_in, value_t *values_out,
- int64_t n, bool descending=false, int64_t start_bit=0, int64_t end_bit=sizeof(key_t)*8
+ int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
using key_t_ = typename cuda_type<key_t>::type;
using value_t_ = typename cuda_type<value_t>::type;
auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr keys_out_owner;
- c10::DataPtr values_out_owner;
if (keys_out == nullptr) {
keys_out_owner = allocator->allocate(n * sizeof(key_t));
keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
}
- if (values_out == nullptr) {
- values_out_owner = allocator->allocate(n * sizeof(value_t));
- values_out = reinterpret_cast<value_t *>(values_out_owner.get());
- }
const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
@@ -111,11 +106,48 @@
if (descending) {
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortPairsDescending,
keys_in_, keys_out_, values_in_, values_out_, n,
- start_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+ begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortPairs,
keys_in_, keys_out_, values_in_, values_out_, n,
- start_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+ begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+ }
+}
+
+template<typename key_t, typename value_t, typename OffsetIteratorT>
+static inline void segmented_sort_pairs(
+ const key_t *keys_in, key_t *keys_out,
+ const value_t *values_in, value_t *values_out,
+ int64_t num_elements, int64_t num_segments,
+ OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
+ bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
+) {
+ using key_t_ = typename cuda_type<key_t>::type;
+ using value_t_ = typename cuda_type<value_t>::type;
+
+ auto allocator = c10::cuda::CUDACachingAllocator::get();
+ c10::DataPtr keys_out_owner;
+
+ if (keys_out == nullptr) {
+ keys_out_owner = allocator->allocate(num_elements * sizeof(key_t));
+ keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
+ }
+
+ const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
+ key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
+ const value_t_ *values_in_ = reinterpret_cast<const value_t_*>(values_in);
+ value_t_ *values_out_ = reinterpret_cast<value_t_*>(values_out);
+
+ if (descending) {
+ CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
+ keys_in_, keys_out_, values_in_, values_out_,
+ num_elements, num_segments, begin_offsets, end_offsets,
+ begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+ } else {
+ CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSegmentedRadixSort::SortPairs,
+ keys_in_, keys_out_, values_in_, values_out_,
+ num_elements, num_segments, begin_offsets, end_offsets,
+ begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
}
}
diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu
index 4ae74e2..928d209 100644
--- a/aten/src/ATen/native/cuda/Sort.cu
+++ b/aten/src/ATen/native/cuda/Sort.cu
@@ -11,28 +11,6 @@
#include <ATen/native/cuda/SortUtils.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
-namespace {
-
-template<typename scalar_t>
-__global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64_t *index, const int2 *i_s_ptr, int nsegments, int nsort) {
- CUDA_KERNEL_LOOP(i, nsegments * nsort) {
- int segment = i / nsort;
- int j = i % nsort;
-
- int offset = segment * nsort;
- const scalar_t *in_ = in + offset;
- scalar_t *out_ = out + offset;
- int64_t *index_ = index + offset;
- const int2 *i_s_ptr_ = i_s_ptr + offset;
-
- int idx = i_s_ptr_[j].y;
- index_[j] = idx;
- out_[j] = in_[idx];
- }
-}
-
-}
-
namespace at { namespace native {
bool should_use_small_sort(const Tensor &self, int64_t dim) {
@@ -209,34 +187,21 @@
#undef HANDLE_A_CASE
}
-// We perform a vectorized segmented sort in cub with inputs that have
+namespace {
+
+struct offset_t {
+ int stride;
+ int begin;
+ __device__ int operator[](int i) {
+ return stride * (begin + i);
+ }
+};
+
+}
+
+// We perform a segmented sort in cub with inputs that have
// more than 1024/2048 elements along the selected dimension.
// Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace).
-// Large sort algorithm:.
-// Say we are sorting a (2, 3) tensor. We have in flattened form:
-// values 0.4 1.2 5.3 6.2 1.3 2.3
-// indices 0 1 2 0 1 2
-// segment_id 0 0 0 1 1 1
-
-// First we sort by values, globally:
-// values 6.2 5.3 2.3 1.2 1.3 0.4
-// indices 0 2 2 1 1 0
-// segment_id 1 0 1 0 1 0
-
-// Then we stable sort by segment id:
-// values 5.3 1.2 0.4 6.2 2.3 1.3
-// indices 2 1 0 0 2 1
-// segment_id 0 0 0 1 1 1
-
-// This method can only work if the slice we are sorting (`dim`) is
-// innermost, and both values and indices are contiguous. We do this
-// by re-arranging the input into this form as needed, which will
-// unfortunately allocate memory if the request is not in this form.
-// Vectorized sort is slower than iterated sort if the number of
-// slices is small (since we're sorting twice, instead of invoking a
-// smaller sort `numSlices` times), but the cub sort
-// implementation here is a catch-all, so we're not looking for
-// efficiency, but instead correctness.
std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::optional<bool> stable, int64_t dim, bool descending, Tensor & values, Tensor & indices) {
// this algorithm is always stable
TORCH_INTERNAL_ASSERT(stable.has_value(), "sort_out(): c10::optional<bool> for stable has to have value.");
@@ -365,32 +330,12 @@
while (remaining > 0) {
int64_t n = std::min(remaining, nbatch);
int64_t nsegments = n / nsort;
- int64_t segment_bits = std::max<int64_t>(1L, static_cast<int64_t>(std::ceil(std::log2(nsegments))));
- auto int_options = indices.options().dtype(kInt);
- auto indices_and_segment = at::empty({nsegments, nsort, 2}, int_options);
- indices_and_segment.select(-1, 0).copy_( // segment id
- at::arange(nsegments, int_options).view({nsegments, 1}).expand({nsegments, nsort}));
- indices_and_segment.select(-1, 1).copy_( // reverse indices
- at::arange(nsort, int_options).view({1, nsort}).expand({nsegments, nsort}));
+ auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous();
- auto i_s_ptr = reinterpret_cast<int2 *>(indices_and_segment.data_ptr<int>());
- auto indices_and_segment2 = at::empty_like(indices_and_segment);
- auto i_s_ptr2 = reinterpret_cast<int2 *>(indices_and_segment2.data_ptr<int>());
-
- at::cuda::cub::sort_pairs<scalar_t, int2>(
- self_ptr, nullptr, i_s_ptr, i_s_ptr2,
- n, descending);
-
- TORCH_INTERNAL_ASSERT(segment_bits <= 32);
-
- // sort on lower 32bits, i.e. segment index
- at::cuda::cub::sort_keys<int64_t>(
- reinterpret_cast<int64_t *>(i_s_ptr2), reinterpret_cast<int64_t *>(i_s_ptr),
- n, false, 0, segment_bits);
-
- sort_postprocess_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
- self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort);
+ at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr,
+ reverse_indices.data_ptr<int64_t>(), indices_ptr, n, nsegments,
+ offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending);
remaining -= n;
self_ptr += n;
diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py
index e3224716..38cb069 100644
--- a/test/test_sort_and_select.py
+++ b/test/test_sort_and_select.py
@@ -121,14 +121,7 @@
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
@dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bool, torch.bfloat16, torch.complex64, torch.complex128})
def test_stable_sort(self, device, dtype):
- if self.device_type == 'cpu':
- sizes = (100, 1000, 10000)
- elif self.device_type == 'cuda':
- # On CUDA, stable sort is supported only when the size of
- # the sorted dim is greater than 2048
- sizes = (1025, 10000)
- else:
- return
+ sizes = (100, 1000, 10000)
for ncopies in sizes:
x = torch.tensor([0, 1] * ncopies, dtype=dtype, device=device)
_, idx = x.sort(stable=True)