Remove Eigen from math CUDA and update algorithm in ReduceTensor and Moments (#6922)

diff --git a/caffe2/operators/reduce_ops.h b/caffe2/operators/reduce_ops.h
index 137b159..7ba4d4b 100644
--- a/caffe2/operators/reduce_ops.h
+++ b/caffe2/operators/reduce_ops.h
@@ -65,8 +65,6 @@
 
   std::vector<int> axes_;
   const int keep_dims_;
-
-  Tensor<Context> buffer_;
 };
 
 template <typename T, class Context>
@@ -90,8 +88,7 @@
         axes.data(),
         X_data,
         Y_data,
-        &context_,
-        &this->buffer_);
+        &context_);
     return true;
   }
 };
@@ -117,8 +114,7 @@
         axes.data(),
         X_data,
         Y_data,
-        &context_,
-        &this->buffer_);
+        &context_);
     return true;
   }
 };
@@ -144,8 +140,7 @@
         axes.data(),
         X_data,
         Y_data,
-        &context_,
-        &this->buffer_);
+        &context_);
     return true;
   }
 };
@@ -171,8 +166,7 @@
         axes.data(),
         X_data,
         Y_data,
-        &context_,
-        &this->buffer_);
+        &context_);
     return true;
   }
 };
diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h
index 4185da1..7278593 100644
--- a/caffe2/utils/math.h
+++ b/caffe2/utils/math.h
@@ -160,8 +160,7 @@
     const int* axes,
     const T* X,
     T* Y,
-    Context* context,
-    Tensor<Context>* scratch_ptr = nullptr);
+    Context* context);
 
 template <typename T, class Context>
 void ReduceMax(
@@ -171,8 +170,7 @@
     const int* axes,
     const T* X,
     T* Y,
-    Context* context,
-    Tensor<Context>* scratch_ptr = nullptr);
+    Context* context);
 
 template <typename T, class Context>
 void ReduceSum(
@@ -182,8 +180,7 @@
     const int* axes,
     const T* X,
     T* Y,
-    Context* context,
-    Tensor<Context>* scratch_ptr = nullptr);
+    Context* context);
 
 template <typename T, class Context>
 void ReduceMean(
@@ -193,10 +190,9 @@
     const int* axes,
     const T* X,
     T* Y,
-    Context* context,
-    Tensor<Context>* scratch_ptr = nullptr);
+    Context* context);
 
-// Broadcasts X with X_dims to Y with Y_dims and multiply the data by scale.
+// Broadcasts X with X_dims to Y with Y_dims.
 template <typename T, class Context>
 void Broadcast(
     const int X_ndim,
@@ -207,6 +203,7 @@
     T* Y,
     Context* context);
 
+// COmputes mean and variance over axes.
 template <typename T, class Context>
 void Moments(
     const int num_dims,
@@ -216,8 +213,7 @@
     const T* X,
     T* mean,
     T* variance,
-    Context* context,
-    Tensor<Context>* scratch_ptr = nullptr);
+    Context* context);
 
 // Adds batch sub-tensors elementwise to output. Stripe is the stripe length
 // and N is the number of elements to add (size of Y).
diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc
index bc7a208..74d6db2 100644
--- a/caffe2/utils/math_cpu.cc
+++ b/caffe2/utils/math_cpu.cc
@@ -971,8 +971,7 @@
       const int* axes,                                               \
       const T* X,                                                    \
       T* Y,                                                          \
-      CPUContext* context,                                           \
-      Tensor<CPUContext>* /* scratch_ptr */) {                       \
+      CPUContext* context) {                                         \
     ReduceMinImpl<T>(num_dims, dims, num_axes, axes, X, Y, context); \
   }
 CAFFE2_SPECIALIZED_REDUCE_MIN(float)
@@ -987,8 +986,7 @@
       const int* axes,                                               \
       const T* X,                                                    \
       T* Y,                                                          \
-      CPUContext* context,                                           \
-      Tensor<CPUContext>* /* scratch_ptr */) {                       \
+      CPUContext* context) {                                         \
     ReduceMaxImpl<T>(num_dims, dims, num_axes, axes, X, Y, context); \
   }
 CAFFE2_SPECIALIZED_REDUCE_MAX(float)
@@ -1003,8 +1001,7 @@
       const int* axes,                                               \
       const T* X,                                                    \
       T* Y,                                                          \
-      CPUContext* context,                                           \
-      Tensor<CPUContext>* /* scratch_ptr */) {                       \
+      CPUContext* context) {                                         \
     ReduceSumImpl<T>(num_dims, dims, num_axes, axes, X, Y, context); \
   }
 CAFFE2_SPECIALIZED_REDUCE_SUM(float)
@@ -1019,8 +1016,7 @@
       const int* axes,                                                \
       const T* X,                                                     \
       T* Y,                                                           \
-      CPUContext* context,                                            \
-      Tensor<CPUContext>* /* scratch_ptr */) {                        \
+      CPUContext* context) {                                          \
     ReduceMeanImpl<T>(num_dims, dims, num_axes, axes, X, Y, context); \
   }
 CAFFE2_SPECIALIZED_REDUCE_MEAN(float)
@@ -1122,8 +1118,7 @@
       const T* X,                                                    \
       T* mean,                                                       \
       T* variance,                                                   \
-      CPUContext* context,                                           \
-      Tensor<CPUContext>* /* scratch_ptr */) {                       \
+      CPUContext* context) {                                         \
     MomentsImpl<T>(                                                  \
         num_dims, dims, num_axes, axes, X, mean, variance, context); \
   }
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index 0b28a04..8813aae 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -1,7 +1,5 @@
 // Implements the math functions for GPU.
 
