Sort: Use cub::WarpMergeSort for small sorts (32 < n <= 128) (#96223)

We currently use `bitonicSortKVInplace` for sorts of size `n <= 32`
but use `radixSortKVInplace` for `32 < n <= 4096`. Bitonic sort is
also unstable, which forces stable sorts fall back to which is up to
4x slower in this small regime.

This PR adds a new kernel `warpMergeSortKVInplace` using
`cub::WarpMergeSort` to implement sorts with `32 < n <= 128` and all
stable sorts with `n < 128`. This results in up to a 2x speedup for
unstable sorts and up to 15x for stable sorts, depending on the input
geometry.

This also doesn't increase the total number of kernels since we are
replacing radix-sorts of size 32 and 128.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96223
Approved by: https://github.com/ngimel
diff --git a/aten/src/ATen/cuda/DeviceUtils.cuh b/aten/src/ATen/cuda/DeviceUtils.cuh
index dc17aa8..c0a2fc4 100644
--- a/aten/src/ATen/cuda/DeviceUtils.cuh
+++ b/aten/src/ATen/cuda/DeviceUtils.cuh
@@ -14,6 +14,12 @@
 #endif
 }
 
+__device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) {
+#if !defined(USE_ROCM)
+  return __syncwarp(mask);
+#endif
+}
+
 #if defined(USE_ROCM)
 __device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
 {
diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu
index cb66b65..30a6149 100644
--- a/aten/src/ATen/native/cuda/Sort.cu
+++ b/aten/src/ATen/native/cuda/Sort.cu
@@ -7,6 +7,7 @@
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/detail/KernelUtils.h>
 #include <ATen/cuda/detail/OffsetCalculator.cuh>
+#include <ATen/cuda/NumericLimits.cuh>
 #include <ATen/native/cuda/SortUtils.cuh>
 #include <ATen/native/cuda/SortingCommon.cuh>
 
@@ -28,12 +29,21 @@
   return minGridSize;
 }
 
-// For very small sorts, use bitonicSortKVInPlace which performs
-// better because it can sort multiple arrays within the same block of
-// threads, improving occupancy.
-//
-// TODO: cub in CUDA 11.6 has a WarpMergeSort primitive that could
-// replace the bitonic sort here.
+template <typename T>
+constexpr bool has_nan() {
+  if constexpr (std::numeric_limits<T>::is_specialized) {
+    return std::numeric_limits<T>::has_quiet_NaN;
+  } else if constexpr (
+      c10::is_complex<T>::value ||
+      std::is_same_v<T, c10::BFloat16> ||
+      std::is_same_v<T, c10::Half>) {
+    return true;
+  }
+}
+
+// For very small unstable sorts (n <= 32), use bitonicSortKVInPlace
+// which can sort multiple arrays within the same block of threads,
+// improving occupancy.
 struct SmallBitonicSort {
   template <int A, typename K, typename V, typename IndexType>
   void sort(
@@ -94,8 +104,79 @@
   }
 };
 
-// For medium sizes (32 < n <= 4096) use radixSortKVInplace for better
-// performance than the bitonic sort kernel.
+// For small sorts (n <= 128) we use warpMergeSortKVInPlace which
+// sorts one slice per warp and potentially multiple slices in the
+// same block for improved occupancy with large batch sizes.
+template <int sort_size>
+struct WarpMergeSort {
+
+  template <int A, typename K, typename V, typename IndexType>
+  void sort(
+      at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
+      IndexType keySlices,
+      IndexType keySliceSize,
+      IndexType keySliceStride,
+      at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
+      IndexType valueSliceStride,
+      bool descending) {
+    constexpr int max_block_y = 16;
+    const int block_x = at::cuda::warp_size();
+
+    TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size);
+
+    // Scale batch size down if the grid would be too small
+    const auto min_grid = minimum_grid_for_occupancy(
+        warpMergeSortKVInPlace<
+            A, -1, sort_size, max_block_y,
+            K, V, LTOp<K, true>, IndexType>,
+        block_x * max_block_y);
+    const auto max_batch = std::max(IndexType{1}, keySlices / min_grid);
+    const int block_y = std::min(IndexType(max_block_y), max_batch);
+    dim3 block(block_x, block_y);
+
+    dim3 grid;
+    const int grid_count = (keySlices + block_y - 1) / block_y;
+    TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid),
+                          "Too many slices to sort");
+    const auto stream = at::cuda::getCurrentCUDAStream();
+
+    if (descending) {
+      const K invalid_key = at::numeric_limits<K>::lower_bound();
+      warpMergeSortKVInPlace<A, -1, sort_size, max_block_y>
+        <<<grid, block, 0, stream>>>(
+          keyInfo,
+          keySlices,
+          keySliceSize,
+          keySliceStride,
+          valueInfo,
+          valueSliceStride,
+          GTOp<K, true>(),
+          invalid_key);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+    } else {
+      const K invalid_key = []{
+        // NAN is sorted after inf
+        if constexpr(has_nan<K>()) {
+          return K(NAN);
+        }
+        return at::numeric_limits<K>::upper_bound();
+      }();
+      warpMergeSortKVInPlace<A, -1, sort_size, max_block_y>
+        <<<grid, block, 0, stream>>>(
+          keyInfo,
+          keySlices,
+          keySliceSize,
+          keySliceStride,
+          valueInfo,
+          valueSliceStride,
+          LTOp<K, true>(),
+          invalid_key);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+    }
+  }
+};
+
+// For medium sizes (128 < n <= 4096) use radixSortKVInplace.
 struct MediumRadixSort {
 
   template <int A, typename K, typename V, typename IndexType>
@@ -134,14 +215,13 @@
         break;
       case 128:
       case 64:
-        HANDLE_CASE(128, 4);
-        break;
       case 32:
       case 16:
       case 8:
       case 4:
       case 2:
-        HANDLE_CASE(32, 2);
+        TORCH_INTERNAL_ASSERT(
+            false, "Expected size <= 128 to be handled by a different algorithm");
         break;
       case 1:
         /* Nothing to do, data already sorted */
@@ -272,9 +352,14 @@
     int dim,
     bool descending,
     bool stable) {
-  if (!stable && key.size(dim) <= 32) {
+  const auto sort_size = key.size(dim);
+  if (sort_size <= 1) {
+    return; // Already sorted
+  } else if (!stable && sort_size <= 32) {
     // NOTE: Bitonic sort is unstable
     sortCommon(SmallBitonicSort{}, key, value, dim, descending);
+  } else if (sort_size <= 128) {
+    sortCommon(WarpMergeSort<128>{}, key, value, dim, descending);
   } else {
     sortCommon(MediumRadixSort{}, key, value, dim, descending);
   }
diff --git a/aten/src/ATen/native/cuda/SortUtils.cuh b/aten/src/ATen/native/cuda/SortUtils.cuh
index a1d309c..172a260 100644
--- a/aten/src/ATen/native/cuda/SortUtils.cuh
+++ b/aten/src/ATen/native/cuda/SortUtils.cuh
@@ -5,6 +5,7 @@
 #include <ATen/cuda/cub.cuh>
 #include <ATen/cuda/detail/TensorInfo.cuh>
 #include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/DeviceUtils.cuh>
 #include <ATen/native/cuda/SortingCommon.cuh>
 #include <ATen/native/cuda/Sort.h>
 #include <ATen/native/StridedRandomAccessor.h>
@@ -153,6 +154,89 @@
   }
 }
 
