fp16 support for FullyConnected op(Fixed)

Summary: This diff resloved some issues in reverted PR246.

Differential Revision: D4911821

fbshipit-source-id: 0a6fa47f4c2405475697e40fb926758c534f8ef7
diff --git a/caffe2/contrib/nervana/nervana_fc_op_gpu.cc b/caffe2/contrib/nervana/nervana_fc_op_gpu.cc
index b232850..8d33a7c 100644
--- a/caffe2/contrib/nervana/nervana_fc_op_gpu.cc
+++ b/caffe2/contrib/nervana/nervana_fc_op_gpu.cc
@@ -5,8 +5,11 @@
 
 namespace caffe2 {
 REGISTER_CUDA_OPERATOR_WITH_ENGINE(
-    FC, NERVANA, FullyConnectedOp<float, CUDAContext, NervanaEngine>);
+    FC,
+    NERVANA,
+    FullyConnectedOp<CUDAContext, NervanaEngine>);
 REGISTER_CUDA_OPERATOR_WITH_ENGINE(
-    FCGradient, NERVANA,
-    FullyConnectedGradientOp<float, CUDAContext, NervanaEngine>);
+    FCGradient,
+    NERVANA,
+    FullyConnectedGradientOp<CUDAContext, NervanaEngine>);
 }  // namespace caffe2
diff --git a/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc b/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc
index a3ae3bb..3eb0fc3 100644
--- a/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc
+++ b/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc
@@ -49,7 +49,7 @@
   AddConstInput(std::vector<int>{6, 10}, 1., "W", &ws);
   AddConstInput(std::vector<int>{6}, 0.1, "B", &ws);
   unique_ptr<OperatorBase> op(
-      new FullyConnectedOp<float, CUDAContext, NervanaEngine>(def, &ws));
+      new FullyConnectedOp<CUDAContext, NervanaEngine>(def, &ws));
   EXPECT_NE(nullptr, op.get());
   EXPECT_TRUE(op->Run());
   Blob* Yblob = ws.GetBlob("Y");