-#define EIGEN_USE_GPU
-
 #include "caffe2/utils/math.h"
 
 #include <limits>
@@ -14,21 +12,11 @@
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/utils/conversions.h"
 
-#if EIGEN_VERSION_AT_LEAST(3, 3, 0)
-#include "unsupported/Eigen/CXX11/Tensor"
-#endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
-
 #if THRUST_VERSION >= 100800
 #define THRUST_SUPPORTS_PER_THREAD
 #endif  // THRUST_VERSION >= 100800
 
 namespace caffe2 {
-
-#if EIGEN_VERSION_AT_LEAST(3, 3, 0)
-template <typename T, int D>
-using EigenTensorMap = Eigen::TensorMap<Eigen::Tensor<T, D>>;
-#endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
-
 namespace math {
 
 #define DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(T, Funcname, function)             \
@@ -2159,123 +2147,7 @@
 
 namespace {
 
-#if EIGEN_VERSION_AT_LEAST(3, 3, 0)
-
-template <typename T, class Reducer, int kNumDims, int kNumAxes>
-void EigenReduceTensorCUDAImpl(
-    const int* dims,
-    const int* axes,
-    const Reducer& reducer,
-    const T* X,
-    T* Y,
-    CUDAContext* context) {
-  Eigen::DSizes<Eigen::DenseIndex, kNumDims> X_dims;
-  Eigen::DSizes<Eigen::DenseIndex, kNumDims> Y_dims;
-  Eigen::array<Eigen::DenseIndex, kNumAxes> reduce_dims;
-  for (int i = 0; i < kNumDims; ++i) {
-    X_dims[i] = static_cast<Eigen::DenseIndex>(dims[kNumDims - 1 - i]);
-    Y_dims[i] = static_cast<Eigen::DenseIndex>(dims[kNumDims - 1 - i]);
-  }
-  for (int i = 0; i < kNumAxes; ++i) {
-    Y_dims[kNumDims - 1 - axes[i]] = static_cast<Eigen::DenseIndex>(1);
-    reduce_dims[kNumAxes - 1 - i] =
-        static_cast<Eigen::DenseIndex>(kNumDims - 1 - axes[i]);
-  }
-  const cudaStream_t cuda_stream = context->cuda_stream();
-  const Eigen::CudaStreamDevice stream_device(
-      &cuda_stream, context->cuda_gpu_id());
-  const Eigen::GpuDevice gpu_device(&stream_device);
-  EigenTensorMap<T, kNumDims>(Y, Y_dims).device(gpu_device) =
-      EigenTensorMap<T, kNumDims>(const_cast<T*>(X), X_dims)
-          .reduce(reduce_dims, reducer);
-}
-
-#endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
-
-template <typename T, class Reducer>
-bool EigenReduceTensorCUDA(
-    const int num_dims,
-    const int* dims,
-    const int num_axes,
-    const int* axes,
-    const Reducer& reducer,
-    const T* X,
-    T* Y,
-    CUDAContext* context) {
-  switch (num_dims) {
-    case 1: {
-      switch (num_axes) {
-        case 1: {
-          EigenReduceTensorCUDAImpl<T, Reducer, 1, 1>(
-              dims, axes, reducer, X, Y, context);
-          return true;
-        }
-        default: { return false; }
-      }
-    }
-    case 2: {
-      switch (num_axes) {
-        case 1: {
-          EigenReduceTensorCUDAImpl<T, Reducer, 2, 1>(
-              dims, axes, reducer, X, Y, context);
-          return true;
-        }
-        case 2: {
-          EigenReduceTensorCUDAImpl<T, Reducer, 2, 2>(
-              dims, axes, reducer, X, Y, context);
-          return true;
-        }
-        default: { return false; }
-      }
-    }
-    case 3: {
-      switch (num_axes) {
-        case 1: {
-          EigenReduceTensorCUDAImpl<T, Reducer, 3, 1>(
-              dims, axes, reducer, X, Y, context);
-          return true;
-        }
-        case 2: {
-          EigenReduceTensorCUDAImpl<T, Reducer, 3, 2>(
-              dims, axes, reducer, X, Y, context);
-          return true;
-        }
-        case 3: {
-          EigenReduceTensorCUDAImpl<T, Reducer, 3, 3>(
-              dims, axes, reducer, X, Y, context);
-          return true;
-        }
-        default: { return false; }
-      }
-    }
-    case 4: {
-      switch (num_axes) {
-        case 1: {
-          EigenReduceTensorCUDAImpl<T, Reducer, 4, 1>(
-              dims, axes, reducer, X, Y, context);
-          return true;
-        }
-        case 2: {
-          EigenReduceTensorCUDAImpl<T, Reducer, 4, 2>(
-              dims, axes, reducer, X, Y, context);
-          return true;
-        }
-        case 3: {
-          EigenReduceTensorCUDAImpl<T, Reducer, 4, 3>(
-              dims, axes, reducer, X, Y, context);
-          return true;
-        }
-        case 4: {
-          EigenReduceTensorCUDAImpl<T, Reducer, 4, 4>(
-              dims, axes, reducer, X, Y, context);
-          return true;
-        }
-        default: { return false; }
-      }
-    }
-    default: { return false; }
-  }
-}
+constexpr int kCUDAReduceTensorMaxDims = 8;
 
 std::vector<int> MakeTransposeAxes(
     const int num_dims,
@@ -2298,6 +2170,82 @@
   return transpose_axes;
 }
 
+template <int D>
+void ComputeTransposedStrides(
+    const int* X_dims,
+    const int* axes,
+    int* X_strides) {
+  int buff[D];
+  int cur_stride = 1;
+  for (int i = D - 1; i >= 0; --i) {
+    buff[i] = cur_stride;
+    cur_stride *= X_dims[i];
+  }
+  for (int i = 0; i < D; ++i) {
+    X_strides[i] = buff[axes[i]];
+  }
+}
+
+template <typename T, class Reducer, int D>
+__global__ void ReduceTensorCUDAKernel(
+    const int outer_size,
+    const int inner_size,
+    SimpleArray<int, D> X_strides,
+    SimpleArray<int, D> Y_dims,
+    const Reducer reducer,
+    const T init,
+    const T* X,
+    T* Y) {
+  __shared__ typename BlockReduce<T>::TempStorage temp_storage;
+  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
+    T val = init;
+    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
+      int X_index = 0;
+      int Y_index = i * inner_size + j;
+#pragma unroll
+      for (int i = D - 1; i >= 0; --i) {
+        X_index += (Y_index % Y_dims.data[i]) * X_strides.data[i];
+        Y_index /= Y_dims.data[i];
+      }
+#if __CUDA_ARCH__ >= 350
+      val = reducer(val, __ldg(X + X_index));
+#else
+      val = reducer(val, X[X_index]);
+#endif
+    }
+    val = BlockReduce<T>(temp_storage).Reduce(val, reducer);
+    if (threadIdx.x == 0) {
+      Y[i] = val;
+    }
+    __syncthreads();
+  }
+}
+
+template <typename T, class Reducer, int D>
+void ReduceTensorCUDAImpl(
+    const int outer_size,
+    const int inner_size,
+    const int* dims,
+    const int* axes,
+    const Reducer& reducer,
+    const T& init,
+    const T* X,
+    T* Y,
+    CUDAContext* context) {
+  SimpleArray<int, D> X_strides;
+  SimpleArray<int, D> Y_dims;
+  ComputeTransposedStrides<D>(dims, axes, X_strides.data);
+  for (int i = 0; i < D; ++i) {
+    Y_dims.data[i] = dims[axes[i]];
+  }
+  ReduceTensorCUDAKernel<T, Reducer, D>
+      <<<std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS),
+         CAFFE_CUDA_NUM_THREADS,
+         0,
+         context->cuda_stream()>>>(
+          outer_size, inner_size, X_strides, Y_dims, reducer, init, X, Y);
+}
+
 template <typename T, class Reducer>
 void ReduceTensorCUDA(
     const int num_dims,
@@ -2308,148 +2256,136 @@
     const T& init,
     const T* X,
     T* Y,
-    CUDAContext* context,
-    Tensor<CUDAContext>* scratch_ptr) {
+    CUDAContext* context) {
+  CAFFE_ENFORCE_LE(num_dims, kCUDAReduceTensorMaxDims);
+  CAFFE_ENFORCE_LE(num_axes, num_dims);
   const std::vector<int> transpose_axes =
       MakeTransposeAxes(num_dims, dims, num_axes, axes);
-  const int d = num_dims - num_axes;
+  const int pivot = num_dims - num_axes;
   int outer_size = 1;
-  for (int i = 0; i < d; ++i) {
+  for (int i = 0; i < pivot; ++i) {
     outer_size *= dims[transpose_axes[i]];
   }
   int inner_size = 1;
-  for (int i = d; i < num_dims; ++i) {
+  for (int i = pivot; i < num_dims; ++i) {
     inner_size *= dims[transpose_axes[i]];
   }
-  const T* X_data = X;
-  if (transpose_axes[d] != d) {
-    scratch_ptr->Resize(std::vector<int>{outer_size, inner_size});
-    Transpose<T, CUDAContext>(
-        num_dims,
-        dims,
-        transpose_axes.data(),
-        X,
-        scratch_ptr->mutable_data<T>(),
-        context);
-    X_data = scratch_ptr->data<T>();
-  }
-  RowwiseReduceKernel<T>
-      <<<std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context->cuda_stream()>>>(
-          outer_size, inner_size, reducer, init, X_data, Y);
-}
-
-template <typename T>
-void ReduceMinCUDAImpl(
-    const int num_dims,
-    const int* dims,
-    const int num_axes,
-    const int* axes,
-    const T* X,
-    T* Y,
-    CUDAContext* context,
-    Tensor<CUDAContext>* scratch_ptr) {
-  CAFFE_ENFORCE_LE(num_axes, num_dims);
-#if EIGEN_VERSION_AT_LEAST(3, 3, 0)
-  if (EigenReduceTensorCUDA(
-          num_dims,
-          dims,
-          num_axes,
-          axes,
-          Eigen::internal::MinReducer<T>(),
-          X,
-          Y,
-          context)) {
+  if (transpose_axes[pivot] == pivot) {
+    RowwiseReduceKernel<T>
+        <<<std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS),
+           CAFFE_CUDA_NUM_THREADS,
+           0,
+           context->cuda_stream()>>>(
+            outer_size, inner_size, reducer, init, X, Y);
     return;
   }
-#endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
-  ReduceTensorCUDA(
-      num_dims,
-      dims,
-      num_axes,
-      axes,
-      cub::Min(),
-      std::numeric_limits<T>::max(),
-      X,
-      Y,
-      context,
-      scratch_ptr);
-}
-
-template <typename T>
-void ReduceMaxCUDAImpl(
-    const int num_dims,
-    const int* dims,
-    const int num_axes,
-    const int* axes,
-    const T* X,
-    T* Y,
-    CUDAContext* context,
-    Tensor<CUDAContext>* scratch_ptr) {
-  CAFFE_ENFORCE_LE(num_axes, num_dims);
-#if EIGEN_VERSION_AT_LEAST(3, 3, 0)
-  if (EigenReduceTensorCUDA(
-          num_dims,
+  switch (num_dims) {
+    case 1: {
+      ReduceTensorCUDAImpl<T, Reducer, 1>(
+          outer_size,
+          inner_size,
           dims,
-          num_axes,
-          axes,
-          Eigen::internal::MaxReducer<T>(),
+          transpose_axes.data(),
+          reducer,
+          init,
           X,
           Y,
-          context)) {
-    return;
-  }
-#endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
-  ReduceTensorCUDA(
-      num_dims,
-      dims,
-      num_axes,
-      axes,
-      cub::Max(),
-      std::numeric_limits<T>::lowest(),
-      X,
-      Y,
-      context,
-      scratch_ptr);
-}
-
-template <typename T>
-void ReduceSumCUDAImpl(
-    const int num_dims,
-    const int* dims,
-    const int num_axes,
-    const int* axes,
-    const T* X,
-    T* Y,
-    CUDAContext* context,
-    Tensor<CUDAContext>* scratch_ptr) {
-  CAFFE_ENFORCE_LE(num_axes, num_dims);
-#if EIGEN_VERSION_AT_LEAST(3, 3, 0)
-  if (EigenReduceTensorCUDA(
-          num_dims,
+          context);
+      break;
+    }
+    case 2: {
+      ReduceTensorCUDAImpl<T, Reducer, 2>(
+          outer_size,
+          inner_size,
           dims,
-          num_axes,
-          axes,
-          Eigen::internal::SumReducer<T>(),
+          transpose_axes.data(),
+          reducer,
+          init,
           X,
           Y,
-          context)) {
-    return;
+          context);
+      break;
+    }
+    case 3: {
+      ReduceTensorCUDAImpl<T, Reducer, 3>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          reducer,
+          init,
+          X,
+          Y,
+          context);
+      break;
+    }
+    case 4: {
+      ReduceTensorCUDAImpl<T, Reducer, 4>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          reducer,
+          init,
+          X,
+          Y,
+          context);
+      break;
+    }
+    case 5: {
+      ReduceTensorCUDAImpl<T, Reducer, 5>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          reducer,
+          init,
+          X,
+          Y,
+          context);
+      break;
+    }
+    case 6: {
+      ReduceTensorCUDAImpl<T, Reducer, 6>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          reducer,
+          init,
+          X,
+          Y,
+          context);
+      break;
+    }
+    case 7: {
+      ReduceTensorCUDAImpl<T, Reducer, 7>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          reducer,
+          init,
+          X,
+          Y,
+          context);
+      break;
+    }
+    case 8: {
+      ReduceTensorCUDAImpl<T, Reducer, 8>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          reducer,
+          init,
+          X,
+          Y,
+          context);
+      break;
+    }
+    default: { break; }
   }
