Move cumprod and cumsum to Aten(CPU) (#33280)

Summary:
This PR is about move cumprod and cumsum to Aten.
Test script:
```
import torch
import torch.nn as nn
import time

torch.manual_seed(0)

def _time():
    return time.time()

device = "cpu"

#torch.set_num_threads(1)

#warm up
for n in [10, 300]:
    input = torch.randn(n, n, n, requires_grad=False, device=device)
    input = input * 0.01 + 1
    for dim in range(input.dim()):
        for i in range(100):
            #output = input.cumsum(dim)
            output = input.cumprod(dim)

for n in [10, 300]:
    input = torch.randn(n, n, n, requires_grad=False, device=device)
    input = input * 0.01 + 1
    for dim in range(input.dim()):
        fwd_t = 0
        for i in range(1000):
            t1 = _time()
            #output = input.cumsum(dim)
            output = input.cumprod(dim)
            t2 = _time()
            fwd_t = fwd_t + (t2 -t1)
        fwd_avg = fwd_t / 1000 * 1000
        print("size = (%d, %d, %d); reduce dim=%d; compute time is %.4f(ms)" % (n, n, n, dim, fwd_avg))
```
Test device: **skx-8180**.
Performance:
```
size = (10, 10, 10); reduce dim=0; compute time is 0.0098(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0089(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0089(ms)
size = (300, 300, 300); reduce dim=0; compute time is 208.9403(ms)
size = (300, 300, 300); reduce dim=1; compute time is 241.5989(ms)
size = (300, 300, 300); reduce dim=2; compute time is 66.2587(ms)
After:
size = (10, 10, 10); reduce dim=0; compute time is 0.0065(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0063(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0053(ms)
size = (300, 300, 300); reduce dim=0; compute time is 36.0139(ms)
size = (300, 300, 300); reduce dim=1; compute time is 36.0776(ms)
size = (300, 300, 300); reduce dim=2; compute time is 21.0111(ms)
number_threads = 1:
size = (10, 10, 10); reduce dim=0; compute time is 0.0053(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0052(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0051(ms)
size = (300, 300, 300); reduce dim=0; compute time is 81.8831(ms)
size = (300, 300, 300); reduce dim=1; compute time is 88.5687(ms)
size = (300, 300, 300); reduce dim=2; compute time is 54.9922(ms)

cumprod:
Before:
size = (10, 10, 10); reduce dim=0; compute time is 0.0096(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0088(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0088(ms)
size = (300, 300, 300); reduce dim=0; compute time is 221.2601(ms)
size = (300, 300, 300); reduce dim=1; compute time is 249.7894(ms)
size = (300, 300, 300); reduce dim=2; compute time is 71.5182(ms)
number_threads = 1:
size = (10, 10, 10); reduce dim=0; compute time is 0.0100(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0093(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0093(ms)
size = (300, 300, 300); reduce dim=0; compute time is 207.6287(ms)
size = (300, 300, 300); reduce dim=1; compute time is 241.6693(ms)
size = (300, 300, 300); reduce dim=2; compute time is 66.2977(ms)
After:
size = (10, 10, 10); reduce dim=0; compute time is 0.0063(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0062(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0053(ms)
size = (300, 300, 300); reduce dim=0; compute time is 36.4283(ms)
size = (300, 300, 300); reduce dim=1; compute time is 38.1139(ms)
size = (300, 300, 300); reduce dim=2; compute time is 20.9140(ms)
number_threads =1:
size = (10, 10, 10); reduce dim=0; compute time is 0.0052(ms)
size = (10, 10, 10); reduce dim=1; compute time is 0.0052(ms)
size = (10, 10, 10); reduce dim=2; compute time is 0.0050(ms)
size = (300, 300, 300); reduce dim=0; compute time is 82.6926(ms)
size = (300, 300, 300); reduce dim=1; compute time is 90.1265(ms)
size = (300, 300, 300); reduce dim=2; compute time is 55.0196(ms)
```
Fix https://github.com/pytorch/pytorch/issues/24668, https://github.com/pytorch/pytorch/issues/24669.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33280

Differential Revision: D20076997

Pulled By: VitalyFedyunin

fbshipit-source-id: 12225767da8cfdc5e44257462a432bffa04cd469
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 6a305f8..fd52eba 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -582,7 +582,8 @@
 [[
   name: _th_cumsum
   cname: cumsum
-  cpu_bool: True
+  backends:
+    - CUDA
   cuda_bool: True
   variants: function
   return: argument 0
@@ -595,7 +596,8 @@
 [[
   name: _th_cumprod
   cname: cumprod
-  cpu_bool: True
+  backends:
+    - CUDA
   cuda_bool: True
   variants: function
   return: argument 0
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 10011b3..31dc3e3 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -35,6 +35,8 @@
 DEFINE_DISPATCH(max_values_stub);
 DEFINE_DISPATCH(argmax_stub);
 DEFINE_DISPATCH(argmin_stub);
+DEFINE_DISPATCH(cumsum_stub);
+DEFINE_DISPATCH(cumprod_stub);
 
 #define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \
 { \
@@ -184,6 +186,17 @@
   return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype));
 }
 
+Tensor _cumsum_cpu(const Tensor& self, int64_t dim) {
+  Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
+  cumsum_stub(self.device().type(), result, self, dim);
+  return result;
+}
+
+Tensor& _cumsum_out_cpu(Tensor& result, const Tensor& self, int64_t dim) {
+  cumsum_stub(self.device().type(), result, self, dim);
+  return result;
+}
+
 Tensor cumsum(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype) {
   auto result = [&]() {
     NoNamesGuard guard;
@@ -210,6 +223,17 @@
   return result;
 }
 
+Tensor _cumprod_cpu(const Tensor& self, int64_t dim) {
+  Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
+  cumprod_stub(self.device().type(), result, self, dim);
+  return result;
+}
+
+Tensor& _cumprod_out_cpu(Tensor& result, const Tensor& self, int64_t dim) {
+  cumprod_stub(self.device().type(), result, self, dim);
+  return result;
+}
+
 Tensor cumprod(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype) {
   auto result = [&]() {
     NoNamesGuard guard;
diff --git a/aten/src/ATen/native/ReduceOps.h b/aten/src/ATen/native/ReduceOps.h
index b2a270f..80fb24c 100644
--- a/aten/src/ATen/native/ReduceOps.h
+++ b/aten/src/ATen/native/ReduceOps.h
@@ -33,4 +33,8 @@
 using reduce_fn_flag = void(*)(TensorIterator &, Scalar);
 DECLARE_DISPATCH(reduce_fn_flag, norm_stub);
 
+using cum_fn = void (*)(Tensor & result, const Tensor & self, int64_t dim);
+DECLARE_DISPATCH(cum_fn, cumsum_stub);
+DECLARE_DISPATCH(cum_fn, cumprod_stub);
+
 }} // namespace at::native
diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h
index 761a1c6..9bdf5a2 100644
--- a/aten/src/ATen/native/ReduceOpsUtils.h
+++ b/aten/src/ATen/native/ReduceOpsUtils.h
@@ -2,6 +2,35 @@
 
 namespace at { namespace native {
 
+static inline int64_t ensure_nonempty_dim(int64_t dim) {
+  return std::max<int64_t>(dim, 1);
+}
+
+static inline int64_t ensure_nonempty_size(const Tensor& t, int64_t dim) {
+  return t.dim() == 0 ? 1 : t.size(dim);
+}
+
+static inline int64_t ensure_nonempty_stride(const Tensor& t, int64_t dim) {
+  return t.dim() == 0 ? 1 : t.stride(dim);
+}
+
+using IdxVec = std::vector<int64_t>;
+static inline IdxVec ensure_nonempty_vec(IdxVec vec) {
+  if (vec.size() == 0) {
+    vec.push_back(1);
+  }
+  return vec;
+}
+
+static inline Tensor restride_dim(
+  const Tensor& src, int64_t dim,
+  IntArrayRef replacement_shape
+) {
+  auto strides = ensure_nonempty_vec(src.strides().vec());
+  strides[dim] = 0;
+  return src.as_strided(replacement_shape, strides);
+}
+
 inline Tensor &_dimreduce_setup(Tensor &result, const Tensor &self,
                                 int64_t dim) {
   IntArrayRef self_sizes = self.sizes();
diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
index 5117667..31e077b 100644
--- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
@@ -6,15 +6,104 @@
 #include <ATen/Dispatch.h>
 #include <ATen/cpu/vec256/vec256.h>
 #include <ATen/native/ReduceOps.h>
+#include <ATen/native/ReduceOpsUtils.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/SharedReduceOps.h>
 #include <ATen/native/cpu/Reduce.h>
 #include <c10/util/Optional.h>
+#include <ATen/AccumulateType.h>
 
 namespace at { namespace native { namespace {
 
 using namespace vec256;
 
+template <typename scalar_t, typename func_t>
+static inline void cpu_cum_base_kernel(Tensor& result,
+    const Tensor& self,
+    int64_t dim,
+    const func_t& f,
+    scalar_t init_val) {
+  if (result.sizes() != self.sizes()) {
+    result.resize_as_(self);
+  }
+  if (self.numel() == 0) {
+    return;
+  }
+  const auto input_ndim = self.dim();
+  if (input_ndim == 0) {
+    result.fill_(self);
+    return;
+  }
+
+  auto self_sizes = ensure_nonempty_vec(self.sizes().vec());
+  self_sizes[dim] = 1;
+
+  auto result_restrided = restride_dim(result, dim, self_sizes);
+  auto self_restrided = restride_dim(self, dim, self_sizes);
+
+  auto iter = TensorIterator();
+  iter.dont_compute_common_dtype();
+  iter.dont_resize_outputs();
+  iter.add_output(result_restrided);
+  iter.add_input(self_restrided);
+  iter.build();
+
+  auto result_dim_stride = ensure_nonempty_stride(result, dim);
+  auto self_dim_stride = ensure_nonempty_stride(self, dim);
+
+  auto loop = [&](char** data, const int64_t* strides, int64_t n) {
+    auto* result_data_bytes = data[0];
+    const auto* self_data_bytes = data[1];
+
+    for (int64_t i = 0; i < n; ++i) {
+      f(
+        (scalar_t*)result_data_bytes, result_dim_stride,
+        (scalar_t*)self_data_bytes, self_dim_stride, init_val
+      );
+      result_data_bytes += strides[0];
+      self_data_bytes += strides[1];
+    }
+  };
+
+  iter.for_each(loop);
+}
+
+static void cumsum_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim) {
+  auto wrap_dim = maybe_wrap_dim(dim, self.dim());
+  int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
+
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "cumsum_out_cpu", [&] {
+    cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
+      scalar_t* result_data, auto result_dim_stride,
+      const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
+        auto cum_number = (at::acc_type<scalar_t, false>)init_val;
+        for (int64_t i = 0; i < self_dim_size; ++i) {
+          cum_number += self_data[i * self_dim_stride];
+          result_data[i * result_dim_stride] = (scalar_t)cum_number;
+        }
+      }, /*init_val=*/ 0
+    );
+  });
+}
+
+static void cumprod_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim) {
+  auto wrap_dim = maybe_wrap_dim(dim, self.dim());
+  int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
+
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "cumprod_out_cpu", [&] {
+    cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
+      scalar_t* result_data, auto result_dim_stride,
+      const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
+        auto cum_number = (at::acc_type<scalar_t, false>)init_val;
+        for (int64_t i = 0; i < self_dim_size; ++i) {
+          cum_number *= self_data[i * self_dim_stride];
+          result_data[i * result_dim_stride] = (scalar_t)cum_number;
+        }
+      }, /*init_val=*/ 1
+    );
+  });
+}
+
 static void sum_kernel_impl(TensorIterator& iter) {
   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
       ScalarType::BFloat16, ScalarType::Bool, iter.dtype(), "sum_cpu", [&] {
@@ -220,5 +309,7 @@
 REGISTER_DISPATCH(max_values_stub, &max_values_kernel_impl);
 REGISTER_DISPATCH(argmax_stub, &argmax_kernel_impl);
 REGISTER_DISPATCH(argmin_stub, &argmin_kernel_impl);
+REGISTER_DISPATCH(cumprod_stub, &cumprod_cpu_kernel);
+REGISTER_DISPATCH(cumsum_stub, &cumsum_cpu_kernel);
 
 }}  // namespace at::native
diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp
index 0761d89..62e8bd8 100644
--- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp
+++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp
@@ -1,4 +1,5 @@
 #include <ATen/native/ScatterGatherShapeChecks.h>
+#include <ATen/native/ReduceOpsUtils.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/Parallel.h>
 
@@ -6,26 +7,6 @@
 
 namespace {
 
-static inline int64_t ensure_nonempty_dim(int64_t dim) {
-  return std::max<int64_t>(dim, 1);
-}
-
-static inline int64_t ensure_nonempty_size(const Tensor& t, int64_t dim) {
-  return t.dim() == 0 ? 1 : t.size(dim);
-}
-
-static inline int64_t ensure_nonempty_stride(const Tensor& t, int64_t dim) {
-  return t.dim() == 0 ? 1 : t.stride(dim);
-}
-
-using IdxVec = std::vector<int64_t>;
-static inline IdxVec ensure_nonempty_vec(IdxVec vec) {
-  if (vec.size() == 0) {
-    vec.push_back(1);
-  }
-  return vec;
-}
-
 // Used for `gather`-like methods
 // Test:
 // 1. index.size(d) == self.size(d) for all d != dim
@@ -99,15 +80,6 @@
   }
 }
 
-static Tensor restride_dim(
-  const Tensor& src, int64_t dim,
-  IntArrayRef replacement_shape
-) {
-  auto strides = ensure_nonempty_vec(src.strides().vec());
-  strides[dim] = 0;
-  return src.as_strided(replacement_shape, strides);
-}
-
 template <typename func_t>
 void cpu_scatter_gather_base_kernel(
   Tensor& self, int64_t dim,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 18bdbda..a095ea2 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5261,23 +5261,23 @@
 - func: _cumsum(Tensor self, int dim) -> Tensor
   use_c10_dispatcher: full
   dispatch:
-    CPU: legacy::cpu::_th_cumsum
+    CPU: _cumsum_cpu
     CUDA: legacy::cuda::_th_cumsum
 
 - func: _cumsum.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
   dispatch:
-    CPU: legacy::cpu::_th_cumsum_out
+    CPU: _cumsum_out_cpu
     CUDA: legacy::cuda::_th_cumsum_out
 
 - func: _cumprod(Tensor self, int dim) -> Tensor
   use_c10_dispatcher: full
   dispatch:
-    CPU: legacy::cpu::_th_cumprod
+    CPU: _cumprod_cpu
     CUDA: legacy::cuda::_th_cumprod
 
 - func: _cumprod.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
   dispatch:
-    CPU: legacy::cpu::_th_cumprod_out
+    CPU: _cumprod_out_cpu
     CUDA: legacy::cuda::_th_cumprod_out
 
 - func: _var(Tensor self, bool unbiased=True) -> Tensor
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index 552bf39..56145a7 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -91,9 +91,6 @@
 TH_API void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate);
 TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val);
 
-TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension);
-TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension);
-
 #if !defined(TH_REAL_IS_BOOL) /* non bool only part */
 
 TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);
diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp
index 0168a22..40fce564 100644
--- a/aten/src/TH/generic/THTensorMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorMoreMath.cpp
@@ -256,42 +256,6 @@
                    *r_data = *t_data < *src_data ? *t_data : *src_data;);
 }
 
-void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension)
-{
-  dimension = at::maybe_wrap_dim(dimension, t);
-  THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyNoScalars)(t), 2, "dimension %d out of range",
-      dimension);
-
-  THTensor_(resizeAs)(r_, t);
-
-  TH_TENSOR_DIM_APPLY2(scalar_t, t, scalar_t, r_, dimension,
-                       accreal cumsum = 0;
-                       int64_t i;
-                       for(i = 0; i < t_size; i++)
-                       {
-                         cumsum += t_data[i*t_stride];
-                         r__data[i*r__stride] = (scalar_t)cumsum;
-                       });
-}
-
-void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension)
-{
-  dimension = at::maybe_wrap_dim(dimension, t);
-  THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyNoScalars)(t), 2, "dimension %d out of range",
-      dimension);
-
-  THTensor_(resizeAs)(r_, t);
-
-  TH_TENSOR_DIM_APPLY2(scalar_t, t, scalar_t, r_, dimension,
-                       accreal cumprod = 1;
-                       int64_t i;
-                       for(i = 0; i < t_size; i++)
-                       {
-                         cumprod *= t_data[i*t_stride];
-                         r__data[i*r__stride] = (scalar_t)cumprod;
-                       });
-}
-
 #if !defined(TH_REAL_IS_BOOL) /* non bool only part */
 
 void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *batch1, THTensor *batch2)
diff --git a/test/test_torch.py b/test/test_torch.py
index 9fd4ccb..bef89db 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -10043,6 +10043,7 @@
         # Check a scalar example
         raw_tensor = torch.tensor(3., requires_grad=True)
         integrated = raw_tensor.cumsum(dim=-1)
+        self.assertEqual(raw_tensor, integrated)
         # Check that backward does not crash
         integrated.sum().backward()
         # Check that output maintained correct shape
@@ -10090,6 +10091,7 @@
         # Check a scalar example
         raw_tensor = torch.tensor(3., requires_grad=True)
         integrated = raw_tensor.cumprod(dim=-1)
+        self.assertEqual(raw_tensor, integrated)
         # Check that backward does not crash
         integrated.sum().backward()
         # Check that output maintained correct shape