Migrate `var` & `std` to ATen (#39967)
Summary:
Not sure why there are so many issues for std & var, but this PR should close them all:
std: Fix https://github.com/pytorch/pytorch/issues/24771, Fix https://github.com/pytorch/pytorch/issues/24676, Fix https://github.com/pytorch/pytorch/issues/24639, Fix https://github.com/pytorch/pytorch/issues/24529
var: Fix https://github.com/pytorch/pytorch/issues/24782, Fix https://github.com/pytorch/pytorch/issues/24677, Fix https://github.com/pytorch/pytorch/issues/24652, Fix https://github.com/pytorch/pytorch/issues/24530
```py
import time
import torch
def _time():
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
for device in (torch.device("cpu"), torch.device("cuda")):
for size in (
[100000000],
[10000, 10000],
[1000, 1000, 100],
[100, 100, 100, 100],
):
t = torch.randn(*size, device=device)
total_time = 0
for i in range(10):
t1 = _time()
t.std()
t2 = _time()
total_time += t2 - t1
print(f"Tensor of size {size} on {device}: {total_time / 10}")
```
Before:
```
Tensor of size [100000000] on cpu: 0.36041643619537356
Tensor of size [10000, 10000] on cpu: 0.37235140800476074
Tensor of size [1000, 1000, 100] on cpu: 0.386572527885437
Tensor of size [100, 100, 100, 100] on cpu: 0.37404844760894773
Tensor of size [100000000] on cuda: 0.0021645784378051757
Tensor of size [10000, 10000] on cuda: 0.002090191841125488
Tensor of size [1000, 1000, 100] on cuda: 0.00208127498626709
Tensor of size [100, 100, 100, 100] on cuda: 0.0020844221115112306
```
After:
```
Tensor of size [100000000] on cpu: 0.1339871883392334
Tensor of size [10000, 10000] on cpu: 0.1343991994857788
Tensor of size [1000, 1000, 100] on cpu: 0.1346735954284668
Tensor of size [100, 100, 100, 100] on cpu: 0.11906447410583496
Tensor of size [100000000] on cuda: 0.0013531208038330077
Tensor of size [10000, 10000] on cuda: 0.0012922048568725585
Tensor of size [1000, 1000, 100] on cuda: 0.001285886764526367
Tensor of size [100, 100, 100, 100] on cuda: 0.0012899160385131836
```
cc: VitalyFedyunin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39967
Differential Revision: D22162469
Pulled By: VitalyFedyunin
fbshipit-source-id: 8d901c779767b00f81cd6231bc665e04f297b4c3
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 7e71b0b..e4a6237 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -211,36 +211,6 @@
- bool sorted
]]
[[
- name: _th_var
- types:
- - floating_point
- backends:
- - CPU
- - CUDA
- variants: function
- options:
- - cname: var_all
- return: accreal
- arguments:
- - THTensor* self
- - bool unbiased
-]]
-[[
- name: _th_std
- types:
- - floating_point
- backends:
- - CPU
- - CUDA
- variants: function
- options:
- - cname: std_all
- return: accreal
- arguments:
- - THTensor* self
- - bool unbiased
-]]
-[[
name: _th_renorm
cname: renorm
types:
diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp
index f09bd2b..42f9445 100644
--- a/aten/src/ATen/core/NamedRegistrations.cpp
+++ b/aten/src/ATen/core/NamedRegistrations.cpp
@@ -18,8 +18,6 @@
m.impl("_sparse_log_softmax.int", CppFunction::makeFallthrough());
m.impl("_sparse_softmax.Dimname", CppFunction::makeFallthrough());
m.impl("_sparse_softmax.int", CppFunction::makeFallthrough());
- m.impl("_std", CppFunction::makeFallthrough());
- m.impl("_var", CppFunction::makeFallthrough());
m.impl("abs", CppFunction::makeFallthrough());
m.impl("abs.out", CppFunction::makeFallthrough());
m.impl("abs_", CppFunction::makeFallthrough());
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index d5f5ff7..ddbe5fe 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -835,7 +835,9 @@
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"var only supports floating-point dtypes");
auto trivial_return = _allreduce_return_trivial(self, std::numeric_limits<double>::quiet_NaN());
- return trivial_return.has_value() ? trivial_return.value() : at::_var(self, unbiased);
+ if (trivial_return.has_value())
+ return trivial_return.value();
+ return at::var(self, IntArrayRef{}, unbiased, false);
}
Tensor var(const Tensor& self, IntArrayRef dim, bool unbiased, bool keepdim) {
@@ -855,7 +857,9 @@
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"std only supports floating-point dtypes");
auto trivial_return = _allreduce_return_trivial(self, std::numeric_limits<double>::quiet_NaN());
- return trivial_return.has_value() ? trivial_return.value() : at::_std(self, unbiased);
+ if (trivial_return.has_value())
+ return trivial_return.value();
+ return at::std(self, IntArrayRef{}, unbiased, false);
}
Tensor std(const Tensor& self, IntArrayRef dim, bool unbiased, bool keepdim) {
@@ -868,7 +872,7 @@
}
Tensor std(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
- return at::std(self, dimnames_to_positions(self, dim), unbiased, keepdim);
+ return at::std(self, dimnames_to_positions(self, dim), unbiased, keepdim);
}
Tensor& std_out(Tensor& result, const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 99ef427..8c690fd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5216,18 +5216,6 @@
CPU: _cumprod_out_cpu
CUDA: _cumprod_out_cuda
-- func: _var(Tensor self, bool unbiased=True) -> Tensor
- use_c10_dispatcher: full
- dispatch:
- CPU: legacy::cpu::_th_var
- CUDA: legacy::cuda::_th_var
-
-- func: _std(Tensor self, bool unbiased=True) -> Tensor
- use_c10_dispatcher: full
- dispatch:
- CPU: legacy::cpu::_th_std
- CUDA: legacy::cuda::_th_std
-
- func: _amp_non_finite_check_and_unscale_(Tensor(a!) self, Tensor(b!) found_inf, Tensor inv_scale) -> ()
variants: function
dispatch:
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index 394705f..dfc5e3e 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -49,9 +49,6 @@
TH_API void THTensor_(renorm)(THTensor *r_, THTensor *t, scalar_t value, int dimension, scalar_t maxnorm);
TH_API void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, scalar_t minvalue, scalar_t maxvalue);
-TH_API accreal THTensor_(var_all)(THTensor *self, bool unbiased);
-TH_API accreal THTensor_(std_all)(THTensor *self, bool unbiased);
-
#endif
#endif
#endif
diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp
index 0a4e86a..4a81a26 100644
--- a/aten/src/TH/generic/THTensorMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorMoreMath.cpp
@@ -710,20 +710,6 @@
c10::raw::intrusive_ptr::decref(rowS);
}
-accreal THTensor_(var_all)(THTensor *tensor, bool unbiased)
-{
- accreal mean = THTensor_wrap(tensor).mean().item<accreal>();
- accreal sum = 0;
- TH_TENSOR_APPLY(scalar_t, tensor, sum += (*tensor_data - mean)*(*tensor_data - mean););
- sum /= std::max<int64_t>(0, THTensor_(nElement)(tensor) - (unbiased ? 1 : 0));
- return sum;
-}
-
-accreal THTensor_(std_all)(THTensor *tensor, bool unbiased)
-{
- return sqrt(THTensor_(var_all)(tensor, unbiased));
-}
-
void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, scalar_t minvalue, scalar_t maxvalue)
{
if (nbins <= 0) {
diff --git a/aten/src/THC/generic/THCTensorMathReduce.cu b/aten/src/THC/generic/THCTensorMathReduce.cu
index cb315c8..76f470c 100644
--- a/aten/src/THC/generic/THCTensorMathReduce.cu
+++ b/aten/src/THC/generic/THCTensorMathReduce.cu
@@ -56,35 +56,6 @@
THCTensor_(free)(state, data);
}
-accreal THCTensor_(std_all)(THCState *state, THCTensor *self, bool unbiased)
-{
- THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self));
- return THCNumerics<accreal>::sqrt((THCTensor_(var_all)(state, self, unbiased)));
-}
-
-accreal THCTensor_(var_all)(THCState *state, THCTensor *self, bool unbiased)
-{
- THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self));
- accreal mean = THTensor_wrap(self).mean().item<accreal>();
-
- accreal val;
- if (!THC_reduceAll<scalar_t>(state, self,
- SquareFunctor<accreal>(mean),
- ReduceAdd<accreal>(),
- scalar_cast<accreal>(0),
- &val, 0)) {
- THArgCheck(false, 1, CUTORCH_DIM_WARNING);
- }
-
- val = THCNumerics<accreal>::div(
- val,
- scalar_cast<accreal>(std::max<int64_t>(0, THCTensor_(nElement)(state, self) - (unbiased ? 1 : 0)))
- );
-
- THCudaCheck(cudaGetLastError());
- return val;
-}
-
#endif
#endif
diff --git a/aten/src/THC/generic/THCTensorMathReduce.h b/aten/src/THC/generic/THCTensorMathReduce.h
index ebb62a6..334c9f4 100644
--- a/aten/src/THC/generic/THCTensorMathReduce.h
+++ b/aten/src/THC/generic/THCTensorMathReduce.h
@@ -9,9 +9,7 @@
THC_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, scalar_t value, int dimension, scalar_t max_norm);
THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, scalar_t value, int dimension, int keepdim);
-THC_API accreal THCTensor_(std_all)(THCState *state, THCTensor *self, bool unbiased);
THC_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, scalar_t value);
-THC_API accreal THCTensor_(var_all)(THCState *state, THCTensor *self, bool unbiased);
#endif
diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py
index bf988b0..ad49e94 100644
--- a/test/backward_compatibility/check_backward_compatibility.py
+++ b/test/backward_compatibility/check_backward_compatibility.py
@@ -107,6 +107,8 @@
('aten::__or__', datetime.date(2020, 6, 30)),
('aten::__xor__', datetime.date(2020, 6, 30)),
('aten::split', datetime.date(2020, 6, 30)),
+ ('aten::_var', datetime.date(2020, 6, 30)),
+ ('aten::_std', datetime.date(2020, 6, 30)),
]