Sort: Use cub::WarpMergeSort for small sorts (32 < n <= 128) (#96223)
We currently use `bitonicSortKVInplace` for sorts of size `n <= 32`
but use `radixSortKVInplace` for `32 < n <= 4096`. Bitonic sort is
also unstable, which forces stable sorts fall back to which is up to
4x slower in this small regime.
This PR adds a new kernel `warpMergeSortKVInplace` using
`cub::WarpMergeSort` to implement sorts with `32 < n <= 128` and all
stable sorts with `n < 128`. This results in up to a 2x speedup for
unstable sorts and up to 15x for stable sorts, depending on the input
geometry.
This also doesn't increase the total number of kernels since we are
replacing radix-sorts of size 32 and 128.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96223
Approved by: https://github.com/ngimel
diff --git a/aten/src/ATen/cuda/DeviceUtils.cuh b/aten/src/ATen/cuda/DeviceUtils.cuh
index dc17aa8..c0a2fc4 100644
--- a/aten/src/ATen/cuda/DeviceUtils.cuh
+++ b/aten/src/ATen/cuda/DeviceUtils.cuh
@@ -14,6 +14,12 @@
#endif
}
+__device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) {
+#if !defined(USE_ROCM)
+ return __syncwarp(mask);
+#endif
+}
+
#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
{
diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu
index cb66b65..30a6149 100644
--- a/aten/src/ATen/native/cuda/Sort.cu
+++ b/aten/src/ATen/native/cuda/Sort.cu
@@ -7,6 +7,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
+#include <ATen/cuda/NumericLimits.cuh>
#include <ATen/native/cuda/SortUtils.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
@@ -28,12 +29,21 @@
return minGridSize;
}
-// For very small sorts, use bitonicSortKVInPlace which performs
-// better because it can sort multiple arrays within the same block of
-// threads, improving occupancy.
-//
-// TODO: cub in CUDA 11.6 has a WarpMergeSort primitive that could
-// replace the bitonic sort here.
+template <typename T>
+constexpr bool has_nan() {
+ if constexpr (std::numeric_limits<T>::is_specialized) {
+ return std::numeric_limits<T>::has_quiet_NaN;
+ } else if constexpr (
+ c10::is_complex<T>::value ||
+ std::is_same_v<T, c10::BFloat16> ||
+ std::is_same_v<T, c10::Half>) {
+ return true;
+ }
+}
+
+// For very small unstable sorts (n <= 32), use bitonicSortKVInPlace
+// which can sort multiple arrays within the same block of threads,
+// improving occupancy.
struct SmallBitonicSort {
template <int A, typename K, typename V, typename IndexType>
void sort(
@@ -94,8 +104,79 @@
}
};
-// For medium sizes (32 < n <= 4096) use radixSortKVInplace for better
-// performance than the bitonic sort kernel.
+// For small sorts (n <= 128) we use warpMergeSortKVInPlace which
+// sorts one slice per warp and potentially multiple slices in the
+// same block for improved occupancy with large batch sizes.
+template <int sort_size>
+struct WarpMergeSort {
+
+ template <int A, typename K, typename V, typename IndexType>
+ void sort(
+ at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
+ IndexType keySlices,
+ IndexType keySliceSize,
+ IndexType keySliceStride,
+ at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
+ IndexType valueSliceStride,
+ bool descending) {
+ constexpr int max_block_y = 16;
+ const int block_x = at::cuda::warp_size();
+
+ TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size);
+
+ // Scale batch size down if the grid would be too small
+ const auto min_grid = minimum_grid_for_occupancy(
+ warpMergeSortKVInPlace<
+ A, -1, sort_size, max_block_y,
+ K, V, LTOp<K, true>, IndexType>,
+ block_x * max_block_y);
+ const auto max_batch = std::max(IndexType{1}, keySlices / min_grid);
+ const int block_y = std::min(IndexType(max_block_y), max_batch);
+ dim3 block(block_x, block_y);
+
+ dim3 grid;
+ const int grid_count = (keySlices + block_y - 1) / block_y;
+ TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid),
+ "Too many slices to sort");
+ const auto stream = at::cuda::getCurrentCUDAStream();
+
+ if (descending) {
+ const K invalid_key = at::numeric_limits<K>::lower_bound();
+ warpMergeSortKVInPlace<A, -1, sort_size, max_block_y>
+ <<<grid, block, 0, stream>>>(
+ keyInfo,
+ keySlices,
+ keySliceSize,
+ keySliceStride,
+ valueInfo,
+ valueSliceStride,
+ GTOp<K, true>(),
+ invalid_key);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ } else {
+ const K invalid_key = []{
+ // NAN is sorted after inf
+ if constexpr(has_nan<K>()) {
+ return K(NAN);
+ }
+ return at::numeric_limits<K>::upper_bound();
+ }();
+ warpMergeSortKVInPlace<A, -1, sort_size, max_block_y>
+ <<<grid, block, 0, stream>>>(
+ keyInfo,
+ keySlices,
+ keySliceSize,
+ keySliceStride,
+ valueInfo,
+ valueSliceStride,
+ LTOp<K, true>(),
+ invalid_key);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ }
+};
+
+// For medium sizes (128 < n <= 4096) use radixSortKVInplace.
struct MediumRadixSort {
template <int A, typename K, typename V, typename IndexType>
@@ -134,14 +215,13 @@
break;
case 128:
case 64:
- HANDLE_CASE(128, 4);
- break;
case 32:
case 16:
case 8:
case 4:
case 2:
- HANDLE_CASE(32, 2);
+ TORCH_INTERNAL_ASSERT(
+ false, "Expected size <= 128 to be handled by a different algorithm");
break;
case 1:
/* Nothing to do, data already sorted */
@@ -272,9 +352,14 @@
int dim,
bool descending,
bool stable) {
- if (!stable && key.size(dim) <= 32) {
+ const auto sort_size = key.size(dim);
+ if (sort_size <= 1) {
+ return; // Already sorted
+ } else if (!stable && sort_size <= 32) {
// NOTE: Bitonic sort is unstable
sortCommon(SmallBitonicSort{}, key, value, dim, descending);
+ } else if (sort_size <= 128) {
+ sortCommon(WarpMergeSort<128>{}, key, value, dim, descending);
} else {
sortCommon(MediumRadixSort{}, key, value, dim, descending);
}
diff --git a/aten/src/ATen/native/cuda/SortUtils.cuh b/aten/src/ATen/native/cuda/SortUtils.cuh
index a1d309c..172a260 100644
--- a/aten/src/ATen/native/cuda/SortUtils.cuh
+++ b/aten/src/ATen/native/cuda/SortUtils.cuh
@@ -5,6 +5,7 @@
#include <ATen/cuda/cub.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/Sort.h>
#include <ATen/native/StridedRandomAccessor.h>
@@ -153,6 +154,89 @@
}
}
+template <int KeyDims, int ValueDims, int sort_size, int max_block_dim_y,
+ typename K, typename V, typename Comparator, typename IndexType>
+C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y)
+__global__ void
+warpMergeSortKVInPlace(
+ at::cuda::detail::TensorInfo<K, IndexType> keys,
+ IndexType keySlices,
+ IndexType keySliceSize,
+ IndexType keySliceStride,
+ at::cuda::detail::TensorInfo<V, IndexType> values,
+ IndexType valueSliceStride,
+ Comparator comp,
+ K invalid_key) {
+ // Find the slice of the tensor that we are sorting
+ // NOTE: blockDim.y may be less max_block_dim_y
+ const IndexType blockIndex = getLinearBlockId<IndexType>();
+ const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y;
+
+ // If this row is out of bounds exit early
+ if (linearIndex >= keySlices) {
+ return;
+ }
+
+ const IndexType keyStartOffset =
+ at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
+ const IndexType valueStartOffset =
+ at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);
+
+ K *keys_slice = &keys.data[keyStartOffset];
+ V *values_slice = &values.data[valueStartOffset];
+
+ StridedRandomAccessor<K, IndexType> keys_iter(keys_slice, keySliceStride);
+ StridedRandomAccessor<V, IndexType> values_iter(values_slice, valueSliceStride);
+
+ namespace cub = ROCM_HIPCUB(at_cuda_detail::cub);
+
+ assert(blockDim.x == C10_WARP_SIZE);
+ assert(blockDim.y <= max_block_dim_y);
+ constexpr int items_per_thread = sort_size / C10_WARP_SIZE;
+ static_assert(
+ items_per_thread * C10_WARP_SIZE == sort_size,
+ "sort_size must be a multiple of C10_WARP_SIZE");
+
+
+ using LoadKeys = cub::WarpLoad<K, items_per_thread, cub::WARP_LOAD_TRANSPOSE>;
+ using LoadValues = cub::WarpLoad<V, items_per_thread, cub::WARP_LOAD_TRANSPOSE>;
+ using Sort = cub::WarpMergeSort<K, items_per_thread, C10_WARP_SIZE, V>;
+ using StoreKeys = cub::WarpStore<K, items_per_thread, cub::WARP_STORE_TRANSPOSE>;
+ using StoreValues = cub::WarpStore<V, items_per_thread, cub::WARP_STORE_TRANSPOSE>;
+
+ __shared__ union {
+ typename LoadKeys::TempStorage load_keys;
+ typename LoadValues::TempStorage load_values;
+ typename Sort::TempStorage sort;
+ typename StoreKeys::TempStorage store_keys;
+ typename StoreValues::TempStorage store_values;
+ } tmp_storage[max_block_dim_y];
+
+ auto& warp_storage = tmp_storage[threadIdx.y];
+
+ // Load inputs
+ K local_keys[items_per_thread];
+ V local_values[items_per_thread];
+
+ const auto invalid_value = V{};
+ LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key);
+ WARP_SYNC();
+ LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value);
+ WARP_SYNC();
+
+ // Sort! We use stable sort to ensure that invalid values are never
+ // sorted before valid values. In testing it performed the same as
+ // .Sort, so there is no down-side.
+ Sort(warp_storage.sort).StableSort(
+ local_keys, local_values, comp, keySliceSize, invalid_key);
+ WARP_SYNC();
+
+ // Store outputs
+ StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize);
+ WARP_SYNC();
+ StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize);
+}
+
template <int KeyDims, int ValueDims,
int block_size, int items_per_thread,
typename K, typename V, typename IndexType>