diff --git a/caffe2/contrib/nervana/nervana_math_gpu.cc b/caffe2/contrib/nervana/nervana_math_gpu.cc
index f3010b9..09c70e4 100644
--- a/caffe2/contrib/nervana/nervana_math_gpu.cc
+++ b/caffe2/contrib/nervana/nervana_math_gpu.cc
@@ -11,10 +11,18 @@
 // limitation that the data has to be contiguous in memory.
 template <>
 void Gemm<float, CUDAContext, NervanaEngine>(
-    const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
-    const int M, const int N, const int K, const float alpha, const float* A,
-    const float* B, const float beta, float* C, CUDAContext* context) {
-
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float* A,
+    const float* B,
+    const float beta,
+    float* C,
+    CUDAContext* context,
+    TensorProto::DataType math_type) {
   // Note that cublas follows fortran order, so the order is different from
   // the cblas convention.
   int lda = (TransA == CblasNoTrans) ? K : M;
diff --git a/caffe2/operators/elementwise_op.cu b/caffe2/operators/elementwise_op.cu
index 016a555..e4c89a3 100644
--- a/caffe2/operators/elementwise_op.cu
+++ b/caffe2/operators/elementwise_op.cu
@@ -5,6 +5,7 @@
 #include "caffe2/core/common_gpu.h"
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/operators/elementwise_op.h"
+#include "caffe2/utils/conversions.h"
 
 namespace caffe2 {
 
@@ -62,9 +63,6 @@
     name, BinaryElementwiseOp< \
         input_type, CUDAContext, Cuda##name##Functor, output_type>)
 
-#define CUDA_ADD(x, y) ((x) + (y))
-CUDA_FUNCTOR(Add, CUDA_ADD, NumericTypes, SameTypeAsInput);
-#undef CUDA_ADD
 #define CUDA_SUB(x, y) ((x) - (y))
 CUDA_FUNCTOR(Sub, CUDA_SUB, NumericTypes, SameTypeAsInput);
 #undef CUDA_SUB
@@ -264,4 +262,165 @@
 
 REGISTER_CUDA_OPERATOR(SumReduceLike, SumReduceLikeOp<CUDAContext>);
 
+namespace {
+
+template <bool is_scaler, typename T, typename M>
+__global__ void binary_add_kernel(const int N, const T* a, const T* b, T* r) {
+  CUDA_1D_KERNEL_LOOP(idx, N) {
+    r[idx] = convert::To<M, T>(
+        convert::To<T, M>(a[idx]) +
+        convert::To<T, M>(is_scaler ? b[0] : b[idx]));
+  }
+}
+
+template <bool no_post, typename T, typename M>
+__global__ void binary_add_kernel_broadcast(
+    const T* a,
+    const T* b,
+    T* r,
+    const int pre,
+    const int post,
+    const int n) {
+  CUDA_1D_KERNEL_LOOP(idx, no_post ? pre * n : pre * post * n) {
+    r[idx] = convert::To<M, T>(
+        convert::To<T, M>(a[idx]) +
+        convert::To<T, M>(no_post ? b[idx % n] : b[(idx / post) % n]));
+  }
+}
+} // namespace
+
+// Actual Add operator, because the above macros are read-only.
+class CUDAAddOp final : public Operator<CUDAContext> {
+ public:
+  CUDAAddOp(const OperatorDef& operator_def, Workspace* ws)
+      : Operator<CUDAContext>(operator_def, ws),
+        OP_SINGLE_ARG(bool, "broadcast", enable_broadcast_, 0),
+        OP_SINGLE_ARG(int, "axis", axis_, -1),
+        OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
+        OP_SINGLE_ARG(string, "order", order_, "NCHW") {
+    // Figure out the correct axis to use.
+    if (enable_broadcast_) {
+      if (axis_ != -1) {
+        // Get axis from an explicit axis argument.
+        CAFFE_ENFORCE_EQ(
+            axis_str_.size(),
+            0,
+            "Args axis and axis_str cannot be used simultaneously.");
+      } else if (axis_str_.size()) {
+        // Get the axis index semantically.
+        CAFFE_ENFORCE_EQ(
+            axis_str_.size(), 1, "Unsupported axis string", axis_str_);
+        size_t semantic_axis_ = order_.find(axis_str_);
+        CAFFE_ENFORCE_NE(
+            semantic_axis_,
+            string::npos,
+            "Unrecognizable axis string ",
+            axis_str_,
+            " from order string ",
+            order_);
+        axis_ = semantic_axis_;
+      }
+    } else {
+      CAFFE_ENFORCE(
+          axis_ == -1 && axis_str_.size() == 0,
+          "Do not specify axis or axis_str if broadcast is not enabled.");
+    }
+  }
+
+  ~CUDAAddOp() {}
+
+  template <typename T, typename M>
+  bool DoRunWithType() {
+    auto& X0 = Input(0);
+    auto& X1 = Input(1);
+    auto* output = Output(0);
+
+    output->ResizeLike(X0);
+
+    const T* X0data = X0.template data<T>();
+    const T* X1data = X1.template data<T>();
+    T* outputData = output->template mutable_data<T>();
+
+    if (!enable_broadcast_) {
+      CAFFE_ENFORCE_EQ(
+          X0.dims(),
+          X1.dims(),
+          "Dimension mismatch - did you forget to set broadcast=1?");
+      binary_add_kernel<false, T, M><<<
+          CAFFE_GET_BLOCKS(X0.size()),
+          CAFFE_CUDA_NUM_THREADS,
+          0,
+          context_.cuda_stream()>>>(X0.size(), X0data, X1data, outputData);
+    } else if (X1.size() == 1) {
+      binary_add_kernel<true, T, M><<<
+          CAFFE_GET_BLOCKS(X0.size()),
+          CAFFE_CUDA_NUM_THREADS,
+          0,
+          context_.cuda_stream()>>>(X0.size(), X0data, X1data, outputData);
+    } else {
+      CAFFE_ENFORCE_GT(
+          X0.ndim(),
+          X1.ndim(),
+          "If you are doing broadcasting, input1 should have "
+          "a smaller number of dimensions.");
+      const int axis = (axis_ == -1 ? X0.ndim() - X1.ndim() : axis_);
+      CAFFE_ENFORCE(
+          axis >= 0 && axis < X0.ndim(),
+          "Broadcast axis should be in the range of the number "
+          "of dimensions of the first input.");
+      size_t pre = 1, n = 1, post = 1;
+      for (int i = 0; i < axis; ++i) {
+        pre *= X0.dim(i);
+      }
+      for (int i = 0; i < X1.ndim(); ++i) {
+        CAFFE_ENFORCE_EQ(
+            X0.dim(i + axis), X1.dim(i), "Broadcast dimension mismatch.");
+        n *= X1.dim(i);
+      }
+      for (int i = axis + X1.ndim(); i < X0.ndim(); ++i) {
+        post *= X0.dim(i);
+      }
+
+      if (post == 1) {
+        binary_add_kernel_broadcast<true, T, M><<<
+            CAFFE_GET_BLOCKS(pre * n),
+            CAFFE_CUDA_NUM_THREADS,
+            0,
+            context_.cuda_stream()>>>(X0data, X1data, outputData, pre, post, n);
+      } else {
+        binary_add_kernel_broadcast<false, T, M><<<
+            CAFFE_GET_BLOCKS(pre * post * n),
+            CAFFE_CUDA_NUM_THREADS,
+            0,
+            context_.cuda_stream()>>>(X0data, X1data, outputData, pre, post, n);
+      }
+    }
+    return true;
+  }
+
+  bool RunOnDevice() override {
+    if (Input(0).IsType<float>()) {
+      return DoRunWithType<float, float>();
+    } else if (Input(0).IsType<float16>()) {
+      return DoRunWithType<float16, float>();
+    } else if (Input(0).IsType<int32_t>()) {
+      return DoRunWithType<int32_t, int32_t>();
+    } else if (Input(0).IsType<int64_t>()) {
+      return DoRunWithType<int64_t, int64_t>();
+    } else {
+      return false;
+    }
+  }
+
+ private:
+  bool enable_broadcast_;
+  int axis_;
+  string axis_str_;
+  string order_;
+};
+
+namespace {
+REGISTER_CUDA_OPERATOR(Add, CUDAAddOp);
+} // namespace
+
 }  // namespace caffe2
diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc
index 7a0e0bc..c00f199 100644
--- a/caffe2/operators/fully_connected_op.cc
+++ b/caffe2/operators/fully_connected_op.cc
@@ -3,8 +3,8 @@
 namespace caffe2 {
 namespace {
 
-REGISTER_CPU_OPERATOR(FC, FullyConnectedOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(FCGradient, FullyConnectedGradientOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(FC, FullyConnectedOp<CPUContext>);
+REGISTER_CPU_OPERATOR(FCGradient, FullyConnectedGradientOp<CPUContext>);
 
 OPERATOR_SCHEMA(FC)
   .NumInputs(3)
diff --git a/caffe2/operators/fully_connected_op.h b/caffe2/operators/fully_connected_op.h
index 6e24c91..45adc0a 100644
--- a/caffe2/operators/fully_connected_op.h
+++ b/caffe2/operators/fully_connected_op.h
@@ -3,12 +3,13 @@
 
 #include "caffe2/core/context.h"
 #include "caffe2/core/operator.h"
+#include "caffe2/utils/conversions.h"
 #include "caffe2/utils/math.h"
 
 namespace caffe2 {
 
 // This is Caffe's InnerProductOp, with a name that fits its purpose better.
-template <typename T, class Context, class Engine = DefaultEngine>
+template <class Context, class Engine = DefaultEngine>
 class FullyConnectedOp final : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -17,7 +18,13 @@
         axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {}
   ~FullyConnectedOp() {}
 
-  bool RunOnDevice() override {
+  template <
+      typename T_X,
+      typename T_W,
+      typename T_B,
+      typename T_Y,
+      typename MATH>
+  bool DoRunWithType() {
     const auto& X = Input(0);
     const auto& W = Input(1);
     const auto& b = Input(2);
@@ -63,44 +70,53 @@
     Y->Resize(Y_shape_cache_);
     CAFFE_ENFORCE(M * N == Y->size(), dimErrorString());
 
-    // X * W^T
-    math::Gemm<T, Context, Engine>(
+    // W * x
+    math::Gemm<T_X, Context, Engine>(
         CblasNoTrans,
         CblasTrans,
         M,
         N,
         K,
         1,
-        X.template data<T>(),
-        W.template data<T>(),
+        X.template data<T_X>(),
+        W.template data<T_W>(),
         0,
-        Y->template mutable_data<T>(),
+        Y->template mutable_data<T_Y>(),
         &context_);
     // Add bias term
     if (bias_multiplier_.size() != M) {
       // If the helper bias multiplier is not M, reshape and fill it with one.
       bias_multiplier_.Resize(M);
-      math::Set<T, Context>(
+      math::Set<T_B, Context>(
           M,
-          static_cast<T>(1),
-          bias_multiplier_.template mutable_data<T>(),
+          convert::To<float, T_B>(1),
+          bias_multiplier_.template mutable_data<T_B>(),
           &context_);
     }
-    math::Gemm<T, Context, Engine>(
+    math::Gemm<T_B, Context, Engine>(
         CblasNoTrans,
         CblasNoTrans,
         M,
         N,
         1,
         1,
-        bias_multiplier_.template data<T>(),
-        b.template data<T>(),
+        bias_multiplier_.template data<T_B>(),
+        b.template data<T_B>(),
         1,
-        Y->template mutable_data<T>(),
+        Y->template mutable_data<T_Y>(),
         &context_);
     return true;
   }
 
+  bool RunOnDevice() override {
+    return DoRunWithType<
+        float, // X
+        float, // W
+        float, // B
+        float, // Y
+        float>(); // Math
+  }
+
  protected:
   size_t axis_{1};
   // A local vector to cache the output shape so we don't need to recreate
@@ -109,7 +125,7 @@
   Tensor<Context> bias_multiplier_;
 };
 
-template <typename T, class Context, class Engine = DefaultEngine>
+template <class Context, class Engine = DefaultEngine>
 class FullyConnectedGradientOp : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -118,7 +134,16 @@
         axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {}
   ~FullyConnectedGradientOp() {}
 
-  bool RunOnDevice() override {
+  template <
+      typename T_X,
+      typename T_W,
+      typename T_DY,
+      typename T_B,
+      typename T_DX,
+      typename T_DW,
+      typename T_DB,
+      typename MATH>
+  bool DoRunWithType() {
     const auto& X = Input(0);
     const auto& W = Input(1);
     const auto& dY = Input(2);
@@ -137,60 +162,72 @@
     db->Resize(N);
 
     // Compute dW
-    math::Gemm<T, Context, Engine>(
+    math::Gemm<T_DY, Context, Engine>(
         CblasTrans,
         CblasNoTrans,
         N,
         K,
         M,
-        1,
-        dY.template data<T>(),
-        X.template data<T>(),
-        0,
-        dW->template mutable_data<T>(),
+        convert::To<float, MATH>(1),
+        dY.template data<T_DY>(),
+        X.template data<T_X>(),
+        convert::To<float, MATH>(0),
+        dW->template mutable_data<T_DW>(),
         &context_);
     if (bias_multiplier_.size() != M) {
       // If the helper bias multiplier is not M, reshape and fill it
       // with one.
       bias_multiplier_.Resize(M);
-      math::Set<T, Context>(
+      math::Set<T_B, Context>(
           M,
-          static_cast<T>(1),
-          bias_multiplier_.template mutable_data<T>(),
+          convert::To<float, T_B>(1),
+          bias_multiplier_.template mutable_data<T_B>(),
           &context_);
     }
     // Compute dB
-    math::Gemv<T, Context>(
+    math::Gemv<T_DY, Context>(
         CblasTrans,
         M,
         N,
-        1,
-        dY.template data<T>(),
-        bias_multiplier_.template data<T>(),
-        0,
-        db->template mutable_data<T>(),
+        convert::To<float, MATH>(1),
+        dY.template data<T_DY>(),
+        bias_multiplier_.template data<T_B>(),
+        convert::To<float, MATH>(0),
+        db->template mutable_data<T_DB>(),
         &context_);
 
     // Compute dX
     if (OutputSize() == 3) {
       auto* dX = Output(2);
       dX->ResizeLike(X);
-      math::Gemm<T, Context, Engine>(
+      math::Gemm<T_DX, Context, Engine>(
           CblasNoTrans,
           CblasNoTrans,
           M,
           K,
           N,
-          1,
-          dY.template data<T>(),
-          W.template data<T>(),
-          0,
-          dX->template mutable_data<T>(),
+          convert::To<float, MATH>(1),
+          dY.template data<T_DY>(),
+          W.template data<T_W>(),
+          convert::To<float, MATH>(0),
+          dX->template mutable_data<T_DX>(),
           &context_);
     }
     return true;
   }
 
+  bool RunOnDevice() override {
+    return DoRunWithType<
+        float, //  X
+        float, //  W
+        float, // dY
+        float, //  B
+        float, // dX
+        float, // dW
+        float, // dB
+        float>(); // Math
+  }
+
  protected:
   size_t axis_{1};
   Tensor<Context> bias_multiplier_;
diff --git a/caffe2/operators/fully_connected_op_gpu.cc b/caffe2/operators/fully_connected_op_gpu.cc
index 8ee67ac..0743186 100644
--- a/caffe2/operators/fully_connected_op_gpu.cc
+++ b/caffe2/operators/fully_connected_op_gpu.cc
@@ -2,9 +2,60 @@
 #include "caffe2/operators/fully_connected_op.h"
 
 namespace caffe2 {
+
+template <>
+bool FullyConnectedOp<CUDAContext>::RunOnDevice() {
+  if (Input(0).IsType<float>()) {
+    return DoRunWithType<
+        float, // X
+        float, // W
+        float, // B
+        float, // Y
+        float>(); // Math
+  } else if (Input(0).IsType<float16>()) {
+    return DoRunWithType<
+        float16, // X
+        float16, // W
+        float16, // B
+        float16, // Y
+        float>(); // Math
+  } else {
+    CAFFE_THROW("Unsupported type");
+  }
+  return false;
+}
+
+template <>
+bool FullyConnectedGradientOp<CUDAContext>::RunOnDevice() {
+  if (Input(0).IsType<float>()) {
+    return DoRunWithType<
+        float, //  X
+        float, //  W
+        float, // dY
+        float, //  B
+        float, // dX
+        float, // dW
+        float, // dB
+        float>(); // Math
+  } else if (Input(0).IsType<float16>()) {
+    return DoRunWithType<
+        float16, //  X
+        float16, //  W
+        float16, // dY
+        float16, //  B
+        float16, // dX
+        float16, // dW
+        float16, // dB
+        float>(); // Math
+  } else {
+    CAFFE_THROW("Unsupported type");
+  }
+  return false;
+}
+
 namespace {
-REGISTER_CUDA_OPERATOR(FC, FullyConnectedOp<float, CUDAContext>);
-REGISTER_CUDA_OPERATOR(FCGradient,
-                       FullyConnectedGradientOp<float, CUDAContext>);
+
+REGISTER_CUDA_OPERATOR(FC, FullyConnectedOp<CUDAContext>);
+REGISTER_CUDA_OPERATOR(FCGradient, FullyConnectedGradientOp<CUDAContext>);
 }  // namespace
 }  // namespace caffe2
diff --git a/caffe2/operators/sparse_to_dense_op.h b/caffe2/operators/sparse_to_dense_op.h
index d48b617..439d96c 100644
--- a/caffe2/operators/sparse_to_dense_op.h
+++ b/caffe2/operators/sparse_to_dense_op.h
@@ -50,7 +50,6 @@
     return DispatchHelper<
         TensorTypes2<
             float,
-            double,
             int32_t,
             int64_t,
             GenericTensorImplementation>,
diff --git a/caffe2/operators/square_root_divide_op.h b/caffe2/operators/square_root_divide_op.h
index 644c2bd..df018bf 100644
--- a/caffe2/operators/square_root_divide_op.h
+++ b/caffe2/operators/square_root_divide_op.h
@@ -17,7 +17,7 @@
       : Operator<Context>(operator_def, ws) {}
 
   bool RunOnDevice() override {
-    return DispatchHelper<TensorTypes<float, double>>::call(this, Input(DATA));
+    return DispatchHelper<TensorTypes<float>>::call(this, Input(DATA));
   }
 
  private:
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index 771da9c..99f0f20 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -3,6 +3,12 @@
 #include <cmath>
 
 namespace caffe2 {
+
+template <>
+bool WeightedSumOp<CPUContext>::RunOnDevice() {
+  return DoRunWithType<float>();
+}
+
 namespace {
 
 REGISTER_CPU_OPERATOR(WallClockTime, WallClockTimeOp<CPUContext>);
@@ -12,10 +18,9 @@
 
 REGISTER_CPU_OPERATOR(Alias, AliasOp<CPUContext>);
 REGISTER_CPU_OPERATOR(ResizeLike, ResizeLikeOp<CPUContext>);
-REGISTER_CPU_OPERATOR(Sum, SumOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(SumInt, SumOp<int, CPUContext>);
-
-REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(Sum, SumOp<CPUContext>);
+REGISTER_CPU_OPERATOR(SumInt, SumOp<CPUContext>);
+REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp<CPUContext>);
 REGISTER_CPU_OPERATOR(
     ScatterWeightedSum,
     ScatterWeightedSumOp<float, CPUContext>);
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index 1722e14..28321b2 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -250,13 +250,14 @@
   }
 };
 
-template <typename T, class Context>
+template <class Context>
 class SumOp : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
   USE_SIMPLE_CTOR_DTOR(SumOp);
 
-  bool RunOnDevice() override {
+  template <typename T, typename M>
+  bool DoRunWithType() {
     auto& input0 = Input(0);
     auto* output = Output(0);
     if (InputSize() == 1) {
@@ -297,6 +298,16 @@
     }
     return true;
   }
+
+  bool RunOnDevice() override {
+    if (Input(0).template IsType<float>()) {
+      return DoRunWithType<float, float>();
+    } else if (Input(0).template IsType<int>()) {
+      return DoRunWithType<int, int>();
+    } else {
+      return false;
+    }
+  }
 };
 
 // WeightedSumOp computes the weighted sum of several tensors. The input should
@@ -304,13 +315,14 @@
 // shape, and weight_i are size 1 tensors that specifies the weight of each
 // vector. Note that if one wants to do in-place computation, it could only be
 // done with X_0 also as the output, but not other X_i.
-template <typename T, class Context>
+template <class Context>
 class WeightedSumOp : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
   USE_SIMPLE_CTOR_DTOR(WeightedSumOp);
 
-  bool RunOnDevice() override {
+  template <typename DstType>
+  bool DoRunWithType() {
     DCHECK_EQ(InputSize() % 2, 0);
     auto& X0 = Input(0);
     auto& weight0 = Input(1);
@@ -319,11 +331,11 @@
     int size = X0.size();
     auto* output = Output(0);
     output->ResizeLike(X0);
-    math::Scale<T, Context>(
+    math::Scale<DstType, Context>(
         size,
-        weight0.template data<T>(),
-        X0.template data<T>(),
-        output->template mutable_data<T>(),
+        weight0.template data<float>(),
+        X0.template data<DstType>(),
+        output->template mutable_data<DstType>(),
         &context_);
     for (int i = 2; i < InputSize(); i += 2) {
       auto& X = Input(i);
@@ -338,15 +350,16 @@
       auto& weight = Input(i + 1);
       DCHECK_EQ(X.size(), size);
       DCHECK_EQ(weight.size(), 1);
-      math::Axpy<T, Context>(
+      math::Axpy<DstType, Context>(
           size,
-          weight.template data<T>(),
-          X.template data<T>(),
-          output->template mutable_data<T>(),
+          weight.template data<float>(),
+          X.template data<DstType>(),
+          output->template mutable_data<DstType>(),
           &context_);
     }
     return true;
   }
+  bool RunOnDevice() override;
 };
 
 /**
diff --git a/caffe2/operators/utility_ops_gpu.cc b/caffe2/operators/utility_ops_gpu.cc
index b3df226..7d41fa2 100644
--- a/caffe2/operators/utility_ops_gpu.cc
+++ b/caffe2/operators/utility_ops_gpu.cc
@@ -5,6 +5,30 @@
 namespace caffe2 {
 
 template <>
+bool WeightedSumOp<CUDAContext>::RunOnDevice() {
+  if (Input(0).IsType<float>()) {
+    return DoRunWithType<float>();
+  } else if (Input(0).IsType<float16>()) {
+    return DoRunWithType<float16>();
+  } else {
+    CAFFE_THROW("Unsupported inputs");
+  }
+  return false;
+}
+
+template <>
+bool SumOp<CUDAContext>::RunOnDevice() {
+  if (Input(0).IsType<float>()) {
+    return DoRunWithType<float, float>();
+  } else if (Input(0).IsType<float16>()) {
+    return DoRunWithType<float16, float16>();
+  } else {
+    CAFFE_THROW("Unsupported inputs");
+  }
+  return false;
+}
+
+template <>
 class CopyOnDeviceLikeOp<CUDAContext, CUDAContext, CUDAContext>
     : public Operator<CUDAContext> {
  public:
@@ -35,9 +59,8 @@
 REGISTER_CUDA_OPERATOR(ExpandDims, ExpandDimsOp<CUDAContext>);
 REGISTER_CUDA_OPERATOR(Alias, AliasOp<CUDAContext>);
 REGISTER_CUDA_OPERATOR(ResizeLike, ResizeLikeOp<CUDAContext>);
-REGISTER_CUDA_OPERATOR(Sum, SumOp<float, CUDAContext>);
-
-REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp<float, CUDAContext>);
+REGISTER_CUDA_OPERATOR(Sum, SumOp<CUDAContext>);
+REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp<CUDAContext>);
 REGISTER_CUDA_OPERATOR(Shape, ShapeOp<CUDAContext>);
 // From whatever the current context, ensure the output is TensorCPU
 REGISTER_CUDA_OPERATOR(
diff --git a/caffe2/utils/conversions.h b/caffe2/utils/conversions.h
new file mode 100644
index 0000000..0c6c323
--- /dev/null
+++ b/caffe2/utils/conversions.h
@@ -0,0 +1,182 @@
+#pragma once
+
+#include <caffe2/core/types.h>
+
+#ifdef __CUDA_ARCH__
+#include <cuda_fp16.h>
+#endif
+
+#ifdef __CUDA_ARCH__
+#define CONVERSIONS_DECL __host__ __device__ inline
+#else
+#define CONVERSIONS_DECL inline
+#endif
+
+namespace caffe2 {
+
+namespace convert {
+
+namespace {
+inline float16 cpu_float2half_rn(float f) {
+  float16 ret;
+
+  static_assert(
+      sizeof(unsigned int) == sizeof(float),
+      "Programming error sizeof(unsigned int) != sizeof(float)");
+
+  unsigned* xp = reinterpret_cast<unsigned int*>(&f);
+  unsigned x = *xp;
+  unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
+  unsigned sign, exponent, mantissa;
+
+  // Get rid of +NaN/-NaN case first.
+  if (u > 0x7f800000) {
+    ret.x = 0x7fffU;
+    return ret;
+  }
+
+  sign = ((x >> 16) & 0x8000);
+
+  // Get rid of +Inf/-Inf, +0/-0.
+  if (u > 0x477fefff) {
+    ret.x = sign | 0x7c00U;
+    return ret;
+  }
+  if (u < 0x33000001) {
+    ret.x = (sign | 0x0000);
+    return ret;
+  }
+
+  exponent = ((u >> 23) & 0xff);
+  mantissa = (u & 0x7fffff);
+
+  if (exponent > 0x70) {
+    shift = 13;
+    exponent -= 0x70;
+  } else {
+    shift = 0x7e - exponent;
+    exponent = 0;
+    mantissa |= 0x800000;
+  }
+  lsb = (1 << shift);
+  lsb_s1 = (lsb >> 1);
+  lsb_m1 = (lsb - 1);
+
+  // Round to nearest even.
+  remainder = (mantissa & lsb_m1);
+  mantissa >>= shift;
+  if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
+    ++mantissa;
+    if (!(mantissa & 0x3ff)) {
+      ++exponent;
+      mantissa = 0;
+    }
+  }
+
+  ret.x = (sign | (exponent << 10) | mantissa);
+
+  return ret;
+}
+
+inline float cpu_half2float(float16 h) {
+  unsigned sign = ((h.x >> 15) & 1);
+  unsigned exponent = ((h.x >> 10) & 0x1f);
+  unsigned mantissa = ((h.x & 0x3ff) << 13);
+
+  if (exponent == 0x1f) { /* NaN or Inf */
+    mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
+    exponent = 0xff;
+  } else if (!exponent) { /* Denorm or Zero */
+    if (mantissa) {
+      unsigned int msb;
+      exponent = 0x71;
+      do {
+        msb = (mantissa & 0x400000);
+        mantissa <<= 1; /* normalize */
+        --exponent;
+      } while (!msb);
+      mantissa &= 0x7fffff; /* 1.mantissa is implicit */
+    }
+  } else {
+    exponent += 0x70;
+  }
+
+  int temp = ((sign << 31) | (exponent << 23) | mantissa);
+
+  unsigned* rp = reinterpret_cast<unsigned*>(&temp);
+  return *rp;
+}
+
+}; // anonymous
+// general version: defer to static_cast
+template <typename IN, typename OUT>
+CONVERSIONS_DECL OUT To(const IN in) {
+  return static_cast<OUT>(in);
+}
+
+#if __CUDA_ARCH__
+__device__ __inline__ __half inf_clip(__half h) {
+  int isi = __hisinf(h);
+  if (isi > 0) {
+    // Exponent all ones except LSB (0x1e), mantissa is all ones (0x3ff)
+    h.x = 0x7bffU;
+  } else if (isi < 0) {
+    // As above, negated
+    h.x = 0x7bffU ^ 0x8000;
+  }
+  return h;
+}
+#endif
+
+// explicit for fp16
+template <>
+CONVERSIONS_DECL float16 To(const float in) {
+#if __CUDA_ARCH__
+  // hacky interface between C2 fp16 and CUDA
+  float16 ret;
+  __half r;
+  // r.x = __float2half_rn(in);
+  // ret.x = inf_clip(r).x;
+  ret.x = __float2half(in).x;
+  return ret;
+#else
+  return cpu_float2half_rn(in);
+#endif
+}
+
+template <>
+CONVERSIONS_DECL float To(const float16 in) {
+#if __CUDA_ARCH__
+  __half tmp;
+  tmp.x = in.x;
+  return __half2float(tmp);
+#else
+  return cpu_half2float(in);
+#endif
+};
+
+template <>
+CONVERSIONS_DECL float To(const float in) {
+  return in;
+}
+
+template <typename OUT, typename IN>
+CONVERSIONS_DECL OUT Get(IN x) {
+  return static_cast<OUT>(x);
+}
+
+template <>
+CONVERSIONS_DECL float Get(float16 x) {
+  return To<float16, float>(x);
+}
+
+template <>
+CONVERSIONS_DECL float16 Get(float x) {
+  return To<float, float16>(x);
+}
+
+}; // namespace convert
+
+}; // namespace caffe2
+
+#undef CONVERSIONS_DECL
diff --git a/caffe2/utils/math-detail.h b/caffe2/utils/math-detail.h
index 35a880a..07a1f99 100644
--- a/caffe2/utils/math-detail.h
+++ b/caffe2/utils/math-detail.h
@@ -11,8 +11,12 @@
 
 template<typename T, class Context, int FixedSize>
 struct ScaleImpl {
-  inline void
-  operator()(const int N, const T alpha, const T* x, T* y, Context* context) {
+  inline void operator()(
+      const int N,
+      const float alpha,
+      const T* x,
+      T* y,
+      Context* context) {
     Scale(N, alpha, x, y, context);
   }
 };
@@ -22,7 +26,7 @@
 struct ScaleImpl<T, CPUContext, 1> {
   inline void operator()(
       const int N,
-      const T alpha,
+      const float alpha,
       const T* x,
       T* y,
       CPUContext* context) {
@@ -33,8 +37,12 @@
 
 template<typename T, class Context, int FixedSize>
 struct AxpyImpl {
-  inline void
-  operator()(const int N, const T alpha, const T* x, T* y, Context* context) {
+  inline void operator()(
+      const int N,
+      const float alpha,
+      const T* x,
+      T* y,
+      Context* context) {
     Axpy(N, alpha, x, y, context);
   }
 };
@@ -44,7 +52,7 @@
 struct AxpyImpl<T, CPUContext, 1> {
   inline void operator()(
       const int N,
-      const T alpha,
+      const float alpha,
       const T* x,
       T* y,
       CPUContext* context) {
@@ -57,14 +65,22 @@
 }  // namespace detail
 
 template <typename T, class Context, int FixedSize>
-inline void
-ScaleFixedSize(const int N, const T alpha, const T* x, T* y, Context* context) {
+inline void ScaleFixedSize(
+    const int N,
+    const float alpha,
+    const T* x,
+    T* y,
+    Context* context) {
   detail::ScaleImpl<T, Context, FixedSize>()(N, alpha, x, y, context);
 }
 
 template <typename T, class Context, int FixedSize>
-inline void
-AxpyFixedSize(const int N, const T alpha, const T* x, T* y, Context* context) {
+inline void AxpyFixedSize(
+    const int N,
+    const float alpha,
+    const T* x,
+    T* y,
+    Context* context) {
   detail::AxpyImpl<T, Context, FixedSize>()(N, alpha, x, y, context);
 }
 
diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h
index a2472c0..105cb19 100644
--- a/caffe2/utils/math.h
+++ b/caffe2/utils/math.h
@@ -141,10 +141,20 @@
 
 // Decaf gemm provides a simpler interface to the gemm functions, with the
 // limitation that the data has to be contiguous in memory.
-template <typename T, class Context, class Engine=DefaultEngine>
-void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
-    const int M, const int N, const int K, const T alpha, const T* A,
-    const T* B, const T beta, T* C, Context* context);
+template <typename T, class Context, class Engine = DefaultEngine>
+void Gemm(
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const T* A,
+    const T* B,
+    const float beta,
+    T* C,
+    Context* context,
+    TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
 
 // We also provide a gemm that has explicit lda, ldb and ldc specified.
 // In most cases you probably want to use the function above, though.
@@ -169,10 +179,18 @@
 // to Trans, the output is:
 // CblasNoTrans: x is an N dim vector and y is an M dim vector.
 // CblasTrans:   x is an M dim vector and y is an N dim vector.
-template <typename T, class Context, class Engine=DefaultEngine>
-void Gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
-    const T alpha, const T* A, const T* x, const T beta,
-    T* y, Context* context);
+template <typename T, class Context, class Engine = DefaultEngine>
+void Gemv(
+    const CBLAS_TRANSPOSE TransA,
+    const int M,
+    const int N,
+    const float alpha,
+    const T* A,
+    const T* x,
+    const float beta,
+    T* y,
+    Context* context,
+    TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
 
 template <typename T, class Context>
 void Set(const TIndex N, const T alpha, T* X, Context* context);
@@ -218,28 +236,31 @@
             Context* context);
 
 template <typename T, class Context>
-void Scale(const int N, const T alpha, const T* x, T* y, Context* context);
+void Scale(const int N, const float alpha, const T* x, T* y, Context* context);
 
 // Different from the Scale function above, if alpha is passed in
 // as a pointer, we will assume that it lives on the Context device,
 // for example on GPU.
 template <typename T, class Context>
-void Scale(const int N, const T* alpha, const T* x, T* y,
-           Context* context);
+void Scale(const int N, const float* alpha, const T* x, T* y, Context* context);
 
 template <typename T, class Context>
-void Axpy(const int N, const T alpha, const T* x, T* y, Context* context);
+void Axpy(const int N, const float alpha, const T* x, T* y, Context* context);
 
 // Different from the Axpy function above, if alpha is passed in
 // as a pointer, we will assume that it lives on the Context device,
 // for example on GPU.
 template <typename T, class Context>
-void Axpy(const int N, const T* alpha, const T* x, T* y,
-          Context* context);
+void Axpy(const int N, const float* alpha, const T* x, T* y, Context* context);
 
 template <typename T, class Context>
-void Axpby(const int N, const T alpha, const T* x, const T b, T* y,
-           Context* context);
+void Axpby(
+    const int N,
+    const float alpha,
+    const T* x,
+    const T b,
+    T* y,
+    Context* context);
 
 template <typename T, class Context, int order>
 void Im2colNd(
diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc
index 5cac0c8..e4340df 100644
--- a/caffe2/utils/math_cpu.cc
+++ b/caffe2/utils/math_cpu.cc
@@ -58,9 +58,18 @@
 // CblasTrans, respectively, for each of A and B.
 template <>
 void Gemm<float, CPUContext>(
-    const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
-    const int M, const int N, const int K, const float alpha, const float* A,
-    const float* B, const float beta, float* C, CPUContext* context) {
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float* A,
+    const float* B,
+    const float beta,
+    float* C,
+    CPUContext* context,
+    TensorProto::DataType math_type) {
   auto C_mat = EigenMatrixMap<float>(C, N, M);
   if (beta == 0) {
     C_mat.setZero();
@@ -178,7 +187,8 @@
     const float* x,
     const float beta,
     float* y,
-    CPUContext* context) {
+    CPUContext* context,
+    TensorProto::DataType math_type) {
   EigenVectorMap<float> y_vec(y, TransA == CblasNoTrans ? M : N);
   if (beta == 0) {
     // In Caffe2 we often do a lazy initialization, which may contain NaNs in
@@ -205,19 +215,22 @@
   }
 }
 
-#define CAFFE2_SPECIALIZED_SCALE(T)                                         \
-  template <>                                                               \
-  void Scale<T, CPUContext>(                                                \
-      const int n, const T alpha, const T* x, T* y, CPUContext* context) {  \
-    EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * alpha;         \
-  }                                                                         \
-  template <>                                                               \
-  void Scale<T, CPUContext>(                                                \
-      const int n, const T* alpha, const T* x, T* y, CPUContext* context) { \
-    EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * (*alpha);      \
+#define CAFFE2_SPECIALIZED_SCALE(T)                                            \
+  template <>                                                                  \
+  void Scale<T, CPUContext>(                                                   \
+      const int n, const float alpha, const T* x, T* y, CPUContext* context) { \
+    EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * alpha;            \
+  }                                                                            \
+  template <>                                                                  \
+  void Scale<T, CPUContext>(                                                   \
+      const int n,                                                             \
+      const float* alpha,                                                      \
+      const T* x,                                                              \
+      T* y,                                                                    \
+      CPUContext* context) {                                                   \
+    EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * (*alpha);         \
   }
 CAFFE2_SPECIALIZED_SCALE(float)
-CAFFE2_SPECIALIZED_SCALE(double)
 #undef CAFFE2_SPECIALIZED_SCALE
 
 #define CAFFE2_SPECIALIZED_DOT(T)                                              \
@@ -228,7 +241,6 @@
   *y = ConstEigenVectorMap<T>(a, N).dot(ConstEigenVectorMap<T>(b, N));         \
 }
 CAFFE2_SPECIALIZED_DOT(float)
-CAFFE2_SPECIALIZED_DOT(double)
 #undef CAFFE2_SPECIALIZED_DOT
 
 #define CAFFE2_SPECIALIZED_AXPY(T)                                          \
@@ -243,7 +255,6 @@
     EigenVectorMap<T>(Y, N) += ConstEigenVectorMap<T>(x, N) * (*alpha);     \
   }
 CAFFE2_SPECIALIZED_AXPY(float)
-CAFFE2_SPECIALIZED_AXPY(double)
 #undef CAFFE2_SPECIALIZED_AXPY
 
 #define CAFFE2_SPECIALIZED_AXPBY(T)                                            \
@@ -254,16 +265,24 @@
   y_vec = y_vec * beta + ConstEigenVectorMap<T>(x, N) * alpha;                 \
 }
 CAFFE2_SPECIALIZED_AXPBY(float)
-CAFFE2_SPECIALIZED_AXPBY(double)
 #undef CAFFE2_SPECIALIZED_AXPBY
 
 #else  // CAFFE2_USE_EIGEN_FOR_BLAS
 
 template <>
 void Gemm<float, CPUContext>(
-    const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
-    const int M, const int N, const int K, const float alpha, const float* A,
-    const float* B, const float beta, float* C, CPUContext* context) {
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float* A,
+    const float* B,
+    const float beta,
+    float* C,
+    CPUContext* context,
+    TensorProto::DataType math_type) {
   int lda = (TransA == CblasNoTrans) ? K : M;
   int ldb = (TransB == CblasNoTrans) ? N : K;
   cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, ldb,
@@ -292,29 +311,39 @@
 
 template <>
 void Gemv<float, CPUContext>(
-    const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha,
-    const float* A, const float* x, const float beta, float* y,
-    CPUContext* context) {
+    const CBLAS_TRANSPOSE TransA,
+    const int M,
+    const int N,
+    const float alpha,
+    const float* A,
+    const float* x,
+    const float beta,
+    float* y,
+    CPUContext* context,
+    TensorProto::DataType math_type) {
   cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
 }
 
-#define CAFFE2_SPECIALIZED_SCALE(T, prefix)                                 \
-  template <>                                                               \
-  void Scale<T, CPUContext>(                                                \
-      const int n, const T alpha, const T* x, T* y, CPUContext* context) {  \
-    if (y != x)                                                             \
-      cblas_##prefix##copy(n, x, 1, y, 1);                                  \
-    cblas_##prefix##scal(n, alpha, y, 1);                                   \
-  }                                                                         \
-  template <>                                                               \
-  void Scale<T, CPUContext>(                                                \
-      const int n, const T* alpha, const T* x, T* y, CPUContext* context) { \
-    if (y != x)                                                             \
-      cblas_##prefix##copy(n, x, 1, y, 1);                                  \
-    cblas_##prefix##scal(n, *alpha, y, 1);                                  \
+#define CAFFE2_SPECIALIZED_SCALE(T, prefix)                                    \
+  template <>                                                                  \
+  void Scale<T, CPUContext>(                                                   \
+      const int n, const float alpha, const T* x, T* y, CPUContext* context) { \
+    if (y != x)                                                                \
+      cblas_##prefix##copy(n, x, 1, y, 1);                                     \
+    cblas_##prefix##scal(n, static_cast<float>(alpha), y, 1);                  \
+  }                                                                            \
+  template <>                                                                  \
+  void Scale<T, CPUContext>(                                                   \
+      const int n,                                                             \
+      const float* alpha,                                                      \
+      const T* x,                                                              \
+      T* y,                                                                    \
+      CPUContext* context) {                                                   \
+    if (y != x)                                                                \
+      cblas_##prefix##copy(n, x, 1, y, 1);                                     \
+    cblas_##prefix##scal(n, static_cast<float>(*alpha), y, 1);                 \
   }
 CAFFE2_SPECIALIZED_SCALE(float, s)
-CAFFE2_SPECIALIZED_SCALE(double, d)
 #undef CAFFE2_SPECIALIZED_SCALE
 
 #define CAFFE2_SPECIALIZED_DOT(T, prefix)                                      \
@@ -325,7 +354,6 @@
   *y = cblas_##prefix##dot(N, a, 1, b, 1);                                     \
 }
 CAFFE2_SPECIALIZED_DOT(float, s)
-CAFFE2_SPECIALIZED_DOT(double, d)
 #undef CAFFE2_SPECIALIZED_DOT
 
 #define CAFFE2_SPECIALIZED_AXPY(T, prefix)                                  \
@@ -340,7 +368,6 @@
     cblas_##prefix##axpy(N, *alpha, x, 1, y, 1);                            \
   }
 CAFFE2_SPECIALIZED_AXPY(float, s)
-CAFFE2_SPECIALIZED_AXPY(double, d)
 #undef CAFFE2_SPECIALIZED_AXPY
 
 // cblas_[sd]axpby is not a standard blas function, and if MKL is not present,
@@ -362,7 +389,6 @@
 }
 #endif  // CAFFE2_USE_MKL
 CAFFE2_SPECIALIZED_AXPBY(float, s)
-CAFFE2_SPECIALIZED_AXPBY(double, d)
 #undef CAFFE2_SPECIALIZED_AXPBY
 
 #endif  // CAFFE2_USE_EIGEN_FOR_BLAS
@@ -436,11 +462,8 @@
   EigenVectorMap<T>(y, N) = ConstEigenVectorMap<T>(x, N).array().expr();       \
 }
 DELEGATE_SIMPLE_UNARY_FUNCTION(float, Exp, exp)
-DELEGATE_SIMPLE_UNARY_FUNCTION(double, Exp, exp)
 DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, log)
-DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, log)
 DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, square)
-DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, square)
 #undef DELEGATE_SIMPLE_UNARY_FUNCTION
 
 #define DELEGATE_POWX_FUNCTION(T)                                              \
@@ -450,7 +473,6 @@
   EigenVectorMap<T>(y, N) = ConstEigenVectorMap<T>(a, N).array().pow(b);       \
 }
 DELEGATE_POWX_FUNCTION(float)
-DELEGATE_POWX_FUNCTION(double)
 #undef DELEGATE_POWX_FUNCTION
 
 #endif  // CAFFE2_USE_MKL
@@ -476,7 +498,6 @@
 
 #define DEFINE_SIMPLE_BINARY_FUNCTION(Funcname, expr)                          \
 EIGEN_SIMPLE_BINARY_FUNCTION(float, Funcname, expr)                            \
-EIGEN_SIMPLE_BINARY_FUNCTION(double, Funcname, expr)                           \
 EIGEN_SIMPLE_BINARY_FUNCTION(int32_t, Funcname, expr)                          \
 EIGEN_SIMPLE_BINARY_FUNCTION(int64_t, Funcname, expr)
 
@@ -546,7 +567,6 @@
   DELEGATE_BROADCAST_BINARY_FUNCTION(int32_t, name, op)                  \
   DELEGATE_BROADCAST_BINARY_FUNCTION(int64_t, name, op)                  \
   DELEGATE_BROADCAST_BINARY_FUNCTION(float, name, op)                    \
-  DELEGATE_BROADCAST_BINARY_FUNCTION(double, name, op)
 
 DEFINE_BROADCAST_BINARY_FUNCTION(Add, +)
 DEFINE_BROADCAST_BINARY_FUNCTION(Sub, -)
@@ -602,7 +622,6 @@
 
 #define CAFFE2_DEFINE_BINARY_OP(name, op)         \
   CAFFE2_INSTANTIATE_BINARY_OP(name, op, float)   \
-  CAFFE2_INSTANTIATE_BINARY_OP(name, op, double)  \
   CAFFE2_INSTANTIATE_BINARY_OP(name, op, int32_t) \
   CAFFE2_INSTANTIATE_BINARY_OP(name, op, int64_t)
 
@@ -644,7 +663,6 @@
   }
 
 CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH(float);
-CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH(double);
 #undef CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH
 
 template <>
@@ -717,7 +735,6 @@
   }
 
 CAFFE2_SPECIALIZED_SUM(float);
-CAFFE2_SPECIALIZED_SUM(double);
 CAFFE2_SPECIALIZED_SUM(int32_t);
 CAFFE2_SPECIALIZED_SUM(int64_t);
 
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index cfbe91b..46c5bc0 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -5,8 +5,9 @@
 #include <thrust/system/cuda/detail/par.h>
 #include <thrust/version.h>
 
-#include "caffe2/utils/math.h"
 #include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/conversions.h"
+#include "caffe2/utils/math.h"
 
 #if THRUST_VERSION >= 100800
 #define THRUST_SUPPORTS_PER_THREAD
@@ -32,33 +33,30 @@
 }
 
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf);
-DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Exp, exp);
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf);
-DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Log, log);
 
 __device__ float cuda_sqrf(const float x) { return x * x; }
-__device__ double cuda_sqr(const double x) { return x * x; }
 
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, cuda_sqrf);
-DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sqr, cuda_sqr);
 
 #undef DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION
 
-#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Funcname, expr)          \
-  __global__ void _Kernel_##T##_##Funcname(                              \
-      const int N, const T* a, const T* b, T* y) {                       \
-    CUDA_1D_KERNEL_LOOP(i, N) {                                          \
-      y[i] = a[i] expr b[i];                                             \
-    }                                                                    \
-  }                                                                      \
-  template <>                                                            \
-  void Funcname<T, CUDAContext>(                                         \
-      const int N, const T* a, const T* b, T* y, CUDAContext* context) { \
-    _Kernel_##T##_##Funcname<<<                                          \
-        CAFFE_GET_BLOCKS(N),                                             \
-        CAFFE_CUDA_NUM_THREADS,                                          \
-        0,                                                               \
-        context->cuda_stream()>>>(N, a, b, y);                           \
+#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Funcname, expr)               \
+  __global__ void _Kernel_##T##_##Funcname(                                   \
+      const int N, const T* a, const T* b, T* y) {                            \
+    CUDA_1D_KERNEL_LOOP(i, N) {                                               \
+      float r = convert::To<T, float>(a[i]) expr convert::To<T, float>(b[i]); \
+      y[i] = convert::To<float, T>(r);                                        \
+    }                                                                         \
+  }                                                                           \
+  template <>                                                                 \
+  void Funcname<T, CUDAContext>(                                              \
+      const int N, const T* a, const T* b, T* y, CUDAContext* context) {      \
+    _Kernel_##T##_##Funcname<<<                                               \
+        CAFFE_GET_BLOCKS(N),                                                  \
+        CAFFE_CUDA_NUM_THREADS,                                               \
+        0,                                                                    \
+        context->cuda_stream()>>>(N, a, b, y);                                \
   }
 
 DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Add, +);