-#endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
-  ReduceTensorCUDA(
-      num_dims,
-      dims,
-      num_axes,
-      axes,
-      cub::Sum(),
-      T(0),
-      X,
-      Y,
-      context,
-      scratch_ptr);
 }
 
 template <typename T>
@@ -2460,33 +2396,9 @@
     const int* axes,
     const T* X,
     T* Y,
-    CUDAContext* context,
-    Tensor<CUDAContext>* scratch_ptr) {
-  CAFFE_ENFORCE_LE(num_axes, num_dims);
-#if EIGEN_VERSION_AT_LEAST(3, 3, 0)
-  if (EigenReduceTensorCUDA(
-          num_dims,
-          dims,
-          num_axes,
-          axes,
-          Eigen::internal::MeanReducer<T>(),
-          X,
-          Y,
-          context)) {
-    return;
-  }
-#endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
+    CUDAContext* context) {
   ReduceTensorCUDA(
-      num_dims,
-      dims,
-      num_axes,
-      axes,
-      cub::Sum(),
-      T(0),
-      X,
-      Y,
-      context,
-      scratch_ptr);
+      num_dims, dims, num_axes, axes, cub::Sum(), T(0), X, Y, context);
   const int X_size =
       std::accumulate(dims, dims + num_dims, 1, std::multiplies<int>());
   int scale = 1;
