Revert "use scatter_add for index_add when dim is the most inner dim (#88729)"
This reverts commit 13dbad63696f0ad39d63e4457eeebf800fb80dff.
Reverted https://github.com/pytorch/pytorch/pull/88729 on behalf of https://github.com/desertfire due to causing inductor test failure
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index 7663475..330558b 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -859,63 +859,6 @@
if (numel == 0) {
return;
}
-
- // When the slice of source or result is noncontiguous,
- // original index_add is slow as it uses add for the sliced tensor,
- // which is serial on index and parallel on sliced tensor to avoid write conflict.
- // Doing parallel on the sliced tensor is not optimal as the size of sliced tensor
- // may be not big enough to parallel and also causes multiple parallelizations.
- // scatter_add is used to speedup for this case as scatter_add parallels on
- // the outer dimension of input and is serial on the inner dimension to
- // avoid write conflict. scatter_add only need one parallel and the size of
- // outer dimensions is bigger to do parallel.
-
- // TODO: When https://github.com/pytorch/pytorch/pull/82703 lands,
- // using scatter_add will also get obvious speedup for the case dim == 0.
- if ((result.stride(dim) == 1 || source.stride(dim) == 1) &&
- // Data type of index should be long and alpha should be 1 to use scatter_add.
- alpha.equal(1.0) && index_contig.scalar_type() == ScalarType::Long &&
- result.numel() > at::internal::GRAIN_SIZE &&
- // scatter_add does not support ComplexHalf
- source.scalar_type() != ScalarType::ComplexHalf &&
- result.scalar_type() != ScalarType::ComplexHalf) {
- std::vector<int64_t> ep_sizes(result.sizes().size());
- std::vector<int64_t> ep_strides(source.sizes().size());
-
- // Check whether result and source are matched apart from the dimension dim.
- // Note that the broadcast case:
- // source.select(dim, i) is broadcast for result.select(dim, index_data[i])
- // The broadcast case is not applicable for scatter_add
- auto check_sizes = [&ep_sizes, &ep_strides, &numel](IntArrayRef a, IntArrayRef b, int64_t dim) -> bool {
- if (a.size() != b.size()) {
- return false;
- }
-
- ep_sizes[dim] = numel;
- ep_strides[dim] = 1;
- for (const auto i : c10::irange(a.size())) {
- if (i == dim) {
- continue;
- }
-
- if (a[i] != b[i]) {
- return false;
- }
- ep_sizes[i] = a[i];
- ep_strides[i] = 0;
-
- }
- return true;
- };
-
- if (check_sizes(result.sizes(), source.sizes(), dim)) {
- auto ep_index = index_contig.as_strided(ep_sizes, ep_strides);
- result.scatter_add_(dim, ep_index, source);
- return;
- }
-
- }
-
auto selfSlice = result.select(dim, 0);
auto sourceSlice = source.select(dim, 0);
auto self_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
diff --git a/test/test_torch.py b/test/test_torch.py
index aae0a72..6f54c51 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -5679,54 +5679,6 @@
added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor, alpha=-1)
self.assertEqual(added, -tensor)
- def test_index_add_correctness(self):
- # Check whether index_add can get correct result when
- # alpha is 1, and dtype of index is torch.long,
- # i.e., using scatter_add
- def helper(dim, dtype, device, size_result, size_source):
- tensor = torch.zeros(size_result, dtype=dtype, device=device)
- index = torch.randint(0, size_result[dim], (size_source[dim],),
- dtype=torch.long, device=device)
- if dtype.is_floating_point or dtype.is_complex:
- source = torch.rand(size_source, dtype=dtype, device=device)
- elif dtype.is_signed:
- source = torch.randint(-2, 5, size_source, dtype=dtype, device=device)
- else:
- source = torch.randint(0, 5, size_source, dtype=dtype, device=device)
-
- ref_out = tensor.index_add(dim, index, source, alpha=2.) / 2.
- ref_out = ref_out.to(dtype=dtype)
- out = tensor.index_add(dim, index, source)
- if device == 'cuda':
- self.assertEqual(out, ref_out, atol=1e-2, rtol=1e-2)
- else:
- self.assertEqual(out, ref_out.to(dtype=dtype))
-
- for dim in [-1, -2, -3]:
- for dtype in all_types_and_complex_and(torch.half, torch.bfloat16):
- for device in get_all_device_types():
- for size in [(2, 512, 256), (5, 256, 256)]:
- helper(dim, dtype, device, size, size)
-
- # Check broadcast cases on CPU
- size_result = (2, 512, 256)
- size_source = (1, 512, 256)
- helper(dim, dtype, 'cpu', size_result, size_source)
- size_result = (2, 512, 512)
- size_source = (1, 512, 1)
- helper(dim, dtype, 'cpu', size_result, size_source)
- size_result = (2, 512, 256)
- size_source = (2, 1, 256)
- helper(dim, dtype, 'cpu', size_result, size_source)
-
- # Check bound
- result = torch.zeros(1, 512, 256, dtype=dtype)
- source = torch.ones(1, 512, 256, dtype=dtype)
- index = torch.ones(257).to(dtype=torch.long)
- self.assertRaises(RuntimeError, lambda: result.index_add_(dim, index, source))
- index = (torch.ones(256) * 257).to(dtype=torch.long)
- self.assertRaises(RuntimeError, lambda: result.index_add_(dim, index, source))
-
# FIXME: move to shape ops test suite
def test_unflatten(self):
# test args: tensor, int, sizes