Revert "[pytorch] Accelerate indexing_backward_kernel with duplicates (#99441)"
This reverts commit 97afbcbc8007857a51c85e9c61fe6d80564ef1f9.
Reverted https://github.com/pytorch/pytorch/pull/99441 on behalf of https://github.com/ngimel due to breaks ROCM ([comment](https://github.com/pytorch/pytorch/pull/99441#issuecomment-1531804487))
diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu
index 7fd3105..4940d1d 100644
--- a/aten/src/ATen/native/cuda/Indexing.cu
+++ b/aten/src/ATen/native/cuda/Indexing.cu
@@ -15,7 +15,6 @@
#include <ATen/native/Resize.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
-#include <ATen/cuda/DeviceUtils.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@@ -122,106 +121,6 @@
}
}
-template <typename scalar_t>
-__global__ void indexing_backward_kernel_stride_1(
- const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
- int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
- using opmath_t = at::opmath_type<scalar_t>;
-
- // Number of values processed by each thread (grain size)
- for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
- int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
- int64_t crnt_sorted_idx = sorted_indices[idx];
-
- if ((idx < numel) &&
- (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]))
- {
- // Determine the number of duplicates in advance
- int64_t num_duplicates = 1;
- while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
- num_duplicates++;
- }
-
- // Continue computing weights
- const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
- int64_t grad_row = 0;
- const opmath_t scale = (opmath_t)1.0;
-
- if (!accumulate) {
- grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
- grad_weight[weight_row] =
- static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row]) * scale);
- } else {
- opmath_t gradient = (opmath_t)0.0;
-
- int laneIdx = threadIdx.x & 0x1f;
- int64_t num_warp_passes = num_duplicates / C10_WARP_SIZE;
- for (int64_t i = 0; i < num_warp_passes; ++i) {
- grad_row = ((int64_t) indices[idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride;
- gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
- }
- WARP_SYNC();
- for (int offset = 16; offset > 0; offset /= 2) {
- gradient += WARP_SHFL_DOWN(gradient, offset);
- }
-
- if (laneIdx == 0) {
- for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < num_duplicates; ++i) {
- grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
- gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
- }
-
- grad_weight[weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row]) + gradient);
- }
- }
- }
- }
-}
-
-template <typename scalar_t>
-__global__ void indexing_backward_kernel_small_stride(
- const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
- int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
- using opmath_t = at::opmath_type<scalar_t>;
-
- // Number of values processed by each thread (grain size)
- for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
- int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
- int64_t tidx = threadIdx.x;
- int64_t crnt_sorted_idx = sorted_indices[idx];
-
- if ((idx < numel) &&
- (tidx < stride) &&
- (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]))
- {
- // Determine the number of duplicates in advance
- int64_t num_duplicates = 1;
- while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
- num_duplicates++;
- }
-
- // Continue computing weights
- const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
- int64_t grad_row = 0;
- const opmath_t scale = (opmath_t)1.0;
-
- if (!accumulate) {
- grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
- grad_weight[weight_row + tidx] =
- static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row + tidx]) * scale);
- } else {
- opmath_t gradient = (opmath_t)0.0;
- for (int64_t i = 0; i < num_duplicates; ++i) {
- grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
- gradient += static_cast<opmath_t>(grad_output[grad_row + tidx]) * scale;
- }
-
- grad_weight[weight_row + tidx] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row + tidx]) + gradient);
- }
- }
- }
-}
-
template <typename scalar_t, int SZ>
__global__ void indexing_backward_kernel_quantized(
const int64_t* sorted_indices, const int64_t* indices, const float* grad_output, scalar_t* grad_weight,
@@ -504,56 +403,20 @@
std::min(std::max<int>(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2]));
dim3 block(warp_size, indices_per_block);
- if (sliceSize == 1) {
- // This implementation is faster with high amounts of duplicates but could overflow
- // if FP16 / BF16 is used
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
- expandedValue.scalar_type(), "indexing_backward_kernel_stride_1", [&] {
- indexing_backward_kernel_stride_1<scalar_t><<<grid, block, 0, stream>>>(
- sorted_indices.const_data_ptr<int64_t>(),
- orig_indices.const_data_ptr<int64_t>(),
- expandedValue.const_data_ptr<scalar_t>(),
- src_.mutable_data_ptr<scalar_t>(),
- num_indices,
- sliceSize,
- strideBefore,
- nElemBefore,
- accumulate);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- });
- } else {
- if (sliceSize <= warp_size) {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
- expandedValue.scalar_type(), "indexing_backward_kernel_small_stride", [&] {
- indexing_backward_kernel_small_stride<scalar_t><<<grid, block, 0, stream>>>(
- sorted_indices.const_data_ptr<int64_t>(),
- orig_indices.const_data_ptr<int64_t>(),
- expandedValue.const_data_ptr<scalar_t>(),
- src_.mutable_data_ptr<scalar_t>(),
- num_indices,
- sliceSize,
- strideBefore,
- nElemBefore,
- accumulate);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- });
- } else {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
- expandedValue.scalar_type(), "indexing_backward", [&] {
- indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
- sorted_indices.const_data_ptr<int64_t>(),
- orig_indices.const_data_ptr<int64_t>(),
- expandedValue.const_data_ptr<scalar_t>(),
- src_.mutable_data_ptr<scalar_t>(),
- num_indices,
- sliceSize,
- strideBefore,
- nElemBefore,
- accumulate);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- });
- }
- }
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
+ expandedValue.scalar_type(), "indexing_backward", [&] {
+ indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
+ sorted_indices.const_data_ptr<int64_t>(),
+ orig_indices.const_data_ptr<int64_t>(),
+ expandedValue.const_data_ptr<scalar_t>(),
+ src_.mutable_data_ptr<scalar_t>(),
+ num_indices,
+ sliceSize,
+ strideBefore,
+ nElemBefore,
+ accumulate);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
if (permuted) {
self.copy_(src_.permute(inversePerm));
diff --git a/test/test_scatter_gather_ops.py b/test/test_scatter_gather_ops.py
index 700cd76..5de8980 100644
--- a/test/test_scatter_gather_ops.py
+++ b/test/test_scatter_gather_ops.py
@@ -150,13 +150,7 @@
else:
expected.div_(counts, rounding_mode="floor")
- if dtype == torch.float16 or dtype == torch.bfloat16:
- # Some CUDA kernels (e.g. indexing_backward_kernel_stride_1) that are called during
- # the test use fp32 for internal accumulation for improved accuracy. When using 16 bit
- # precision types can be small differences
- self.assertEqual(actual, expected, atol=0.04, rtol=0.05)
- else:
- self.assertEqual(actual, expected, atol=0, rtol=0)
+ self.assertEqual(actual, expected, atol=0, rtol=0)
# Tests empty index
dst = make_tensor((2, 2), device=device, dtype=dtype)
diff --git a/test/test_torch.py b/test/test_torch.py
index ab94511..de3f55a 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -1842,17 +1842,11 @@
x = torch.zeros(m, device=device)
res = x.scatter_add(dim, idx, src)
- # Checking if scatter_add is deterministic
- for i in range(5):
- res_next = x.scatter_add(dim, idx, src)
- self.assertEqual(res, res_next, atol=0, rtol=0)
- res = res_next
-
expected = torch.zeros(m, device=device)
for i in range(elems):
expected[idx[i]] += src[i]
- self.assertEqual(res, expected, atol=1e-4, rtol=1e-5)
+ self.assertEqual(res, expected, atol=0, rtol=0)
# FIXME: move to test_scatter_gather_ops
@onlyNativeDeviceTypes