@@ -2500,70 +2412,81 @@
 
 } // namespace
 
-#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MIN(T)                        \
-  template <>                                                        \
-  void ReduceMin<T, CUDAContext>(                                    \
-      const int num_dims,                                            \
-      const int* dims,                                               \
-      const int num_axes,                                            \
-      const int* axes,                                               \
-      const T* X,                                                    \
-      T* Y,                                                          \
-      CUDAContext* context,                                          \
-      Tensor<CUDAContext>* scratch_ptr) {                            \
-    ReduceMinCUDAImpl<T>(                                            \
-        num_dims, dims, num_axes, axes, X, Y, context, scratch_ptr); \
+#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MIN(T) \
+  template <>                                 \
+  void ReduceMin<T, CUDAContext>(             \
+      const int num_dims,                     \
+      const int* dims,                        \
+      const int num_axes,                     \
+      const int* axes,                        \
+      const T* X,                             \
+      T* Y,                                   \
+      CUDAContext* context) {                 \
+    ReduceTensorCUDA(                         \
+        num_dims,                             \
+        dims,                                 \
+        num_axes,                             \
+        axes,                                 \
+        cub::Min(),                           \
+        std::numeric_limits<T>::max(),        \
+        X,                                    \
+        Y,                                    \
+        context);                             \
   }
 CAFFE2_SPECIALIZED_CUDA_REDUCE_MIN(float)
 #undef CAFFE2_SPECIALIZED_CUDA_REDUCE_MIN
 
-#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MAX(T)                        \
-  template <>                                                        \
-  void ReduceMax<T, CUDAContext>(                                    \
-      const int num_dims,                                            \
-      const int* dims,                                               \
-      const int num_axes,                                            \
-      const int* axes,                                               \
-      const T* X,                                                    \
-      T* Y,                                                          \
-      CUDAContext* context,                                          \
-      Tensor<CUDAContext>* scratch_ptr) {                            \
-    ReduceMaxCUDAImpl<T>(                                            \
-        num_dims, dims, num_axes, axes, X, Y, context, scratch_ptr); \
+#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MAX(T) \
+  template <>                                 \
+  void ReduceMax<T, CUDAContext>(             \
+      const int num_dims,                     \
+      const int* dims,                        \
+      const int num_axes,                     \
+      const int* axes,                        \
+      const T* X,                             \
+      T* Y,                                   \
+      CUDAContext* context) {                 \
+    ReduceTensorCUDA(                         \
+        num_dims,                             \
+        dims,                                 \
+        num_axes,                             \
+        axes,                                 \
+        cub::Max(),                           \
+        std::numeric_limits<T>::lowest(),     \
+        X,                                    \
+        Y,                                    \
+        context);                             \
   }
 CAFFE2_SPECIALIZED_CUDA_REDUCE_MAX(float)
 #undef CAFFE2_SPECIALIZED_CUDA_REDUCE_MAX
 
-#define CAFFE2_SPECIALIZED_CUDA_REDUCE_SUM(T)                        \
-  template <>                                                        \
-  void ReduceSum<T, CUDAContext>(                                    \
-      const int num_dims,                                            \
-      const int* dims,                                               \
-      const int num_axes,                                            \
-      const int* axes,                                               \
-      const T* X,                                                    \
-      T* Y,                                                          \
-      CUDAContext* context,                                          \
-      Tensor<CUDAContext>* scratch_ptr) {                            \
-    ReduceSumCUDAImpl<T>(                                            \
-        num_dims, dims, num_axes, axes, X, Y, context, scratch_ptr); \
+#define CAFFE2_SPECIALIZED_CUDA_REDUCE_SUM(T)                             \
+  template <>                                                             \
+  void ReduceSum<T, CUDAContext>(                                         \
+      const int num_dims,                                                 \
+      const int* dims,                                                    \
+      const int num_axes,                                                 \
+      const int* axes,                                                    \
+      const T* X,                                                         \
+      T* Y,                                                               \
+      CUDAContext* context) {                                             \
+    ReduceTensorCUDA(                                                     \
+        num_dims, dims, num_axes, axes, cub::Sum(), T(0), X, Y, context); \
   }
 CAFFE2_SPECIALIZED_CUDA_REDUCE_SUM(float)
 #undef CAFFE2_SPECIALIZED_CUDA_REDUCE_SUM
 
-#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(T)                       \
-  template <>                                                        \
-  void ReduceMean<T, CUDAContext>(                                   \
-      const int num_dims,                                            \
-      const int* dims,                                               \
-      const int num_axes,                                            \
-      const int* axes,                                               \
-      const T* X,                                                    \
-      T* Y,                                                          \
-      CUDAContext* context,                                          \
-      Tensor<CUDAContext>* scratch_ptr) {                            \
-    ReduceMeanCUDAImpl<T>(                                           \
-        num_dims, dims, num_axes, axes, X, Y, context, scratch_ptr); \
+#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(T)                            \
+  template <>                                                             \
+  void ReduceMean<T, CUDAContext>(                                        \
+      const int num_dims,                                                 \
+      const int* dims,                                                    \
+      const int num_axes,                                                 \
+      const int* axes,                                                    \
+      const T* X,                                                         \
+      T* Y,                                                               \
+      CUDAContext* context) {                                             \
+    ReduceMeanCUDAImpl<T>(num_dims, dims, num_axes, axes, X, Y, context); \
   }
 CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(float)
 #undef CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN
@@ -2693,8 +2616,6 @@
 
 namespace {
 
-constexpr int kCUDAMomentsMaxDims = 8;
-
 template <typename T>
 __global__ void RowwiseMomentsCUDAKernel(
     const int rows,
@@ -2727,8 +2648,72 @@
   }
 }
 
-template <typename T>
+template <typename T, int D>
+__global__ void MomentsCUDAKernel(
+    const int outer_size,
+    const int inner_size,
+    SimpleArray<int, D> X_strides,
+    SimpleArray<int, D> Y_dims,
+    const T* X,
+    T* mean,
+    T* variance) {
+  __shared__ typename BlockReduce<T>::TempStorage m_storage;
+  __shared__ typename BlockReduce<T>::TempStorage v_storage;
+  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
+    T m_val = 0;
+    T v_val = 0;
+    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
+      int X_index = 0;
+      int Y_index = i * inner_size + j;
+#pragma unroll
+      for (int i = D - 1; i >= 0; --i) {
+        X_index += (Y_index % Y_dims.data[i]) * X_strides.data[i];
+        Y_index /= Y_dims.data[i];
+      }
+#if __CUDA_ARCH__ >= 350
+      m_val += __ldg(X + X_index);
+      v_val += __ldg(X + X_index) * __ldg(X + X_index);
+#else
+      m_val += X[X_index];
+      v_val += X[X_index] * X[X_index];
+#endif
+    }
+    m_val = BlockReduce<T>(m_storage).Reduce(m_val, cub::Sum());
+    v_val = BlockReduce<T>(v_storage).Reduce(v_val, cub::Sum());
+    if (threadIdx.x == 0) {
+      mean[i] = m_val / static_cast<T>(inner_size);
+      variance[i] = v_val / static_cast<T>(inner_size) - mean[i] * mean[i];
+    }
+    __syncthreads();
+  }
+}
+
+template <typename T, int D>
 void MomentsCUDAImpl(
+    const int outer_size,
+    const int inner_size,
+    const int* dims,
+    const int* axes,
+    const T* X,
+    T* mean,
+    T* variance,
+    CUDAContext* context) {
+  SimpleArray<int, D> X_strides;
+  SimpleArray<int, D> Y_dims;
+  ComputeTransposedStrides<D>(dims, axes, X_strides.data);
+  for (int i = 0; i < D; ++i) {
+    Y_dims.data[i] = dims[axes[i]];
+  }
+  MomentsCUDAKernel<T, D>
+      <<<std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS),
+         CAFFE_CUDA_NUM_THREADS,
+         0,
+         context->cuda_stream()>>>(
+          outer_size, inner_size, X_strides, Y_dims, X, mean, variance);
+}
+
+template <typename T>
+void MomentsCUDA(
     const int num_dims,
     const int* dims,
     const int num_axes,
@@ -2736,65 +2721,144 @@
     const T* X,
     T* mean,
     T* variance,
-    CUDAContext* context,
-    Tensor<CUDAContext>* scratch_ptr) {
+    CUDAContext* context) {
+  CAFFE_ENFORCE_LE(num_dims, kCUDAReduceTensorMaxDims);
+  CAFFE_ENFORCE_LE(num_axes, num_dims);
   const std::vector<int> transpose_axes =
       MakeTransposeAxes(num_dims, dims, num_axes, axes);
-  const int d = num_dims - num_axes;
+  const int pivot = num_dims - num_axes;
   int outer_size = 1;
-  for (int i = 0; i < d; ++i) {
+  for (int i = 0; i < pivot; ++i) {
     outer_size *= dims[transpose_axes[i]];
   }
   int inner_size = 1;
-  for (int i = d; i < num_dims; ++i) {
+  for (int i = pivot; i < num_dims; ++i) {
     inner_size *= dims[transpose_axes[i]];
   }
-  const T* X_data = X;
-  if (transpose_axes[d] != d) {
-    scratch_ptr->Resize(std::vector<int>{outer_size, inner_size});
-    Transpose<T, CUDAContext>(
-        num_dims,
-        dims,
-        transpose_axes.data(),
-        X,
-        scratch_ptr->mutable_data<T>(),
-        context);
-    X_data = scratch_ptr->data<T>();
-  }
-  RowwiseMomentsCUDAKernel<T>
+  if (transpose_axes[pivot] == pivot) {
+    RowwiseMomentsCUDAKernel<T>
       <<<std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS),
          CAFFE_CUDA_NUM_THREADS,
          0,
-         context->cuda_stream()>>>(
-          outer_size, inner_size, X_data, mean, variance);
+         context->cuda_stream()>>>(outer_size, inner_size, X, mean, variance);
+    return;
+  }
+  switch (num_dims) {
+    case 1: {
+      MomentsCUDAImpl<T, 1>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          X,
+          mean,
+          variance,
+          context);
+      break;
+    }
+    case 2: {
+      MomentsCUDAImpl<T, 2>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          X,
+          mean,
+          variance,
+          context);
+      break;
+    }
+    case 3: {
+      MomentsCUDAImpl<T, 3>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          X,
+          mean,
+          variance,
+          context);
+      break;
+    }
+    case 4: {
+      MomentsCUDAImpl<T, 4>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          X,
+          mean,
+          variance,
+          context);
+      break;
+    }
+    case 5: {
+      MomentsCUDAImpl<T, 5>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          X,
+          mean,
+          variance,
+          context);
+      break;
+    }
+    case 6: {
+      MomentsCUDAImpl<T, 6>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          X,
+          mean,
+          variance,
+          context);
+      break;
+    }
+    case 7: {
+      MomentsCUDAImpl<T, 7>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          X,
+          mean,
+          variance,
+          context);
+      break;
+    }
+    case 8: {
+      MomentsCUDAImpl<T, 8>(
+          outer_size,
+          inner_size,
+          dims,
+          transpose_axes.data(),
+          X,
+          mean,
+          variance,
+          context);
+      break;
+    }
+    default: { break; }
+  }
 }
 
 } // namespace
 