@@ -66,13 +64,27 @@
 DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Mul, *);
 DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Div, /);
 
+DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Add, +);
+DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Sub, -);
+DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Mul, *);
+DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Div, /);
+
 // Caffe2 gemm provides a simpler interface to the gemm functions, with the
 // limitation that the data has to be contiguous in memory.
 template <>
 void Gemm<float, CUDAContext>(
-    const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
-    const int M, const int N, const int K, const float alpha, const float* A,
-    const float* B, const float beta, float* C, CUDAContext* context) {
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float* A,
+    const float* B,
+    const float beta,
+    float* C,
+    CUDAContext* context,
+    TensorProto::DataType math_type) {
   // Note that cublas follows fortran order, so the order is different from
   // the cblas convention.
   int lda = (TransA == CblasNoTrans) ? K : M;
@@ -99,11 +111,91 @@
 }
 
 template <>
+void Gemm<float16, CUDAContext>(
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float16* A,
+    const float16* B,
+    const float beta,
+    float16* C,
+    CUDAContext* context,
+    TensorProto::DataType math_type) {
+  // Note that cublas follows fortran order, so the order is different from
+  // the cblas convention.
+  int lda = (TransA == CblasNoTrans) ? K : M;
+  int ldb = (TransB == CblasNoTrans) ? N : K;
+  cublasOperation_t cuTransA =
+      (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  cublasOperation_t cuTransB =
+      (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  if (math_type == TensorProto_DataType_FLOAT) {
+    CUBLAS_CHECK(cublasSgemmEx(
+        context->cublas_handle(),
+        cuTransB,
+        cuTransA,
+        N,
+        M,
+        K,
+        &alpha,
+        B,
+        CUDA_R_16F,
+        ldb,
+        A,
+        CUDA_R_16F,
+        lda,
+        &beta,
+        C,
+        CUDA_R_16F,
+        N));
+
+  } else if (math_type == TensorProto_DataType_FLOAT16) {
+    // convert alpha, beta from caffe2::float16 -> __half
+    __half alpha_fp16;
+    alpha_fp16.x = convert::To<float, float16>(alpha).x;
+    __half beta_fp16;
+    beta_fp16.x = convert::To<float, float16>(beta).x;
+    // call cublasHgemm
+    CUBLAS_CHECK(cublasHgemm(
+        context->cublas_handle(),
+        cuTransB,
+        cuTransA,
+        N,
+        M,
+        K,
+        &alpha_fp16,
+        (const __half*)B,
+        ldb,
+        (const __half*)A,
+        lda,
+        &beta_fp16,
+        (__half*)C,
+        N));
+  } else {
+    // fail
+    CAFFE_THROW("Unsupported math type");
+  }
+}
+
+template <>
 void GemmEx<float, CUDAContext>(
-    const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
-    const int M, const int N, const int K, const float alpha, const float* A,
-    const int lda, const float* B, const int ldb, const float beta, float* C,
-    const int ldc, CUDAContext* context) {
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float* A,
+    const int lda,
+    const float* B,
+    const int ldb,
+    const float beta,
+    float* C,
+    const int ldc,
+    CUDAContext* context) {
   // Note that cublas follows fortran order, so the order is different from
   // the cblas convention.
   cublasOperation_t cuTransA =
@@ -129,40 +221,19 @@
 
 template <>
 void Gemv<float, CUDAContext>(
-    const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha,
-    const float* A, const float* x, const float beta, float* y,
-    CUDAContext* context) {
-  cublasOperation_t cuTransA =
-      (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
-  CUBLAS_ENFORCE(cublasSgemv(
-      context->cublas_handle(),
-      cuTransA,
-      N,
-      M,
-      &alpha,
-      A,
-      N,
-      x,
-      1,
-      &beta,
-      y,
-      1));
-}
-
-template <>
-void Gemv<double, CUDAContext>(
     const CBLAS_TRANSPOSE TransA,
     const int M,
     const int N,
-    const double alpha,
-    const double* A,
-    const double* x,
-    const double beta,
-    double* y,
-    CUDAContext* context) {
+    const float alpha,
+    const float* A,
+    const float* x,
+    const float beta,
+    float* y,
+    CUDAContext* context,
+    TensorProto::DataType math_type) {
   cublasOperation_t cuTransA =
       (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
-  CUBLAS_ENFORCE(cublasDgemv(
+  CUBLAS_ENFORCE(cublasSgemv(
       context->cublas_handle(),
       cuTransA,
       N,
@@ -216,6 +287,73 @@
 CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(double);
 #undef CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH
 
+template <>
+void Gemv<float16, CUDAContext>(
+    const CBLAS_TRANSPOSE TransA,
+    const int M,
+    const int N,
+    const float alpha,
+    const float16* A,
+    const float16* x,
+    const float beta,
+    float16* y,
+    CUDAContext* context,
+    TensorProto::DataType math_type) {
+  cublasOperation_t cuTransA =
+      (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+  // sort out what we need to call cublasSgemmEx / cublasHgemm
+  int m = (cuTransA == CUBLAS_OP_N) ? N : M;
+  int k = (cuTransA == CUBLAS_OP_N) ? M : N;
+  int LDA = (cuTransA == CUBLAS_OP_N) ? m : k;
+  int LDC = m;
+
+  if (math_type == TensorProto_DataType_FLOAT) {
+    CUBLAS_CHECK(cublasSgemmEx(
+        context->cublas_handle(),
+        cuTransA,
+        CUBLAS_OP_N,
+        m,
+        1,
+        k,
+        &alpha,
+        A,
+        CUDA_R_16F,
+        LDA,
+        x,
+        CUDA_R_16F,
+        k,
+        &beta,
+        y,
+        CUDA_R_16F,
+        LDC));
+  } else if (math_type == TensorProto_DataType_FLOAT16) {
+    __half alpha_fp16;
+    alpha_fp16.x = convert::To<float, float16>(alpha).x;
+    __half beta_fp16;
+    beta_fp16.x = convert::To<float, float16>(beta).x;
+
+    CUBLAS_CHECK(cublasHgemm(
+        context->cublas_handle(),
+        cuTransA,
+        CUBLAS_OP_N,
+        m,
+        1,
+        k,
+        &alpha_fp16,
+        (const __half*)A,
+        LDA,
+        (const __half*)x,
+        k,
+        &beta_fp16,
+        (__half*)y,
+        LDC));
+  } else {
+    // fail
+    CAFFE_THROW("Unsupported math type");
+  }
+}
+
 namespace {
 template <typename T>
 __global__ void SetKernel(const int N, const T alpha, T* Y) {
@@ -238,6 +376,7 @@
 CAFFE2_SPECIALIZED_CUDA_SET(bool);
 CAFFE2_SPECIALIZED_CUDA_SET(int8_t);
 CAFFE2_SPECIALIZED_CUDA_SET(int16_t);
+CAFFE2_SPECIALIZED_CUDA_SET(float16);
 CAFFE2_SPECIALIZED_CUDA_SET(int);
 CAFFE2_SPECIALIZED_CUDA_SET(int64_t);
 CAFFE2_SPECIALIZED_CUDA_SET(char);
@@ -247,11 +386,11 @@
 
 namespace {
 template <typename T>
-__global__ void UniformShift(const int N, const T min, const T max,
-                             T* x) {
-  T scale = max - min;
+__global__ void
+UniformShift(const int N, const float min, const float max, T* x) {
+  float scale = max - min;
   CUDA_1D_KERNEL_LOOP(i, N) {
-    x[i] = x[i] * scale + min;
+    x[i] = convert::To<float, T>(convert::To<T, float>(x[i]) * scale + min);
   }
 }
 
@@ -336,7 +475,6 @@
       context->curand_generator(), r, even_n, mean, std));
 }
 
-
 template<>
 void Dot<float, CUDAContext>(
     const int n, const float* a, const float* b, float* y,
@@ -346,13 +484,28 @@
   context->Copy<float, CPUContext, CUDAContext>(1, &result, y);
 }
 
-template<>
-void Dot<double, CUDAContext>(
-    const int n, const double* a, const double* b, double* y,
+template <>
+void Dot<float16, CUDAContext>(
+    const int n,
+    const float16* a,
+    const float16* b,
+    float16* y,
     CUDAContext* context) {
-  double result;
-  CUBLAS_ENFORCE(cublasDdot(context->cublas_handle(), n, a, 1, b, 1, y));
-  context->Copy<double, CPUContext, CUDAContext>(1, &result, y);
+  float16 result;
+  // execute with 32-bit math
+  CUBLAS_CHECK(cublasDotEx(
+      context->cublas_handle(),
+      n,
+      a,
+      CUDA_R_16F,
+      1,
+      b,
+      CUDA_R_16F,
+      1,
+      &result,
+      CUDA_R_16F,
+      CUDA_R_32F));
+  context->Copy<float16, CPUContext, CUDAContext>(1, &result, y);
 }
 
 // A previous version of caffe2 used Thrust but it turns out that thrust
@@ -363,7 +516,7 @@
 template <typename T>
 __global__ void SumKernel(const int N, const T* X, T* Y, bool square) {
   const int idx = threadIdx.x;
-  __shared__ T reduction_buffer[SUM_KERNEL_NTHREADS];
+  __shared__ float reduction_buffer[SUM_KERNEL_NTHREADS];
 
   reduction_buffer[idx] = 0;
 
@@ -371,11 +524,12 @@
   // N -> 128
   if (!square) {
     for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) {
-      reduction_buffer[idx] += X[i];
+      reduction_buffer[idx] += convert::To<T, float>(X[i]);
     }
   } else {
     for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) {
-      reduction_buffer[idx] += X[i] * X[i];
+      float Xi = convert::To<T, float>(X[i]);
+      reduction_buffer[idx] += Xi * Xi;
     }
   }
   __syncthreads();
@@ -393,7 +547,7 @@
     for (int i = 0; i < 32; ++i) {
       tmp += reduction_buffer[i];
     }
-    *Y = tmp;
+    *Y = convert::To<float, T>(tmp);
   }
 }
 
@@ -406,7 +560,7 @@
   }
 
 CAFFE2_MATH_SUM_FUNC(float)
-CAFFE2_MATH_SUM_FUNC(double)
+CAFFE2_MATH_SUM_FUNC(float16)
 #undef CAFFE2_MATH_SUM_FUNC
 
 #define CAFFE2_MATH_SUMSQR_FUNC(T)                                    \
@@ -438,18 +592,33 @@
                         0, context->cuda_stream()>>>(N, D, x, idx, y);
 }
 
+template <>
+void Select<float16, CUDAContext>(
+    const int N,
+    const int D,
+    const float16* x,
+    const int* idx,
+    float16* y,
+    CUDAContext* context) {
+  SelectKernel<float16><<<
+      CAFFE_GET_BLOCKS(N),
+      CAFFE_CUDA_NUM_THREADS,
+      0,
+      context->cuda_stream()>>>(N, D, x, idx, y);
+}
+
 namespace {
 template <typename T>
-__global__ void ScaleKernel(
-    const int n, const T alpha, const T* x, T* y) {
+__global__ void ScaleKernel(const int n, const float alpha, const T* x, T* y) {
   CUDA_1D_KERNEL_LOOP(i, n) {
-    y[i] = x[i] * alpha;
+    // y[i] = convert::To<float,T>(convert::To<T, float>(x[i]) * alpha);
+    y[i] = convert::Get<T>(convert::Get<float>(x[i]) * alpha);
   }
 }
 
 template <typename T>
-__global__ void ScaleKernelDeviceAlpha(
-    const int n, const T* alpha, const T* x, T* y) {
+__global__ void
+ScaleKernelDeviceAlpha(const int n, const float* alpha, const T* x, T* y) {
   CUDA_1D_KERNEL_LOOP(i, n) {
     y[i] = x[i] * (*alpha);
   }
@@ -461,6 +630,20 @@
     y[i] = powf(x[i], exponent);
   }
 }
+
+// fp16 specialization
+template <>
+__global__ void ScaleKernelDeviceAlpha(
+    const int n,
+    const float* alpha,
+    const float16* x,
+    float16* y) {
+  CUDA_1D_KERNEL_LOOP(i, n) {
+    y[i] = convert::To<float, float16>(
+        convert::To<float16, float>(x[i]) * (*alpha));
+  }
+}
+
 }  // namespace
 
 template <>
@@ -489,12 +672,17 @@
 }
 
 template <>
-void Scale<double, CUDAContext>(
-    const int n, const double alpha, const double *x, double* y,
+void Scale<float16, CUDAContext>(
+    const int n,
+    const float alpha,
+    const float16* x,
+    float16* y,
     CUDAContext* context) {
-  ScaleKernel<double><<<
-      CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
-          n, alpha, x, y);
+  ScaleKernel<float16><<<
+      CAFFE_GET_BLOCKS(n),
+      CAFFE_CUDA_NUM_THREADS,
+      0,
+      context->cuda_stream()>>>(n, alpha, x, y);
 }
 
 template <>
@@ -507,11 +695,17 @@
 }
 
 template <>
-void Scale<double, CUDAContext>(
-    const int n, const double* alpha, const double *x, double* y,
+void Scale<float16, CUDAContext>(
+    const int n,
+    const float* alpha,
+    const float16* x,
+    float16* y,
     CUDAContext* context) {
-  ScaleKernelDeviceAlpha<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
-                       0, context->cuda_stream()>>>(n, alpha, x, y);
+  ScaleKernelDeviceAlpha<float16><<<
+      CAFFE_GET_BLOCKS(n),
+      CAFFE_CUDA_NUM_THREADS,
+      0,
+      context->cuda_stream()>>>(n, alpha, x, y);
 }
 
 template <>
@@ -527,18 +721,42 @@
 template <>
 void Axpy<double, CUDAContext>(
     const int N,
-    const double alpha,
+    const float alpha,
     const double* X,
     double* Y,
     CUDAContext* context) {
-  CUBLAS_ENFORCE(cublasDaxpy(context->cublas_handle(), N, &alpha, X, 1, Y, 1));
+  double alpha_d{alpha};
+  CUBLAS_ENFORCE(
+      cublasDaxpy(context->cublas_handle(), N, &alpha_d, X, 1, Y, 1));
+}
+
+template <>
+void Axpy<float16, CUDAContext>(
+    const int N,
+    const float alpha,
+    const float16* X,
+    float16* Y,
+    CUDAContext* context) {
+  CUBLAS_CHECK(cublasAxpyEx(
+      context->cublas_handle(),
+      N,
+      &alpha,
+      CUDA_R_16F,
+      X,
+      CUDA_R_16F,
+      1,
+      Y,
+      CUDA_R_16F,
+      1,
+      CUDA_R_32F));
 }
 
 namespace {
 template <typename T>
-__global__ void AxpyKernel(const int n, const T* a, const T* x, T* y) {
+__global__ void AxpyKernel(const int n, const float* a, const T* x, T* y) {
   CUDA_1D_KERNEL_LOOP(index, n) {
-    y[index] += x[index] * (*a);
+    y[index] = convert::Get<T>(
+        convert::Get<float>(x[index]) * (*a) + convert::Get<float>(y[index]));
   }
 }
 }  // namespace
@@ -552,14 +770,19 @@
 }
 
 template <>
-void Axpy<double, CUDAContext>(
-    const int n, const double* alpha, const double* X,
-    double* Y, CUDAContext* context) {
-  AxpyKernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
-                       0, context->cuda_stream()>>>(n, alpha, X, Y);
+void Axpy<float16, CUDAContext>(
+    const int n,
+    const float* alpha,
+    const float16* X,
+    float16* Y,
+    CUDAContext* context) {
+  AxpyKernel<float16><<<
+      CAFFE_GET_BLOCKS(n),
+      CAFFE_CUDA_NUM_THREADS,
+      0,
+      context->cuda_stream()>>>(n, alpha, X, Y);
 }
 
-
 namespace {
 template <typename T>
 __global__ void AxpbyKernel(const int n, const T a, const T* x,
@@ -578,14 +801,6 @@
                        0, context->cuda_stream()>>>(n, a, x, b, y);
 }
 
-template <>
-void Axpby<double, CUDAContext>(
-    const int n, const double a, const double* x, const double b, double* y,
-    CUDAContext* context) {
-  AxpbyKernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
-                        0, context->cuda_stream()>>>(n, a, x, b, y);
-}
-
 namespace {
 
 template <typename T>
diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc
index 2ceeddd..b1f930b 100644
--- a/caffe2/utils/math_gpu_test.cc
+++ b/caffe2/utils/math_gpu_test.cc
@@ -67,61 +67,4 @@
   }
 }
 
-#define TEST_GEMV_WITH_TYPE(field_name)                                      \
-  TEST(MathUtilGPUTest, testGemv_##field_name) {                             \
-    if (!HasCudaGPU())                                                       \
-      return;                                                                \
-    Workspace ws;                                                            \
-    DeviceOption option;                                                     \
-    option.set_device_type(CUDA);                                            \
-    CUDAContext context(option);                                             \
-    Blob* blobx = ws.CreateBlob("X");                                        \
-    Blob* bloby = ws.CreateBlob("Y");                                        \
-    Blob* blobz = ws.CreateBlob("Z");                                        \
-    Blob* bloby_host = ws.CreateBlob("Y_host");                              \
-                                                                             \
-    vector<int> shapex{64, 128};                                             \
-    vector<int> shapey{64};                                                  \
-    vector<int> shapez{128};                                                 \
-                                                                             \
-    auto* tensorx = blobx->GetMutable<Tensor<CUDAContext>>();                \
-    tensorx->Resize(shapex);                                                 \
-    math::Set<field_name, CUDAContext>(                                      \
-        64 * 128,                                                            \
-        (field_name)1.0,                                                     \
-        tensorx->mutable_data<field_name>(),                                 \
-        &context);                                                           \
-                                                                             \
-    auto* tensory = bloby->GetMutable<Tensor<CUDAContext>>();                \
-    tensory->Resize(shapey);                                                 \
-    math::Set<field_name, CUDAContext>(                                      \
-        64, (field_name)1.0, tensory->mutable_data<field_name>(), &context); \
-                                                                             \
-    auto* tensorz = blobz->GetMutable<Tensor<CUDAContext>>();                \
-    tensorz->Resize(shapez);                                                 \
-                                                                             \
-    math::Gemv<field_name, CUDAContext>(                                     \
-        CblasTrans,                                                          \
-        64,                                                                  \
-        128,                                                                 \
-        1.0,                                                                 \
-        tensorx->template data<field_name>(),                                \
-        tensory->mutable_data<field_name>(),                                 \
-        0.0,                                                                 \
-        tensorz->template mutable_data<field_name>(),                        \
-        &context);                                                           \
-    context.FinishDeviceComputation();                                       \
-                                                                             \
-    auto* tensory_host = bloby_host->GetMutable<Tensor<CPUContext>>();       \
-    tensory_host->CopyFrom<CUDAContext, CUDAContext>(*tensorz, &context);    \
-    context.FinishDeviceComputation();                                       \
-                                                                             \
-    for (int i = 0; i < 128; i++) {                                          \
-      EXPECT_EQ(tensory_host->data<field_name>()[i], 64.0);                  \
-    }                                                                        \
-  }
-
-TEST_GEMV_WITH_TYPE(float);
-TEST_GEMV_WITH_TYPE(double);
-
 } // namespace caffe2