Migrate `_th_std_var` to ATen (#59258)

Summary:
Ref https://github.com/pytorch/pytorch/issues/49421

This migrates `std`/`var`'s special case all-reduction from TH to ATen. Using the benchmark from gh-43858 that was used to justify keeping the TH version; I find this PR has similar (slightly better) performance in single threaded. And unlike the TH version, this is multi-threaded and so much faster for large tensors.

TH Results:
```
[----------------------------- Index ------------------------------]
               |  torch_var  |  torch_var0  |  stdfn   |  torch_sum0
1 threads: ---------------------------------------------------------
      8        |       3.6   |       3.8    |     8.2  |      1.2
      80       |       3.7   |       3.8    |     8.4  |      1.2
      800      |       4.2   |       4.3    |     8.7  |      1.2
      8000     |       9.0   |       9.1    |    11.2  |      1.5
      80000    |      58.3   |      59.0    |    30.6  |      4.2
      800000   |     546.9   |     546.9    |   183.4  |     31.3
      8000000  |    5729.7   |    5701.0    |  6165.4  |    484.1
```

ATen results:
```
[----------------------------- Index ------------------------------]
               |  torch_var  |  torch_var0  |  stdfn   |  torch_sum0
1 threads: ---------------------------------------------------------
      8        |       4.0   |       4.0    |     8.7  |      1.2
      80       |       3.6   |       3.8    |     9.0  |      1.2
      800      |       4.1   |       4.3    |     8.9  |      1.2
      8000     |       8.9   |       9.2    |    10.6  |      1.5
      80000    |      57.0   |      57.4    |    28.8  |      4.3
      800000   |     526.9   |     526.9    |   178.3  |     30.2
      8000000  |    5568.1   |    5560.6    |  6042.1  |    453.2

[----------------------------- Index ------------------------------]
               |  torch_var  |  torch_var0  |  stdfn   |  torch_sum0
8 threads: ---------------------------------------------------------
      8        |      3.9    |      3.8     |     9.1  |      1.2
      80       |      3.8    |      3.9     |     8.8  |      1.2
      800      |      4.2    |      4.3     |     8.9  |      1.3
      8000     |      9.0    |      9.2     |    10.4  |      1.5
      80000    |     26.0    |     26.8     |    26.4  |      4.4
      800000   |     92.9    |     87.3     |    72.1  |     22.4
      8000000  |    793.5    |    791.8     |  5334.8  |    115.1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59258

Reviewed By: mruberry

Differential Revision: D28821216

Pulled By: ngimel

fbshipit-source-id: f35992c21f08a0a8878053680dc0ca7a8facd155
diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.cpp b/aten/src/ATen/LegacyTHFunctionsCPU.cpp
index 8eeadd5..50d04ea 100644
--- a/aten/src/ATen/LegacyTHFunctionsCPU.cpp
+++ b/aten/src/ATen/LegacyTHFunctionsCPU.cpp
@@ -35,25 +35,6 @@
   }
 }
 
-Scalar _th_std_var(const Tensor& self, int64_t correction, bool take_sqrt) {
-    // DeviceGuard omitted
-    auto dispatch_scalar_type = infer_scalar_type(self);
-
-    switch (dispatch_scalar_type) {
-        case ScalarType::Double: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_var", false, DeviceType::CPU, dispatch_scalar_type);
-            return convert<double>(THDoubleTensor_std_var_all(self_, correction, take_sqrt));
-            break;
-        }
-        case ScalarType::Float: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_var", false, DeviceType::CPU, dispatch_scalar_type);
-            return convert<float>(THFloatTensor_std_var_all(self_, correction, take_sqrt));
-            break;
-        }
-        default:
-            AT_ERROR("_th_var not supported on CPUType for ", dispatch_scalar_type);
-    }
-}
 Tensor & _th_renorm_out(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm, Tensor & result) {
     // DeviceGuard omitted
     auto dispatch_scalar_type = infer_scalar_type(self);
diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.h b/aten/src/ATen/LegacyTHFunctionsCPU.h
index 0d273ea..5a7fc4d 100644
--- a/aten/src/ATen/LegacyTHFunctionsCPU.h
+++ b/aten/src/ATen/LegacyTHFunctionsCPU.h
@@ -20,7 +20,6 @@
 
 Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source);
 Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source);
-Scalar _th_std_var(const Tensor& self, int64_t correction, bool take_sqrt);
 Tensor & _th_renorm_out(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm, Tensor & result);
 Tensor _th_renorm(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm);
 Tensor & _th_renorm_(Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm);
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 63d0618..5b144cb 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -16,6 +16,7 @@
 #include <ATen/core/grad_mode.h>
 
 #include <c10/util/irange.h>
+#include <c10/util/SmallBuffer.h>
 
 #include <algorithm>
 #include <functional>
@@ -1380,6 +1381,61 @@
   return at::native::argmin_out(self, dim, keepdims, result);
 }
 
+static double std_var_all_cpu(const Tensor& self, int64_t correction, bool take_sqrt) {
+  const auto dtype = self.scalar_type();
+  TORCH_CHECK(dtype == kDouble || dtype == kFloat,
+              "std_var_all: Unsupported dtype ", dtype);
+
+  auto mean = self.mean().item<double>();
+  auto iter = TensorIteratorConfig()
+      .add_input(self)
+      .build();
+
+  const auto max_threads = at::get_num_threads();
+  c10::SmallBuffer<double, 64> partial_sums(max_threads);
+  std::fill(partial_sums.begin(), partial_sums.end(), 0.0);
+
+  at::parallel_for(0, iter.numel(), at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
+    const auto tid = at::get_thread_num();
+    double thread_sum = 0.0;
+
+    AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "std_var_all_cpu", [&] {
+      iter.serial_for_each([&] (char** data, const int64_t* strides, int64_t size0, int64_t size1) {
+        const double local_mean = mean;
+        const int64_t inner_stride = strides[0];
+        const int64_t outer_stride = strides[1];
+
+        double local_sum = 0.0;
+        for (int64_t i = 0; i < size1; ++i) {
+          const char* row_ptr = data[0] + outer_stride * i;
+          for (int64_t j = 0; j < size0; ++j) {
+            const auto ptr = reinterpret_cast<const scalar_t*>(row_ptr + inner_stride * j);
+            auto dx = (static_cast<double>(*ptr) - local_mean);
+            local_sum += dx * dx;
+          }
+        }
+        thread_sum += local_sum;
+      }, {begin, end});
+    });
+
+    partial_sums[tid] = thread_sum;
+  });
+
+  const double total_sum = std::accumulate(
+      partial_sums.begin(), partial_sums.end(), 0.0);
+  const auto var = [&] () __ubsan_ignore_float_divide_by_zero__ {
+    return total_sum / std::max(int64_t{0}, self.numel() - correction);
+  }();
+  const auto result = take_sqrt ? std::sqrt(var) : var;
+
+  if (dtype == kFloat) {
+    // Convert to infinity if out of range for a float.
+    // Doing it now prevents checked_convert failing later
+    return static_cast<float>(result);
+  }
+  return result;
+}
+
 static Tensor& std_var_out(
     const char* fname, Tensor& result, const Tensor& self,
     c10::optional<IntArrayRef> dim, c10::optional<int64_t> correction_opt,
@@ -1439,9 +1495,9 @@
       iter.common_dtype() != kBFloat16 && iter.common_dtype() != kHalf) {
     // NOTE: CPU performance significantly regressed when attempting to port to
     // ATen,
-    //   so all-reduce is still implemented in TH.
+    //   so all-reduce has a custom implementation.
     //   See https://github.com/pytorch/pytorch/pull/43858.
-    result.fill_(legacy::cpu::_th_std_var(self, correction, take_sqrt));
+    result.fill_(std_var_all_cpu(self, correction, take_sqrt));
   } else {
     std_var_stub(iter.device_type(), iter, correction, take_sqrt);
   }
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index b0b294c..2bb5624 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -26,8 +26,6 @@
 TH_API void THTensor_(renorm)(THTensor *r_, THTensor *t, scalar_t value, int dimension, scalar_t maxnorm);
 TH_API void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, scalar_t minvalue, scalar_t maxvalue);
 
-TH_API accreal THTensor_(std_var_all)(THTensor* self, int64_t correction, bool take_sqrt);
-
 #endif
 #endif
 #endif
diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp
index 59629ab..b2f9790 100644
--- a/aten/src/TH/generic/THTensorMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorMoreMath.cpp
@@ -301,19 +301,6 @@
   c10::raw::intrusive_ptr::decref(rowS);
 }
 
-accreal THTensor_(std_var_all)(THTensor* tensor, int64_t correction, bool take_sqrt)
-    __ubsan_ignore_float_divide_by_zero__ {
-  accreal mean = THTensor_wrap(tensor).mean().item<accreal>();
-  accreal sum = 0;
-  TH_TENSOR_APPLY(scalar_t, tensor, sum += (*tensor_data - mean)*(*tensor_data - mean););
-  sum /= std::max(int64_t{0}, THTensor_(nElement)(tensor) - correction);
-  if (take_sqrt) {
-    return std::sqrt(sum);
-  } else {
-    return sum;
-  }
-}
-
 void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, scalar_t minvalue, scalar_t maxvalue)
 {
   if (nbins <= 0) {