-#define CAFFE2_SPECIALIZED_CUDA_MOMENTS(T)           \
-  template <>                                        \
-  void Moments<T, CUDAContext>(                      \
-      const int num_dims,                            \
-      const int* dims,                               \
-      const int num_axes,                            \
-      const int* axes,                               \
-      const T* X,                                    \
-      T* mean,                                       \
-      T* variance,                                   \
-      CUDAContext* context,                          \
-      Tensor<CUDAContext>* scratch_ptr) {            \
-    CAFFE_ENFORCE_LE(num_dims, kCUDAMomentsMaxDims); \
-    CAFFE_ENFORCE_LE(num_axes, num_dims);            \
-    MomentsCUDAImpl<T>(                              \
-        num_dims,                                    \
-        dims,                                        \
-        num_axes,                                    \
-        axes,                                        \
-        X,                                           \
-        mean,                                        \
-        variance,                                    \
-        context,                                     \
-        scratch_ptr);                                \
+#define CAFFE2_SPECIALIZED_CUDA_MOMENTS(T)                           \
+  template <>                                                        \
+  void Moments<T, CUDAContext>(                                      \
+      const int num_dims,                                            \
+      const int* dims,                                               \
+      const int num_axes,                                            \
+      const int* axes,                                               \
+      const T* X,                                                    \
+      T* mean,                                                       \
+      T* variance,                                                   \
+      CUDAContext* context) {                                        \
+    MomentsCUDA<T>(                                                  \
+        num_dims, dims, num_axes, axes, X, mean, variance, context); \
   }
 CAFFE2_SPECIALIZED_CUDA_MOMENTS(float)
 #undef CAFFE2_SPECIALIZED_CUDA_MOMENTS
@@ -2804,77 +2868,6 @@
 constexpr int kCUDATransposeMaxDims = 8;
 
 template <typename T, int D>
-void EigenTransposeCUDAImpl(
-    const int* dims,
-    const int* axes,
-    const T* X,
-    T* Y,
-    CUDAContext* context) {
-  Eigen::DSizes<Eigen::DenseIndex, D> X_dims;
-  Eigen::DSizes<Eigen::DenseIndex, D> Y_dims;
-  Eigen::array<Eigen::DenseIndex, D> axes_array;
-  for (int i = 0; i < D; ++i) {
-    X_dims[i] = static_cast<Eigen::DenseIndex>(dims[D - 1 - i]);
-    Y_dims[i] = static_cast<Eigen::DenseIndex>(dims[D - 1 - axes[i]]);
-    axes_array[D - 1 - i] = static_cast<Eigen::DenseIndex>(D - 1 - axes[i]);
-  }
-  const cudaStream_t cuda_stream = context->cuda_stream();
-  const Eigen::CudaStreamDevice stream_device(
-      &cuda_stream, context->cuda_gpu_id());
-  const Eigen::GpuDevice gpu_device(&stream_device);
-  EigenTensorMap<T, D>(Y, Y_dims).device(gpu_device) =
-      EigenTensorMap<T, D>(const_cast<T*>(X), X_dims).shuffle(axes_array);
-}
-
-template <typename T>
-bool EigenTransposeCUDA(
-    const int ndim,
-    const int* dims,
-    const int* axes,
-    const T* X,
-    T* Y,
-    CUDAContext* context) {
-#if EIGEN_VERSION_AT_LEAST(3, 3, 0)
-  switch (ndim) {
-    case 1: {
-      EigenTransposeCUDAImpl<T, 1>(dims, axes, X, Y, context);
-      return true;
-    }
-    case 2: {
-      EigenTransposeCUDAImpl<T, 2>(dims, axes, X, Y, context);
-      return true;
-    }
-    case 3: {
-      EigenTransposeCUDAImpl<T, 3>(dims, axes, X, Y, context);
-      return true;
-    }
-    case 4: {
-      EigenTransposeCUDAImpl<T, 4>(dims, axes, X, Y, context);
-      return true;
-    }
-    case 5: {
-      EigenTransposeCUDAImpl<T, 5>(dims, axes, X, Y, context);
-      return true;
-    }
-    case 6: {
-      EigenTransposeCUDAImpl<T, 6>(dims, axes, X, Y, context);
-      return true;
-    }
-    case 7: {
-      EigenTransposeCUDAImpl<T, 7>(dims, axes, X, Y, context);
-      return true;
-    }
-    case 8: {
-      EigenTransposeCUDAImpl<T, 8>(dims, axes, X, Y, context);
-      return true;
-    }
-    default: { return false; }
-  }
-#endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
-  return false;
-}
-
-template <typename T, int D>
 __global__ void TransposeCUDAKernel(
     const int size,
     const SimpleArray<int, D> X_strides,
@@ -2897,22 +2890,6 @@
   }
 }
 
