Migrate `_th_std_var` to ATen (#59258)
Summary:
Ref https://github.com/pytorch/pytorch/issues/49421
This migrates `std`/`var`'s special case all-reduction from TH to ATen. Using the benchmark from gh-43858 that was used to justify keeping the TH version; I find this PR has similar (slightly better) performance in single threaded. And unlike the TH version, this is multi-threaded and so much faster for large tensors.
TH Results:
```
[----------------------------- Index ------------------------------]
| torch_var | torch_var0 | stdfn | torch_sum0
1 threads: ---------------------------------------------------------
8 | 3.6 | 3.8 | 8.2 | 1.2
80 | 3.7 | 3.8 | 8.4 | 1.2
800 | 4.2 | 4.3 | 8.7 | 1.2
8000 | 9.0 | 9.1 | 11.2 | 1.5
80000 | 58.3 | 59.0 | 30.6 | 4.2
800000 | 546.9 | 546.9 | 183.4 | 31.3
8000000 | 5729.7 | 5701.0 | 6165.4 | 484.1
```
ATen results:
```
[----------------------------- Index ------------------------------]
| torch_var | torch_var0 | stdfn | torch_sum0
1 threads: ---------------------------------------------------------
8 | 4.0 | 4.0 | 8.7 | 1.2
80 | 3.6 | 3.8 | 9.0 | 1.2
800 | 4.1 | 4.3 | 8.9 | 1.2
8000 | 8.9 | 9.2 | 10.6 | 1.5
80000 | 57.0 | 57.4 | 28.8 | 4.3
800000 | 526.9 | 526.9 | 178.3 | 30.2
8000000 | 5568.1 | 5560.6 | 6042.1 | 453.2
[----------------------------- Index ------------------------------]
| torch_var | torch_var0 | stdfn | torch_sum0
8 threads: ---------------------------------------------------------
8 | 3.9 | 3.8 | 9.1 | 1.2
80 | 3.8 | 3.9 | 8.8 | 1.2
800 | 4.2 | 4.3 | 8.9 | 1.3
8000 | 9.0 | 9.2 | 10.4 | 1.5
80000 | 26.0 | 26.8 | 26.4 | 4.4
800000 | 92.9 | 87.3 | 72.1 | 22.4
8000000 | 793.5 | 791.8 | 5334.8 | 115.1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59258
Reviewed By: mruberry
Differential Revision: D28821216
Pulled By: ngimel
fbshipit-source-id: f35992c21f08a0a8878053680dc0ca7a8facd155
diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.cpp b/aten/src/ATen/LegacyTHFunctionsCPU.cpp
index 8eeadd5..50d04ea 100644
--- a/aten/src/ATen/LegacyTHFunctionsCPU.cpp
+++ b/aten/src/ATen/LegacyTHFunctionsCPU.cpp
@@ -35,25 +35,6 @@
}
}
-Scalar _th_std_var(const Tensor& self, int64_t correction, bool take_sqrt) {
- // DeviceGuard omitted
- auto dispatch_scalar_type = infer_scalar_type(self);
-
- switch (dispatch_scalar_type) {
- case ScalarType::Double: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_var", false, DeviceType::CPU, dispatch_scalar_type);
- return convert<double>(THDoubleTensor_std_var_all(self_, correction, take_sqrt));
- break;
- }
- case ScalarType::Float: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_var", false, DeviceType::CPU, dispatch_scalar_type);
- return convert<float>(THFloatTensor_std_var_all(self_, correction, take_sqrt));
- break;
- }
- default:
- AT_ERROR("_th_var not supported on CPUType for ", dispatch_scalar_type);
- }
-}
Tensor & _th_renorm_out(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm, Tensor & result) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.h b/aten/src/ATen/LegacyTHFunctionsCPU.h
index 0d273ea..5a7fc4d 100644
--- a/aten/src/ATen/LegacyTHFunctionsCPU.h
+++ b/aten/src/ATen/LegacyTHFunctionsCPU.h
@@ -20,7 +20,6 @@
Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source);
Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source);
-Scalar _th_std_var(const Tensor& self, int64_t correction, bool take_sqrt);
Tensor & _th_renorm_out(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm, Tensor & result);
Tensor _th_renorm(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm);
Tensor & _th_renorm_(Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm);
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 63d0618..5b144cb 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -16,6 +16,7 @@
#include <ATen/core/grad_mode.h>
#include <c10/util/irange.h>
+#include <c10/util/SmallBuffer.h>
#include <algorithm>
#include <functional>
@@ -1380,6 +1381,61 @@
return at::native::argmin_out(self, dim, keepdims, result);
}
+static double std_var_all_cpu(const Tensor& self, int64_t correction, bool take_sqrt) {
+ const auto dtype = self.scalar_type();
+ TORCH_CHECK(dtype == kDouble || dtype == kFloat,
+ "std_var_all: Unsupported dtype ", dtype);
+
+ auto mean = self.mean().item<double>();
+ auto iter = TensorIteratorConfig()
+ .add_input(self)
+ .build();
+
+ const auto max_threads = at::get_num_threads();
+ c10::SmallBuffer<double, 64> partial_sums(max_threads);
+ std::fill(partial_sums.begin(), partial_sums.end(), 0.0);
+
+ at::parallel_for(0, iter.numel(), at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
+ const auto tid = at::get_thread_num();
+ double thread_sum = 0.0;
+
+ AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "std_var_all_cpu", [&] {
+ iter.serial_for_each([&] (char** data, const int64_t* strides, int64_t size0, int64_t size1) {
+ const double local_mean = mean;
+ const int64_t inner_stride = strides[0];
+ const int64_t outer_stride = strides[1];
+
+ double local_sum = 0.0;
+ for (int64_t i = 0; i < size1; ++i) {
+ const char* row_ptr = data[0] + outer_stride * i;
+ for (int64_t j = 0; j < size0; ++j) {
+ const auto ptr = reinterpret_cast<const scalar_t*>(row_ptr + inner_stride * j);
+ auto dx = (static_cast<double>(*ptr) - local_mean);
+ local_sum += dx * dx;
+ }
+ }
+ thread_sum += local_sum;
+ }, {begin, end});
+ });
+
+ partial_sums[tid] = thread_sum;
+ });
+
+ const double total_sum = std::accumulate(
+ partial_sums.begin(), partial_sums.end(), 0.0);
+ const auto var = [&] () __ubsan_ignore_float_divide_by_zero__ {
+ return total_sum / std::max(int64_t{0}, self.numel() - correction);
+ }();
+ const auto result = take_sqrt ? std::sqrt(var) : var;
+
+ if (dtype == kFloat) {
+ // Convert to infinity if out of range for a float.
+ // Doing it now prevents checked_convert failing later
+ return static_cast<float>(result);
+ }
+ return result;
+}
+
static Tensor& std_var_out(
const char* fname, Tensor& result, const Tensor& self,
c10::optional<IntArrayRef> dim, c10::optional<int64_t> correction_opt,
@@ -1439,9 +1495,9 @@
iter.common_dtype() != kBFloat16 && iter.common_dtype() != kHalf) {
// NOTE: CPU performance significantly regressed when attempting to port to
// ATen,
- // so all-reduce is still implemented in TH.
+ // so all-reduce has a custom implementation.
// See https://github.com/pytorch/pytorch/pull/43858.
- result.fill_(legacy::cpu::_th_std_var(self, correction, take_sqrt));
+ result.fill_(std_var_all_cpu(self, correction, take_sqrt));
} else {
std_var_stub(iter.device_type(), iter, correction, take_sqrt);
}
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index b0b294c..2bb5624 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -26,8 +26,6 @@
TH_API void THTensor_(renorm)(THTensor *r_, THTensor *t, scalar_t value, int dimension, scalar_t maxnorm);
TH_API void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, scalar_t minvalue, scalar_t maxvalue);
-TH_API accreal THTensor_(std_var_all)(THTensor* self, int64_t correction, bool take_sqrt);
-
#endif
#endif
#endif
diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp
index 59629ab..b2f9790 100644
--- a/aten/src/TH/generic/THTensorMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorMoreMath.cpp
@@ -301,19 +301,6 @@
c10::raw::intrusive_ptr::decref(rowS);
}
-accreal THTensor_(std_var_all)(THTensor* tensor, int64_t correction, bool take_sqrt)
- __ubsan_ignore_float_divide_by_zero__ {
- accreal mean = THTensor_wrap(tensor).mean().item<accreal>();
- accreal sum = 0;
- TH_TENSOR_APPLY(scalar_t, tensor, sum += (*tensor_data - mean)*(*tensor_data - mean););
- sum /= std::max(int64_t{0}, THTensor_(nElement)(tensor) - correction);
- if (take_sqrt) {
- return std::sqrt(sum);
- } else {
- return sum;
- }
-}
-
void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, scalar_t minvalue, scalar_t maxvalue)
{
if (nbins <= 0) {