Migrate index_add cpu from TH to ATen (#28421)

Summary:
Migrate index_add cpu from TH to ATen.

I couldn't find replacement for get1d and set1d, so doing pointer arithmetic inplace.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28421

Test Plan: existing tests

Differential Revision: D18060971

Pulled By: ggoossen

fbshipit-source-id: 413719990cdb2fe578964cde14e93577e48a4342
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 53b2187..a1ee2d8 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -215,6 +215,8 @@
   name: _th_index_add_
   cname: indexAdd
   variants: function
+  backends:
+    - CUDA
   return: argument 0
   arguments:
     - THTensor* self
diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp
index 6f986e8..a0f7ecd 100644
--- a/aten/src/ATen/native/Indexing.cpp
+++ b/aten/src/ATen/native/Indexing.cpp
@@ -55,6 +55,7 @@
 #include <ATen/NativeFunctions.h>
 #include <ATen/ExpandUtils.h>
 #include <ATen/native/TensorIterator.h>
+#include <ATen/native/BinaryOps.h>
 #include <ATen/core/EnableNamedTensor.h>
 
 #include <algorithm>
@@ -314,6 +315,69 @@
   return self.clone(at::MemoryFormat::Preserve).index_copy_(dim, index, source);
 }
 
+
+Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
+  dim = maybe_wrap_dim(dim, self.dim());
+
+  auto numel = index.numel();
+  TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector");
+  TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index");
+  TORCH_CHECK(self.scalar_type() == source.scalar_type(),
+              "index_add_(): self and source must have the same scalar type");
+  TORCH_CHECK(dim == 0 || dim < source.dim(),
+              "index_add_(): Indexing dim ", dim, " is out of bounds of tensor");
+  TORCH_CHECK(numel == (source.dim() == 0 ? 1 : source.size(dim)),
+              "index_add_(): Number of indices should be equal to self.size(dim)");
+
+  auto index_contig = index.contiguous();
+  auto index_data = index_contig.data_ptr<int64_t>();
+
+  if (self.dim() > 1) {
+    // Equivalent to:
+    //   for (auto i = 0; i < numel; i++) {
+    //     auto selfSlice = self.select(dim, index_data[i]);
+    //     auto sourceSlice = source.select(dim, i);
+    //     selfSlice.add_(sourceSlice);
+    //   }
+    // But much faster as this reuses the iterator from add_
+    if (numel == 0) {
+      return self;
+    }
+    auto selfSlice = self.select(dim, 0);
+    auto sourceSlice = source.select(dim, 0);
+    auto self_stride_bytes = self.stride(dim) * elementSize(self.scalar_type());
+    auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type());
+    auto self_dim_size = self.size(dim);
+    auto iter = TensorIterator::binary_op(selfSlice, selfSlice, sourceSlice);
+
+    for (auto i = 0; i < numel; i++) {
+      auto self_i = index_data[i];
+      TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
+      auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
+      auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes;
+      iter.unsafe_replace_operand(0, self_data);
+      iter.unsafe_replace_operand(1, self_data);
+      iter.unsafe_replace_operand(2, source_data);
+      add_stub(iter.device_type(), iter, 1);
+    }
+  }
+  else {
+    TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
+
+    AT_DISPATCH_ALL_TYPES(self.scalar_type(), "index_add_", [&] {
+      auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
+      auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
+      for (auto i = 0; i < numel; i++) {
+        auto self_i = index_data[i];
+        TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self");
+        scalar_t *self_ip = self.data<scalar_t>() + self_i * self_stride;
+        *self_ip += *(source.data<scalar_t>() + i * source_stride);
+      }
+    });
+  }
+  return self;
+}
+
 Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
   return self.clone(at::MemoryFormat::Preserve).index_add_(dim, index, source);
 }
diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp
index bfc3bf9..212d7eb 100644
--- a/aten/src/ATen/native/TensorIterator.cpp
+++ b/aten/src/ATen/native/TensorIterator.cpp
@@ -630,9 +630,8 @@
   operands_.erase(operands_.begin() + arg);
 }
 