-template <int D>
-void ComputeXStride(
-    const int* X_dims,
-    const int* axes,
-    int* X_strides) {
-  int buff[D];
-  int cur_stride = 1;
-  for (int i = D - 1; i >= 0; --i) {
-    buff[i] = cur_stride;
-    cur_stride *= X_dims[i];
-  }
-  for (int i = 0; i < D; ++i) {
-    X_strides[i] = buff[axes[i]];
-  }
-}
-
 template <typename T, int D>
 void TransposeCUDAImpl(
     const int* dims,
@@ -2922,7 +2899,7 @@
     CUDAContext* context) {
   SimpleArray<int, D> X_strides;
   SimpleArray<int, D> Y_dims;
-  ComputeXStride<D>(dims, axes, X_strides.data);
+  ComputeTransposedStrides<D>(dims, axes, X_strides.data);
   int size = 1;
   for (int i = 0; i < D; ++i) {
     Y_dims.data[i] = dims[axes[i]];
@@ -2943,6 +2920,8 @@
     const T* X,
     T* Y,
     CUDAContext* context) {
+  CAFFE_ENFORCE_LE(
+      ndim, kCUDATransposeMaxDims, "ndim exceeds compile time max.");
   switch (ndim) {
     case 1: {
       TransposeCUDAImpl<T, 1>(dims, axes, X, Y, context);
@@ -2991,11 +2970,6 @@
       const T* X,                                                       \
       T* Y,                                                             \
       CUDAContext* context) {                                           \
-    CAFFE_ENFORCE_LE(                                                   \
-        ndim, kCUDATransposeMaxDims, "ndim exceeds compile time max."); \
-    if (EigenTransposeCUDA(ndim, dims, axes, X, Y, context)) {          \
-      return;                                                           \
-    }                                                                   \
     TransposeCUDA<T>(ndim, dims, axes, X, Y, context);                  \
   }
 CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(float)
diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc
index 1622367..8de888f 100644
--- a/caffe2/utils/math_gpu_test.cc
+++ b/caffe2/utils/math_gpu_test.cc
@@ -331,10 +331,8 @@
     cuda_context_ = make_unique<CUDAContext>(option_);
     Blob* blob_x = ws_.CreateBlob("X");
     Blob* blob_y = ws_.CreateBlob("Y");
-    Blob* blob_scratch = ws_.CreateBlob("scratch");
     X_ = blob_x->GetMutable<Tensor<CUDAContext>>();
     Y_ = blob_y->GetMutable<Tensor<CUDAContext>>();
-    scratch_ptr_ = blob_scratch->GetMutable<Tensor<CUDAContext>>();
   }
 
   void SetUpData(
@@ -378,8 +376,7 @@
         axes.data(),
         X_->data<float>(),
         Y_->mutable_data<float>(),
-        cuda_context_.get(),
-        scratch_ptr_);
+        cuda_context_.get());
     VerifyResult(Y_data);
   }
 
