Revert D16627924: [pytorch][PR] Port addcdiv operator from the TH code to Aten

Differential Revision:
D16627924

Original commit changeset: 960856d30fd3

fbshipit-source-id: a375a3ede5ef956a07fb55c7b4a5d4fc34c96ddb
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 0354148..5d639c3 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -2322,6 +2322,38 @@
     - THTensor* batch2
 ]]
 [[
+  name: _th_addcdiv
+  cname: addcdiv
+  variants:
+    - function
+  return: argument 0
+  arguments:
+    - arg: THTensor* result
+      output: True
+    - arg: THTensor* self
+      broadcast: tensor1,tensor2 fallback
+    - arg: real value
+      default: AS_REAL(1)
+      kwarg_only: True
+    - THTensor* tensor1
+    - THTensor* tensor2
+]]
+[[
+  name: _th_addcdiv_
+  cname: addcdiv
+  variants: function
+  return: argument 0
+  arguments:
+    - THTensor* self
+    - arg: THTensor* self
+      broadcast: tensor1,tensor2 inplace fallback
+    - arg: real value
+      default: AS_REAL(1)
+      kwarg_only: True
+    - THTensor* tensor1
+    - THTensor* tensor2
+]]
+[[
   name: _th_gels
   cname: gels
   types:
diff --git a/aten/src/ATen/native/PointwiseOps.cpp b/aten/src/ATen/native/PointwiseOps.cpp
index b5e5ca3..cc62734 100644
--- a/aten/src/ATen/native/PointwiseOps.cpp
+++ b/aten/src/ATen/native/PointwiseOps.cpp
@@ -50,45 +50,7 @@
   return result;
 }
 
-Tensor addcdiv(
-    const Tensor& self,
-    const Tensor& tensor1,
-    const Tensor& tensor2,
-    Scalar value) {
-  Tensor result = at::empty({0}, self.options());
-  return at::addcdiv_out(result, self, tensor1, tensor2, value);
-}
-
-Tensor& addcdiv_(
-    Tensor& self,
-    const Tensor& tensor1,
-    const Tensor& tensor2,
-    Scalar value) {
-  return at::addcdiv_out(self, self, tensor1, tensor2, value);
-}
-
-Tensor& addcdiv_out(
-    Tensor& result,
-    const Tensor& self,
-    const Tensor& tensor1,
-    const Tensor& tensor2,
-    Scalar value) {
-  checkBackend("addcdiv_cpu", result, self.type().backend());
-  auto iter = at::TensorIterator();
-  iter.check_and_add_output(result);
-  iter.add_input(self);
-  iter.add_input(tensor1);
-  iter.add_input(tensor2);
-  iter.build();
-  addcdiv_stub(iter.device_type(), iter, value);
-#ifdef BUILD_NAMEDTENSOR
-  at::namedinference::propagate_names(result, self);
-#endif
-  return result;
-}
-
 DEFINE_DISPATCH(addcmul_stub);
-DEFINE_DISPATCH(addcdiv_stub);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/PointwiseOps.h b/aten/src/ATen/native/PointwiseOps.h
index 57f34c4..10bc8ba 100644
--- a/aten/src/ATen/native/PointwiseOps.h
+++ b/aten/src/ATen/native/PointwiseOps.h
@@ -13,6 +13,5 @@
 using pointwise_fn = void (*)(TensorIterator&, Scalar scalar);
 
 DECLARE_DISPATCH(pointwise_fn, addcmul_stub);
-DECLARE_DISPATCH(pointwise_fn, addcdiv_stub);
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp
index d4c8a16..fde58c8 100644
--- a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp
@@ -28,28 +28,9 @@
   });
 }
 
-static void addcdiv_cpu_kernel(TensorIterator& iter, Scalar value) {
-  ScalarType dtype = iter.dtype(0);
-  AT_DISPATCH_ALL_TYPES(dtype, "addcdiv_cpu_out", [&] {
-    scalar_t scalar_val = value.to<scalar_t>();
-    auto scalar_vec = Vec256<scalar_t>(scalar_val);
-    cpu_kernel_vec(
-        iter,
-        [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
-          return self_val + scalar_val * t1_val / t2_val;
-        },
-        [=](Vec256<scalar_t> self_vec,
-            Vec256<scalar_t> t1_vec,
-            Vec256<scalar_t> t2_vec) {
-          return self_vec + scalar_vec * t1_vec / t2_vec;
-        });
-  });
-}
-
 } // anonymous namespace
 
 REGISTER_DISPATCH(addcmul_stub, &addcmul_cpu_kernel);