-void TensorIterator::replace_operand(int arg, void* data, IntArrayRef stride) {
+void TensorIterator::unsafe_replace_operand(int arg, void* data) {
   operands_[arg].data = data;
-  operands_[arg].stride_bytes = stride;
 }
 
 void TensorIterator::remove_dimension(int dim) {
diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h
index a50d2c0..da7dc44 100644
--- a/aten/src/ATen/native/TensorIterator.h
+++ b/aten/src/ATen/native/TensorIterator.h
@@ -238,8 +238,10 @@
   void narrow(int dim, int64_t start, int64_t size);
   /// Narrows every dim after and including `start_dim` to size one.
   void select_all_keeping_dim(int start_dim, IntArrayRef starts);
-  /// Replaces the data pointer and strides for the operand at index `arg`
-  void replace_operand(int arg, void* data, IntArrayRef stride);
+  /// Replaces the data pointer for the operand at index `arg`.
+  /// The new pointer should have the same sizes, strides and dtype as the
+  /// original
+  void unsafe_replace_operand(int arg, void* data);
 
   /// Splits this TensorIterator into two iterators. Together they iterate over
   /// the entire operation. Used by `with_32bit_indexing()`.
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 83ee7f0..ba888c7 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3751,7 +3751,7 @@
 - func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
   variants: method
   dispatch:
-    CPU: legacy::cpu::_th_index_add_
+    CPU: index_add_cpu_
     CUDA: legacy::cuda::_th_index_add_
 
 - func: index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp
index f555b5a..9136919 100644
--- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp
@@ -723,47 +723,6 @@
 
 #if !defined(TH_REAL_IS_BOOL)
 
-void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
-{
-  ptrdiff_t i, numel;
-  THTensor *tSlice, *sSlice;
-  int64_t *index_data;
-
-  numel = THLongTensor_nElement(index);
-  THArgCheck(THTensor_nDimensionLegacyNoScalars(index) == 1, 3, "Index is supposed to be a vector");
-  THArgCheck(dim < THTensor_nDimensionLegacyNoScalars(src), 4,"Indexing dim %d is out of bounds of tensor", dim);
-  THArgCheck(numel == THTensor_sizeLegacyNoScalars(src, dim),4,"Number of indices should be equal to source:size(dim)");
-
-  index = THLongTensor_newContiguous(index);
-  index_data = THLongTensor_data(index);
-
-  if (tensor->dim() > 1)
-  {
-    tSlice = THTensor_(new)();
-    sSlice = THTensor_(new)();
-
-    for (i=0; i<numel; i++)
-    {
-      THTensor_(select)(tSlice, tensor, dim, index_data[i]);
-      THTensor_(select)(sSlice, src, dim, i);
-      THTensor_(cadd)(tSlice, tSlice, 1.0, sSlice);
-    }
-
-    c10::raw::intrusive_ptr::decref(tSlice);
-    c10::raw::intrusive_ptr::decref(sSlice);
-  }
-  else
-  {
-    for (i=0; i<numel; i++)
-    {
-      THTensor_(set1d)(tensor,
-              index_data[i],
-              THTensor_(get1d)(src,i) + THTensor_(get1d)(tensor,index_data[i]));
-    }
-  }
-  THLongTensor_free(index);
-}
-
 accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
 {
 #ifdef BUILD_NAMEDTENSOR
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index b8313d3..b1054dd 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -107,8 +107,6 @@
 
 #if !defined(TH_REAL_IS_BOOL) /* non bool only part */
 
-TH_API void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
-
 TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);
 
 TH_API void THTensor_(cinv)(THTensor *self, THTensor *src);
diff --git a/test/test_torch.py b/test/test_torch.py
index 8b529b1..752d75e 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2621,26 +2621,25 @@
             reference[0.0, :, 0.0] = 1
 
     def test_index_add(self):
-        num_copy, num_dest = 3, 3
-        dest = torch.randn(num_dest, 4, 5)
-        src = torch.randn(num_copy, 4, 5)
-        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
-        dest2 = dest.clone()
-        dest.index_add_(0, idx, src)
-        for i in range(idx.size(0)):
-            dest2[idx[i]] += src[i]
-        self.assertEqual(dest, dest2)
+        for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
+            for other_sizes in ((), (4, 5)):
+                num_copy, num_dest = 3, 3
+                dest = torch.randn(num_dest, *other_sizes)
+                if not dest_contig:
+                    dest = torch.testing.make_non_contiguous(dest)
+                src = torch.randn(num_copy, *other_sizes)
+                if not src_contig:
+                    src = torch.testing.make_non_contiguous(src)
+                idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
+                if not index_contig:
+                    idx = torch.testing.make_non_contiguous(idx)
+                dest2 = dest.clone()
+                dest.index_add_(0, idx, src)
+                for i in range(idx.size(0)):
+                    dest2[idx[i]] += src[i]
+                self.assertEqual(dest, dest2)
 
-        dest = torch.randn(num_dest)
-        src = torch.randn(num_copy)
-        idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
-        dest2 = dest.clone()
-        dest.index_add_(0, idx, src)
-        for i in range(idx.size(0)):
-            dest2[idx[i]] = dest2[idx[i]] + src[i]
-        self.assertEqual(dest, dest2)
-
-    # add coverage for issue with atomic add that appeared only for 
+    # add coverage for issue with atomic add that appeared only for
     # specific dtypes on cuda:
     # https://github.com/pytorch/pytorch/issues/29153
     def test_index_add_all_dtypes(self):