THTensor_varOuterDim numeric stability (#3533)
diff --git a/aten/src/THC/THCTensorMathReduce.cuh b/aten/src/THC/THCTensorMathReduce.cuh
index 1e81189..c2d8f0f 100644
--- a/aten/src/THC/THCTensorMathReduce.cuh
+++ b/aten/src/THC/THCTensorMathReduce.cuh
@@ -305,31 +305,38 @@
* outer dimensions, which contains several "inner rows").
* Each thread processes a single inner row at a time.
*/
-template<typename Real, bool flag, bool apply_sqrt>
+template<typename Real, typename Accreal, bool flag, bool apply_sqrt>
__global__ void THCTensor_kernel_varOuterDim(Real *tgt, Real *src_, unsigned num_orows, unsigned num_irows, unsigned row_size)
{
for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
Real *src = src_ + orow * row_size * num_irows + irow;
- Real sum = ScalarConvert<int, Real>::to(0), sum2 = ScalarConvert<int, Real>::to(0);
+ Accreal mean = ScalarConvert<int, Accreal>::to(0);
+ Accreal m2 = ScalarConvert<int, Accreal>::to(0);
for (unsigned col = 0; col < row_size; ++col) {
- Real val = *src;
- sum = THCNumerics<Real>::add(sum, val);
- sum2 = THCNumerics<Real>::add(
- sum2,
- THCNumerics<Real>::mul(val, val)
- );
-
+ Accreal val = ScalarConvert<Real, Accreal>::to(*src);
+ Accreal delta = THCNumerics<Accreal>::sub(val, mean);
+ mean = THCNumerics<Accreal>::add(mean,
+ THCNumerics<Accreal>::div(delta, ScalarConvert<int, Accreal>::to(col + 1)));
+ Accreal delta2 = THCNumerics<Accreal>::sub(val, mean);
+ m2 = THCNumerics<Accreal>::add(m2,
+ THCNumerics<Accreal>::mul(delta, delta2));
src += num_irows;
}
-
- tgt[orow * num_irows + irow] = THCTensor_computeVar<Real, flag, apply_sqrt>(sum, sum2, row_size);
+
+ if (flag) {
+ m2 = THCNumerics<Accreal>::div(m2, ScalarConvert<int, Accreal>::to(row_size));
+ } else {
+ m2 = THCNumerics<Accreal>::div(m2, ScalarConvert<int, Accreal>::to(row_size - 1));
+ }
+ tgt[orow * num_irows + irow] = ScalarConvert<Accreal, Real>::to(
+ apply_sqrt ? THCNumerics<Accreal>::sqrt(m2) : m2);
}
}
}
-template<typename TensorTypeK, typename Real, bool apply_sqrt>
+template<typename TensorTypeK, typename Real, typename Accreal, bool apply_sqrt>
__host__ void THCTensor_varOuterDim(THCState *state, TensorTypeK *tgt, TensorTypeK *src, int64_t dimension, int flag)
{
unsigned ndim = TensorUtils<TensorTypeK>::getDims(state, src);
@@ -350,10 +357,10 @@
dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, THCCeilDiv(num_irows, threads.x)));
if (flag) {
- THCTensor_kernel_varOuterDim<Real, true, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
+ THCTensor_kernel_varOuterDim<Real, Accreal, true, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
TensorUtils<TensorTypeK>::getData(state, tgt), TensorUtils<TensorTypeK>::getData(state, src), num_orows, num_irows, row_size);
} else {
- THCTensor_kernel_varOuterDim<Real, false, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
+ THCTensor_kernel_varOuterDim<Real, Accreal, false, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
TensorUtils<TensorTypeK>::getData(state, tgt), TensorUtils<TensorTypeK>::getData(state, src), num_orows, num_irows, row_size);
}
cudaError errcode = cudaGetLastError();
diff --git a/aten/src/THC/generic/THCTensorMathReduce.cu b/aten/src/THC/generic/THCTensorMathReduce.cu
index 494f2f7..ff662eb 100644
--- a/aten/src/THC/generic/THCTensorMathReduce.cu
+++ b/aten/src/THC/generic/THCTensorMathReduce.cu
@@ -88,7 +88,7 @@
if (dimension == THCTensor_(nDimension)(state, src) - 1) {
THCTensor_varInnermostDim<THCTensor, real, accreal, true>(state, self, src, biased);
} else {
- THCTensor_varOuterDim<THCTensor, real, true>(state, self, src, dimension, biased);
+ THCTensor_varOuterDim<THCTensor, real, accreal, true>(state, self, src, dimension, biased);
}
THCTensor_(free)(state, src);
@@ -114,7 +114,7 @@
if (dimension == THCTensor_(nDimension)(state, src) - 1) {
THCTensor_varInnermostDim<THCTensor, real, accreal, false>(state, self, src, biased);
} else {
- THCTensor_varOuterDim<THCTensor, real, false>(state, self, src, dimension, biased);
+ THCTensor_varOuterDim<THCTensor, real, accreal, false>(state, self, src, dimension, biased);
}
THCTensor_(free)(state, src);
diff --git a/test/test_cuda.py b/test/test_cuda.py
index cbd0a2e..75d9f16 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1009,9 +1009,17 @@
def test_var_stability(self):
tensor = torch.FloatTensor([2281.5, 2281.25]).cuda()
+
+ # Stability for inner dim
self.assertEqual(tensor.var(0)[0], 0.03125)
+
+ # General stability
self.assertEqual(tensor.var(), 0.03125)
+ # Stability for outer dimensions
+ tensor = tensor.unsqueeze(1)
+ self.assertEqual(tensor.var(0)[0], 0.03125)
+
def test_arange(self):
for t in ['IntTensor', 'LongTensor', 'FloatTensor', 'DoubleTensor']:
a = torch.cuda.__dict__[t]()