-REGISTER_DISPATCH(addcdiv_stub, &addcdiv_cpu_kernel);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu
index abc2ea5..04550ab 100644
--- a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu
+++ b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu
@@ -8,15 +8,6 @@
 
 namespace at { namespace native {
 
-void addcdiv_cuda_kernel(TensorIterator& iter, Scalar value) {
-  AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "addcdiv_cuda", [&]() {
-    auto alpha = value.to<scalar_t>();
-    gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
-      return a + alpha * (b / c);
-    });
-  });
-}
-
 void addcmul_cuda_kernel(TensorIterator& iter, Scalar value) {
   AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "addcmul_cuda", [&]() {
     auto alpha = value.to<scalar_t>();
@@ -26,7 +17,6 @@
   });
 }
 
-REGISTER_DISPATCH(addcdiv_stub, &addcdiv_cuda_kernel);
 REGISTER_DISPATCH(addcmul_stub, &addcmul_cuda_kernel);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index a914883..d65e4ad 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3439,6 +3439,9 @@
 
 - func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)
   variants: method
+  dispatch:
+    CPU: legacy::cpu::_th_addcdiv_
+    CUDA: legacy::cuda::_th_addcdiv_
 
 - func: random_.from(Tensor(a!) self, int from, int to, *, Generator? generator=None) -> Tensor(a!)
   variants: method
@@ -3754,9 +3757,15 @@
   variants: method
 
 - func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: legacy::cpu::_th_addcdiv_out
+    CUDA: legacy::cuda::_th_addcdiv_out
 
 - func: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
   variants: method, function
+  dispatch:
+    CPU: legacy::cpu::_th_addcdiv
+    CUDA: legacy::cuda::_th_addcdiv
 
 - func: lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
   dispatch:
diff --git a/aten/src/TH/generic/THTensorMath.cpp b/aten/src/TH/generic/THTensorMath.cpp
index 29ba043..5299ced 100644
--- a/aten/src/TH/generic/THTensorMath.cpp
+++ b/aten/src/TH/generic/THTensorMath.cpp
@@ -575,6 +575,28 @@
   }
 }
 
+void THTensor_(addcdiv)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src1, THTensor *src2)
+{
+  if(r_ != t)
+  {
+    THTensor_(resizeAs)(r_, t);
+    at::Tensor r__wrap = THTensor_wrap(r_);
+    at::Tensor t_wrap = THTensor_wrap(t);
+    at::native::copy_(r__wrap, t_wrap);
+  }
+  int64_t r_Size = THTensor_(nElement)(r_);
+  int64_t src1Size = THTensor_(nElement)(src1);
+  int64_t src2Size = THTensor_(nElement)(src2);
+  int r_Contig = THTensor_(isContiguous)(r_);
+  int src1Contig = THTensor_(isContiguous)(src1);
+  int src2Contig = THTensor_(isContiguous)(src2);
+  if( (src1Size == src2Size) && (src1Size == r_Size) ){
+    TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, src1Contig, src2Contig, scalar_t, r_, scalar_t, src1, scalar_t, src2, *r__data += value * *src1_data / *src2_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
+  } else {
+    TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, src1, scalar_t, src2, *r__data += value * *src1_data / *src2_data;);
+  }
+}
+
 void THTensor_(addmv)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *mat, THTensor *vec)
 {
   if( (mat->dim() != 2) || (THTensor_nDimension(vec) != 1) )
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index ca79035..ffc6724 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -129,6 +129,8 @@
 TH_API void THTensor_(cfmod)(THTensor *r_, THTensor *t, THTensor *src);
 TH_API void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src);
 
+TH_API void THTensor_(addcdiv)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src1, THTensor *src2);
+
 TH_API void THTensor_(addmv)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *mat,  THTensor *vec);
 TH_API void THTensor_(addmm)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *mat1, THTensor *mat2);
 TH_API void THTensor_(addr)(THTensor *r_,  scalar_t beta, THTensor *t, scalar_t alpha, THTensor *vec1, THTensor *vec2);