@@ -388,7 +385,6 @@
   std::unique_ptr<CUDAContext> cuda_context_;
   Tensor<CUDAContext>* X_ = nullptr;
   Tensor<CUDAContext>* Y_ = nullptr;
-  Tensor<CUDAContext>* scratch_ptr_ = nullptr;
 };
 
 TEST_F(ReduceTensorGPUTest, ReduceMinGPUTest) {
@@ -401,10 +397,9 @@
                               const int* axes,
                               const float* X,
                               float* Y,
-                              CUDAContext* context,
-                              Tensor<CUDAContext>* scratch_ptr) {
+                              CUDAContext* context) {
     return math::ReduceMin<float, CUDAContext>(
-        num_dims, dims, num_axes, axes, X, Y, context, scratch_ptr);
+        num_dims, dims, num_axes, axes, X, Y, context);
   };
   // Test for 1D tensor.
   RunRedcueTensorTest(
@@ -465,10 +460,9 @@
                               const int* axes,
                               const float* X,
                               float* Y,
-                              CUDAContext* context,
-                              Tensor<CUDAContext>* scratch_ptr) {
+                              CUDAContext* context) {
     return math::ReduceMax<float, CUDAContext>(
-        num_dims, dims, num_axes, axes, X, Y, context, scratch_ptr);
+        num_dims, dims, num_axes, axes, X, Y, context);
   };
   // Test for 1D tensor.
   RunRedcueTensorTest(
@@ -711,11 +705,9 @@
     Blob* blob_x = ws_.CreateBlob("X");
     Blob* blob_mean = ws_.CreateBlob("mean");
     Blob* blob_variance = ws_.CreateBlob("variance");
-    Blob* blob_scratch = ws_.CreateBlob("scratch");
     X_ = blob_x->GetMutable<Tensor<CUDAContext>>();
     mean_ = blob_mean->GetMutable<Tensor<CUDAContext>>();
     variance_ = blob_variance->GetMutable<Tensor<CUDAContext>>();
-    scratch_ptr_ = blob_scratch->GetMutable<Tensor<CUDAContext>>();
   }
 
   void SetUpData(
@@ -771,8 +763,7 @@
         X_->data<float>(),
         mean_->mutable_data<float>(),
         variance_->mutable_data<float>(),
-        cuda_context_.get(),
-        scratch_ptr_);
+        cuda_context_.get());
     VerifyResult(mean_data, variance_data);
   }
 
