Migrate index_add cpu from TH to ATen (#28421)
Summary:
Migrate index_add cpu from TH to ATen.
I couldn't find replacement for get1d and set1d, so doing pointer arithmetic inplace.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28421
Test Plan: existing tests
Differential Revision: D18060971
Pulled By: ggoossen
fbshipit-source-id: 413719990cdb2fe578964cde14e93577e48a4342
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 53b2187..a1ee2d8 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -215,6 +215,8 @@
name: _th_index_add_
cname: indexAdd
variants: function
+ backends:
+ - CUDA
return: argument 0
arguments:
- THTensor* self
diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp
index 6f986e8..a0f7ecd 100644
--- a/aten/src/ATen/native/Indexing.cpp
+++ b/aten/src/ATen/native/Indexing.cpp
@@ -55,6 +55,7 @@
#include <ATen/NativeFunctions.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/TensorIterator.h>
+#include <ATen/native/BinaryOps.h>
#include <ATen/core/EnableNamedTensor.h>
#include <algorithm>
@@ -314,6 +315,69 @@
return self.clone(at::MemoryFormat::Preserve).index_copy_(dim, index, source);
}
+
+Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
+ dim = maybe_wrap_dim(dim, self.dim());
+
+ auto numel = index.numel();
+ TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector");
+ TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index");
+ TORCH_CHECK(self.scalar_type() == source.scalar_type(),
+ "index_add_(): self and source must have the same scalar type");
+ TORCH_CHECK(dim == 0 || dim < source.dim(),
+ "index_add_(): Indexing dim ", dim, " is out of bounds of tensor");
+ TORCH_CHECK(numel == (source.dim() == 0 ? 1 : source.size(dim)),
+ "index_add_(): Number of indices should be equal to self.size(dim)");
+
+ auto index_contig = index.contiguous();
+ auto index_data = index_contig.data_ptr<int64_t>();
+
+ if (self.dim() > 1) {
+ // Equivalent to:
+ // for (auto i = 0; i < numel; i++) {
+ // auto selfSlice = self.select(dim, index_data[i]);
+ // auto sourceSlice = source.select(dim, i);
+ // selfSlice.add_(sourceSlice);
+ // }
+ // But much faster as this reuses the iterator from add_
+ if (numel == 0) {
+ return self;
+ }
+ auto selfSlice = self.select(dim, 0);
+ auto sourceSlice = source.select(dim, 0);
+ auto self_stride_bytes = self.stride(dim) * elementSize(self.scalar_type());
+ auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type());
+ auto self_dim_size = self.size(dim);
+ auto iter = TensorIterator::binary_op(selfSlice, selfSlice, sourceSlice);
+
+ for (auto i = 0; i < numel; i++) {
+ auto self_i = index_data[i];
+ TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
+ auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
+ auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes;
+ iter.unsafe_replace_operand(0, self_data);
+ iter.unsafe_replace_operand(1, self_data);
+ iter.unsafe_replace_operand(2, source_data);
+ add_stub(iter.device_type(), iter, 1);
+ }
+ }
+ else {
+ TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
+
+ AT_DISPATCH_ALL_TYPES(self.scalar_type(), "index_add_", [&] {
+ auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
+ auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
+ for (auto i = 0; i < numel; i++) {
+ auto self_i = index_data[i];
+ TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self");
+ scalar_t *self_ip = self.data<scalar_t>() + self_i * self_stride;
+ *self_ip += *(source.data<scalar_t>() + i * source_stride);
+ }
+ });
+ }
+ return self;
+}
+
Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone(at::MemoryFormat::Preserve).index_add_(dim, index, source);
}
diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp
index bfc3bf9..212d7eb 100644
--- a/aten/src/ATen/native/TensorIterator.cpp
+++ b/aten/src/ATen/native/TensorIterator.cpp
@@ -630,9 +630,8 @@
operands_.erase(operands_.begin() + arg);
}
-void TensorIterator::replace_operand(int arg, void* data, IntArrayRef stride) {
+void TensorIterator::unsafe_replace_operand(int arg, void* data) {
operands_[arg].data = data;
- operands_[arg].stride_bytes = stride;
}
void TensorIterator::remove_dimension(int dim) {
diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h
index a50d2c0..da7dc44 100644
--- a/aten/src/ATen/native/TensorIterator.h
+++ b/aten/src/ATen/native/TensorIterator.h
@@ -238,8 +238,10 @@
void narrow(int dim, int64_t start, int64_t size);
/// Narrows every dim after and including `start_dim` to size one.
void select_all_keeping_dim(int start_dim, IntArrayRef starts);
- /// Replaces the data pointer and strides for the operand at index `arg`
- void replace_operand(int arg, void* data, IntArrayRef stride);
+ /// Replaces the data pointer for the operand at index `arg`.
+ /// The new pointer should have the same sizes, strides and dtype as the
+ /// original
+ void unsafe_replace_operand(int arg, void* data);
/// Splits this TensorIterator into two iterators. Together they iterate over
/// the entire operation. Used by `with_32bit_indexing()`.
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 83ee7f0..ba888c7 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3751,7 +3751,7 @@
- func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
variants: method
dispatch:
- CPU: legacy::cpu::_th_index_add_
+ CPU: index_add_cpu_
CUDA: legacy::cuda::_th_index_add_
- func: index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp
index f555b5a..9136919 100644
--- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp
@@ -723,47 +723,6 @@
#if !defined(TH_REAL_IS_BOOL)
-void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
-{
- ptrdiff_t i, numel;
- THTensor *tSlice, *sSlice;
- int64_t *index_data;
-
- numel = THLongTensor_nElement(index);
- THArgCheck(THTensor_nDimensionLegacyNoScalars(index) == 1, 3, "Index is supposed to be a vector");
- THArgCheck(dim < THTensor_nDimensionLegacyNoScalars(src), 4,"Indexing dim %d is out of bounds of tensor", dim);
- THArgCheck(numel == THTensor_sizeLegacyNoScalars(src, dim),4,"Number of indices should be equal to source:size(dim)");
-
- index = THLongTensor_newContiguous(index);
- index_data = THLongTensor_data(index);
-
- if (tensor->dim() > 1)
- {
- tSlice = THTensor_(new)();
- sSlice = THTensor_(new)();
-
- for (i=0; i<numel; i++)
- {
- THTensor_(select)(tSlice, tensor, dim, index_data[i]);
- THTensor_(select)(sSlice, src, dim, i);
- THTensor_(cadd)(tSlice, tSlice, 1.0, sSlice);
- }
-
- c10::raw::intrusive_ptr::decref(tSlice);
- c10::raw::intrusive_ptr::decref(sSlice);
- }
- else
- {
- for (i=0; i<numel; i++)
- {
- THTensor_(set1d)(tensor,
- index_data[i],
- THTensor_(get1d)(src,i) + THTensor_(get1d)(tensor,index_data[i]));
- }
- }
- THLongTensor_free(index);
-}
-
accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
{
#ifdef BUILD_NAMEDTENSOR
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index b8313d3..b1054dd 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -107,8 +107,6 @@
#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
-TH_API void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
-
TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);
TH_API void THTensor_(cinv)(THTensor *self, THTensor *src);
diff --git a/test/test_torch.py b/test/test_torch.py
index 8b529b1..752d75e 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2621,26 +2621,25 @@
reference[0.0, :, 0.0] = 1
def test_index_add(self):
- num_copy, num_dest = 3, 3
- dest = torch.randn(num_dest, 4, 5)
- src = torch.randn(num_copy, 4, 5)
- idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
- dest2 = dest.clone()
- dest.index_add_(0, idx, src)
- for i in range(idx.size(0)):
- dest2[idx[i]] += src[i]
- self.assertEqual(dest, dest2)
+ for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
+ for other_sizes in ((), (4, 5)):
+ num_copy, num_dest = 3, 3
+ dest = torch.randn(num_dest, *other_sizes)
+ if not dest_contig:
+ dest = torch.testing.make_non_contiguous(dest)
+ src = torch.randn(num_copy, *other_sizes)
+ if not src_contig:
+ src = torch.testing.make_non_contiguous(src)
+ idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
+ if not index_contig:
+ idx = torch.testing.make_non_contiguous(idx)
+ dest2 = dest.clone()
+ dest.index_add_(0, idx, src)
+ for i in range(idx.size(0)):
+ dest2[idx[i]] += src[i]
+ self.assertEqual(dest, dest2)
- dest = torch.randn(num_dest)
- src = torch.randn(num_copy)
- idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
- dest2 = dest.clone()
- dest.index_add_(0, idx, src)
- for i in range(idx.size(0)):
- dest2[idx[i]] = dest2[idx[i]] + src[i]
- self.assertEqual(dest, dest2)
-
- # add coverage for issue with atomic add that appeared only for
+ # add coverage for issue with atomic add that appeared only for
# specific dtypes on cuda:
# https://github.com/pytorch/pytorch/issues/29153
def test_index_add_all_dtypes(self):