THTensor_varOuterDim numeric stability (#3533)

diff --git a/aten/src/THC/THCTensorMathReduce.cuh b/aten/src/THC/THCTensorMathReduce.cuh
index 1e81189..c2d8f0f 100644
--- a/aten/src/THC/THCTensorMathReduce.cuh
+++ b/aten/src/THC/THCTensorMathReduce.cuh
@@ -305,31 +305,38 @@
  * outer dimensions, which contains several "inner rows").
  * Each thread processes a single inner row at a time.
  */
-template<typename Real, bool flag, bool apply_sqrt>
+template<typename Real, typename Accreal, bool flag, bool apply_sqrt>
 __global__ void THCTensor_kernel_varOuterDim(Real *tgt, Real *src_, unsigned num_orows, unsigned num_irows, unsigned row_size)
 {
   for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
     for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
       Real *src = src_ + orow * row_size * num_irows + irow;
-      Real sum = ScalarConvert<int, Real>::to(0), sum2 = ScalarConvert<int, Real>::to(0);
+      Accreal mean = ScalarConvert<int, Accreal>::to(0);
+      Accreal m2 = ScalarConvert<int, Accreal>::to(0);
 
       for (unsigned col = 0; col < row_size; ++col) {
-        Real val = *src;
-        sum = THCNumerics<Real>::add(sum, val);
-        sum2 = THCNumerics<Real>::add(
-          sum2,
-          THCNumerics<Real>::mul(val, val)
-        );
-
+        Accreal val = ScalarConvert<Real, Accreal>::to(*src);
+        Accreal delta = THCNumerics<Accreal>::sub(val, mean);
+        mean = THCNumerics<Accreal>::add(mean,
+            THCNumerics<Accreal>::div(delta, ScalarConvert<int, Accreal>::to(col + 1)));
+        Accreal delta2 = THCNumerics<Accreal>::sub(val, mean);
+        m2 = THCNumerics<Accreal>::add(m2,
+            THCNumerics<Accreal>::mul(delta, delta2));
         src += num_irows;
       }
-
-      tgt[orow * num_irows + irow] = THCTensor_computeVar<Real, flag, apply_sqrt>(sum, sum2, row_size);
+      
+      if (flag) {
+        m2 = THCNumerics<Accreal>::div(m2, ScalarConvert<int, Accreal>::to(row_size));
+      } else {
+        m2 = THCNumerics<Accreal>::div(m2, ScalarConvert<int, Accreal>::to(row_size - 1));
+      }
+      tgt[orow * num_irows + irow] = ScalarConvert<Accreal, Real>::to(
+          apply_sqrt ? THCNumerics<Accreal>::sqrt(m2) : m2);
     }
   }
 }
 
-template<typename TensorTypeK, typename Real, bool apply_sqrt>
+template<typename TensorTypeK, typename Real, typename Accreal, bool apply_sqrt>
 __host__ void THCTensor_varOuterDim(THCState *state, TensorTypeK *tgt, TensorTypeK *src, int64_t dimension, int flag)
 {
   unsigned ndim = TensorUtils<TensorTypeK>::getDims(state, src);
@@ -350,10 +357,10 @@
   dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, THCCeilDiv(num_irows, threads.x)));
 
   if (flag) {
-    THCTensor_kernel_varOuterDim<Real, true, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
+    THCTensor_kernel_varOuterDim<Real, Accreal, true, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
         TensorUtils<TensorTypeK>::getData(state, tgt), TensorUtils<TensorTypeK>::getData(state, src), num_orows, num_irows, row_size);
   } else {
-    THCTensor_kernel_varOuterDim<Real, false, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
+    THCTensor_kernel_varOuterDim<Real, Accreal, false, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
         TensorUtils<TensorTypeK>::getData(state, tgt), TensorUtils<TensorTypeK>::getData(state, src), num_orows, num_irows, row_size);
   }
   cudaError errcode = cudaGetLastError();
diff --git a/aten/src/THC/generic/THCTensorMathReduce.cu b/aten/src/THC/generic/THCTensorMathReduce.cu
index 494f2f7..ff662eb 100644
--- a/aten/src/THC/generic/THCTensorMathReduce.cu
+++ b/aten/src/THC/generic/THCTensorMathReduce.cu
@@ -88,7 +88,7 @@
   if (dimension == THCTensor_(nDimension)(state, src) - 1) {
     THCTensor_varInnermostDim<THCTensor, real, accreal, true>(state, self, src, biased);
   } else {
-    THCTensor_varOuterDim<THCTensor, real, true>(state, self, src, dimension, biased);
+    THCTensor_varOuterDim<THCTensor, real, accreal, true>(state, self, src, dimension, biased);
   }
 
   THCTensor_(free)(state, src);
@@ -114,7 +114,7 @@
   if (dimension == THCTensor_(nDimension)(state, src) - 1) {
     THCTensor_varInnermostDim<THCTensor, real, accreal, false>(state, self, src, biased);
   } else {
-    THCTensor_varOuterDim<THCTensor, real, false>(state, self, src, dimension, biased);
+    THCTensor_varOuterDim<THCTensor, real, accreal, false>(state, self, src, dimension, biased);
   }
 
   THCTensor_(free)(state, src);
diff --git a/test/test_cuda.py b/test/test_cuda.py
index cbd0a2e..75d9f16 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1009,9 +1009,17 @@
 
     def test_var_stability(self):
         tensor = torch.FloatTensor([2281.5, 2281.25]).cuda()
+
+        # Stability for inner dim
         self.assertEqual(tensor.var(0)[0], 0.03125)
+
+        # General stability
         self.assertEqual(tensor.var(), 0.03125)
 
+        # Stability for outer dimensions
+        tensor = tensor.unsqueeze(1)
+        self.assertEqual(tensor.var(0)[0], 0.03125)
+
     def test_arange(self):
         for t in ['IntTensor', 'LongTensor', 'FloatTensor', 'DoubleTensor']:
             a = torch.cuda.__dict__[t]()