+template <int KeyDims, int ValueDims, int sort_size, int max_block_dim_y,
+          typename K, typename V, typename Comparator, typename IndexType>
+C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y)
+__global__ void
+warpMergeSortKVInPlace(
+    at::cuda::detail::TensorInfo<K, IndexType> keys,
+    IndexType keySlices,
+    IndexType keySliceSize,
+    IndexType keySliceStride,
+    at::cuda::detail::TensorInfo<V, IndexType> values,
+    IndexType valueSliceStride,
+    Comparator comp,
+    K invalid_key) {
+  // Find the slice of the tensor that we are sorting
+  // NOTE: blockDim.y may be less max_block_dim_y
+  const IndexType blockIndex = getLinearBlockId<IndexType>();
+  const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y;
+
+  // If this row is out of bounds exit early
+  if (linearIndex >= keySlices) {
+    return;
+  }
+
+  const IndexType keyStartOffset =
+    at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
+  const IndexType valueStartOffset =
+    at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);
+
+  K *keys_slice = &keys.data[keyStartOffset];
+  V *values_slice = &values.data[valueStartOffset];
+
+  StridedRandomAccessor<K, IndexType> keys_iter(keys_slice, keySliceStride);
+  StridedRandomAccessor<V, IndexType> values_iter(values_slice, valueSliceStride);
+
+  namespace cub = ROCM_HIPCUB(at_cuda_detail::cub);
+
+  assert(blockDim.x == C10_WARP_SIZE);
+  assert(blockDim.y <= max_block_dim_y);
+  constexpr int items_per_thread = sort_size / C10_WARP_SIZE;
+  static_assert(
+      items_per_thread * C10_WARP_SIZE == sort_size,
+      "sort_size must be a multiple of C10_WARP_SIZE");
+
+
+  using LoadKeys = cub::WarpLoad<K, items_per_thread, cub::WARP_LOAD_TRANSPOSE>;
+  using LoadValues = cub::WarpLoad<V, items_per_thread, cub::WARP_LOAD_TRANSPOSE>;
+  using Sort = cub::WarpMergeSort<K, items_per_thread, C10_WARP_SIZE, V>;
+  using StoreKeys = cub::WarpStore<K, items_per_thread, cub::WARP_STORE_TRANSPOSE>;
+  using StoreValues = cub::WarpStore<V, items_per_thread, cub::WARP_STORE_TRANSPOSE>;
+
+  __shared__ union {
+    typename LoadKeys::TempStorage load_keys;
+    typename LoadValues::TempStorage load_values;
+    typename Sort::TempStorage sort;
+    typename StoreKeys::TempStorage store_keys;
+    typename StoreValues::TempStorage store_values;
+  } tmp_storage[max_block_dim_y];
+
+  auto& warp_storage = tmp_storage[threadIdx.y];
+
+  // Load inputs
+  K local_keys[items_per_thread];
+  V local_values[items_per_thread];
+
+  const auto invalid_value = V{};
+  LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key);
+  WARP_SYNC();
+  LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value);
+  WARP_SYNC();
+
+  // Sort! We use stable sort to ensure that invalid values are never
+  // sorted before valid values. In testing it performed the same as
+  // .Sort, so there is no down-side.
+  Sort(warp_storage.sort).StableSort(
+      local_keys, local_values, comp, keySliceSize, invalid_key);
+  WARP_SYNC();
+
+  // Store outputs
+  StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize);
+  WARP_SYNC();
+  StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize);
+}
+
 template <int KeyDims, int ValueDims,
           int block_size, int items_per_thread,
           typename K, typename V, typename IndexType>