@@ -782,7 +773,6 @@
   Tensor<CUDAContext>* X_ = nullptr;
   Tensor<CUDAContext>* mean_ = nullptr;
   Tensor<CUDAContext>* variance_ = nullptr;
-  Tensor<CUDAContext>* scratch_ptr_ = nullptr;
 };
 
 TEST_F(MomentsGPUTest, MomentsGPUFloatTest) {
diff --git a/caffe2/utils/math_test.cc b/caffe2/utils/math_test.cc
index 1a7c5f4..8ade7d6 100644
--- a/caffe2/utils/math_test.cc
+++ b/caffe2/utils/math_test.cc
@@ -407,8 +407,7 @@
         axes.data(),
         X_.data<float>(),
         Y_.mutable_data<float>(),
-        cpu_context_.get(),
-        nullptr);
+        cpu_context_.get());
     ASSERT_EQ(Y_data.size(), Y_.size());
     for (int i = 0; i < Y_.size(); ++i) {
       EXPECT_FLOAT_EQ(Y_data[i], Y_.data<float>()[i]);
@@ -428,10 +427,9 @@
                               const int* axes,
                               const float* X,
                               float* Y,
-                              CPUContext* context,
-                              TensorCPU* scratch_ptr) {
+                              CPUContext* context) {
     return math::ReduceMin<float, CPUContext>(
-        num_dims, dims, num_axes, axes, X, Y, context, scratch_ptr);
+        num_dims, dims, num_axes, axes, X, Y, context);
   };
   // Test for 1D tensor.
   RunRedcueTensorTest(
@@ -489,10 +487,9 @@
                               const int* axes,
                               const float* X,
                               float* Y,
-                              CPUContext* context,
-                              TensorCPU* scratch_ptr) {
+                              CPUContext* context) {
     return math::ReduceMax<float, CPUContext>(
-        num_dims, dims, num_axes, axes, X, Y, context, scratch_ptr);
+        num_dims, dims, num_axes, axes, X, Y, context);
   };
   // Test for 1D tensor.
   RunRedcueTensorTest(