diff --git a/aten/src/THC/THCTensorMathPointwise.cuh b/aten/src/THC/THCTensorMathPointwise.cuh
index 5b90632..e8c041d 100644
--- a/aten/src/THC/THCTensorMathPointwise.cuh
+++ b/aten/src/THC/THCTensorMathPointwise.cuh
@@ -393,6 +393,23 @@
 };
 
 template <typename T>
+struct TensorAddCDivOp {
+  TensorAddCDivOp(T v) : val(v) {}
+
+  __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
+    *out = THCNumerics<T>::add(
+      *out,
+      THCNumerics<T>::mul(
+        val,
+        THCNumerics<T>::div(*in1, *in2)
+      )
+    );
+  }
+
+  T val;
+};
+
+template <typename T>
 struct TensorLShiftOp {
   __device__ __forceinline__ void
   operator()(T* out, T* in) {
diff --git a/aten/src/THC/generic/THCTensorMathPointwise.cu b/aten/src/THC/generic/THCTensorMathPointwise.cu
index da5f177..6846f3c 100644
--- a/aten/src/THC/generic/THCTensorMathPointwise.cu
+++ b/aten/src/THC/generic/THCTensorMathPointwise.cu
@@ -595,5 +595,28 @@
   }
 }
 
+void THCTensor_(addcdiv)(THCState *state, THCTensor *self_, THCTensor *t, scalar_t value, THCTensor *src1, THCTensor *src2)
+{
+  THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, self_, t, src1, src2));
+  if(self_ != t)
+  {
+    THCTensor_(resizeAs)(state, self_, t);
+    THCTensor_(copy)(state, self_, t);
+  }
+  else
+  {
+    THArgCheck(THCTensor_(nElement)(state, self_) == THCTensor_(nElement)(state, src1),
+               1, "sizes do not match");
+  }
+  THArgCheck(THCTensor_(nElement)(state, src1) == THCTensor_(nElement)(state, src2),
+             3, "sizes do not match");
+
+  if (!THC_pointwiseApply3<scalar_t, scalar_t, scalar_t>(state, self_, src1, src2, TensorAddCDivOp<scalar_t>(value))) {
+    THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+  }
+
+  THCudaCheck(cudaGetLastError());
+}
+
 #endif
 #endif
diff --git a/aten/src/THC/generic/THCTensorMathPointwise.h b/aten/src/THC/generic/THCTensorMathPointwise.h
index 8d4c008..28c58de 100644
--- a/aten/src/THC/generic/THCTensorMathPointwise.h
+++ b/aten/src/THC/generic/THCTensorMathPointwise.h
@@ -68,5 +68,7 @@
 THC_API void THCTensor_(cfmod)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
 THC_API void THCTensor_(cremainder)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
 
+THC_API void THCTensor_(addcdiv)(THCState *state, THCTensor *self, THCTensor* t, scalar_t value, THCTensor *src1, THCTensor *src2);
+
 #endif
 #endif
diff --git a/test/test_torch.py b/test/test_torch.py
index 37154df..4226c32 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -1632,27 +1632,6 @@
         res2 = x1.sum(axis=(0, 2), keepdims=True)
         self.assertEqual(res1, res2)
 
-    def test_addcdiv(self):
-        def _test_addcdiv(a, alpha, b, c):
-            actual = torch.addcdiv(a, alpha, b, c)
-            expected = a + (alpha * b) / c
-            self.assertTrue(torch.allclose(expected, actual, equal_nan=True))
-        def non_zero_rand(size, dtype, device):
-            if dtype.is_floating_point:
-                a = torch.rand(size=size, dtype=dtype, device=device)
-            elif dtype == torch.uint8:
-                a = torch.randint(1, 5, size=size, dtype=dtype, device=device)
-            else:
-                a = torch.randint(-5, 5, size=size, dtype=dtype, device=device)
-            return a + (a == 0).type(dtype)
-        for device in torch.testing.get_all_device_types():
-            for dtype in torch.testing.get_all_math_dtypes(device):
-                _test_addcdiv(
-                    non_zero_rand((2, 2), dtype=dtype, device=device),
-                    0.5,
-                    non_zero_rand((2, 2), dtype=dtype, device=device),
-                    non_zero_rand((2, 2), dtype=dtype, device=device))
-
     def test_add(self):
         for device in torch.testing.get_all_device_types():
             # [res] torch.add([res,] tensor1, tensor2)