resubmission of PR1175: fp16 BatchMatMul

Summary: PR 1175 caused a build error because gemmBatched was only under a specific #ifdef. Now put it outside the #ifdef, and things work.

Reviewed By: asaadaldien

Differential Revision: D5834868

fbshipit-source-id: 072a64c8f4b259ff7504104121766115b46b8aa0
diff --git a/caffe2/operators/batch_matmul_op.cc b/caffe2/operators/batch_matmul_op.cc
index c2e578d..571758d 100644
--- a/caffe2/operators/batch_matmul_op.cc
+++ b/caffe2/operators/batch_matmul_op.cc
@@ -3,7 +3,7 @@
 
 namespace caffe2 {
 
-REGISTER_CPU_OPERATOR(BatchMatMul, BatchMatMulOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(BatchMatMul, BatchMatMulOp<CPUContext>);
 
 OPERATOR_SCHEMA(BatchMatMul)
     .NumInputs(2)
@@ -55,15 +55,22 @@
       trans_b = GetArgument(Def(), "trans_b").i();
     }
 
-    const auto no_trans_arg = vector<Argument>();
-    const auto trans_a_arg = vector<Argument>{
+    auto no_trans_arg = vector<Argument>();
+    auto trans_a_arg = vector<Argument>{
         MakeArgument<int>("trans_a", 1)};
-    const auto trans_b_arg = vector<Argument>{
+    auto trans_b_arg = vector<Argument>{
         MakeArgument<int>("trans_b", 1)};
-    const auto trans_both_arg = vector<Argument>{
+    auto trans_both_arg = vector<Argument>{
         MakeArgument<int>("trans_a", 1),
         MakeArgument<int>("trans_b", 1)};
 
+    if (ArgumentHelper::HasArgument(Def(), "use_scratch")) {
+      no_trans_arg.push_back(MakeArgument<int>("use_scratch", 1));
+      trans_a_arg.push_back(MakeArgument<int>("use_scratch", 1));
+      trans_b_arg.push_back(MakeArgument<int>("use_scratch", 1));
+      trans_both_arg.push_back(MakeArgument<int>("use_scratch", 1));
+    }
+
     if (trans_a) {
       if (trans_b) {
         // A'B':
diff --git a/caffe2/operators/batch_matmul_op.cu b/caffe2/operators/batch_matmul_op.cu
index 81111be..2eee5f7 100644
--- a/caffe2/operators/batch_matmul_op.cu
+++ b/caffe2/operators/batch_matmul_op.cu
@@ -4,84 +4,24 @@
 
 namespace caffe2 {
 
-#if __CUDACC_VER_MAJOR__ >= 8
-// CUDA 8 introduced a cublasSgemmStridedBatched function that allows us
-// to carry out batched sgemm more efficiently. This is the specialized
-// version that implements this.
 template <>
-bool BatchMatMulOp<float, CUDAContext, DefaultEngine>::RunOnDevice() {
-  const auto& A = Input(0);
-  const auto& B = Input(1);
-  auto* Y = Output(0);
-
-  CAFFE_ENFORCE_EQ(A.ndim(), 3);
-  CAFFE_ENFORCE_EQ(B.ndim(), 3);
-  CAFFE_ENFORCE_EQ(A.dim32(0), B.dim32(0));
-
-  int a_dim0, a_dim1, b_dim0, b_dim1;
-
-  if (trans_a_) {
-    a_dim0 = A.dim32(2);
-    a_dim1 = A.dim32(1);
-  } else {
-    a_dim0 = A.dim32(1);
-    a_dim1 = A.dim32(2);
-  }
-
-  if (trans_b_) {
-    b_dim0 = B.dim32(2);
-    b_dim1 = B.dim32(1);
-  } else {
-    b_dim0 = B.dim32(1);
-    b_dim1 = B.dim32(2);
-  }
-
-  // Error checking
-  CAFFE_ENFORCE(
-      a_dim1 == b_dim0,
-      "Dimension mismatch: ",
-      trans_a_ ? "trans(A): " : "A: ",
-      a_dim0,
-      " ",
-      a_dim1,
-      trans_b_ ? ", trans(B): " : ", B: ",
-      b_dim0,
-      " ",
-      b_dim1);
-
-  Y->Resize(A.dim(0), a_dim0, b_dim1);
-
-  if (!A.dim(0)) {
-    Y->mutable_data<float>(); // create output tensor
-    return true;
-  }
-
-  float alpha = 1;
-  float beta = 0;
-
-  CUBLAS_ENFORCE(cublasSgemmStridedBatched(
-      context_.cublas_handle(),
-      trans_b_ ? CUBLAS_OP_T : CUBLAS_OP_N,
-      trans_a_ ? CUBLAS_OP_T : CUBLAS_OP_N,
-      b_dim1,
-      a_dim0,
-      a_dim1,
-      &alpha,
-      B.data<float>(),
-      trans_b_ ? a_dim1 : b_dim1, // ldb
-      B.size() / B.dim(0), // b stride
-      A.data<float>(),
-      trans_a_ ? a_dim0 : a_dim1, // lda
-      A.size() / A.dim(0), // a stride
-      &beta,
-      Y->mutable_data<float>(),
-      b_dim1,
-      a_dim0 * b_dim1, // y stride
-      A.dim32(0) // batch count
-      ));
-  return true;
+bool BatchMatMulOp<CUDAContext, DefaultEngine>::RunOnDevice() {
+    return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
 }
-#endif // __CUDACC_VER_MAJOR__ >= 8
 
-REGISTER_CUDA_OPERATOR(BatchMatMul, BatchMatMulOp<float, CUDAContext>);
+REGISTER_CUDA_OPERATOR(BatchMatMul, BatchMatMulOp<CUDAContext>);
+
+#if CUDA_VERSION >= 9000
+
+template <>
+bool BatchMatMulOp<CUDAContext, TensorCoreEngine>::RunOnDevice() {
+    return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
+}
+
+REGISTER_CUDA_OPERATOR_WITH_ENGINE(
+    BatchMatMul,
+    TENSORCORE,
+    BatchMatMulOp<CUDAContext, TensorCoreEngine>);
+#endif
+
 } // namespace caffe2
diff --git a/caffe2/operators/batch_matmul_op.h b/caffe2/operators/batch_matmul_op.h
index de080c5..9b80acb 100644
--- a/caffe2/operators/batch_matmul_op.h
+++ b/caffe2/operators/batch_matmul_op.h
@@ -7,17 +7,26 @@
 
 namespace caffe2 {
 
-template <typename T, class Context, class Engine = DefaultEngine>
+template <class Context, class Engine = DefaultEngine>
 class BatchMatMulOp final : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
   BatchMatMulOp(const OperatorDef& operator_def, Workspace* ws)
       : Operator<Context>(operator_def, ws),
         trans_a_(OperatorBase::GetSingleArgument<int>("trans_a", 0)),
-        trans_b_(OperatorBase::GetSingleArgument<int>("trans_b", 0)) {}
+        trans_b_(OperatorBase::GetSingleArgument<int>("trans_b", 0)),
+        use_scratch_(OperatorBase::GetSingleArgument<int>("use_scratch", 0)) {
+    if (use_scratch_)
+      scratch_ = std::make_shared<Tensor<Context> >();
+  }
   ~BatchMatMulOp() {}
 
   bool RunOnDevice() override {
+    return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
+  }
+
+  template <typename T>
+  bool DoRunWithType() {
     const auto& A = Input(0);
     const auto& B = Input(1);
     auto* Y = Output(0);
@@ -65,29 +74,32 @@
     }
 
     // Y = A * B
-    auto a_offset = A.size() / A.dim(0);
-    auto b_offset = B.size() / B.dim(0);
-    auto y_offset = a_dim0 * b_dim1;
-    for (int i = 0; i < A.dim32(0); ++i) {
-      math::Gemm<T, Context, Engine>(
-          trans_a_ ? CblasTrans : CblasNoTrans,
-          trans_b_ ? CblasTrans : CblasNoTrans,
-          a_dim0,
-          b_dim1,
-          a_dim1,
-          1,
-          A.template data<T>() + a_offset * i,
-          B.template data<T>() + b_offset * i,
-          0,
-          Y->template mutable_data<T>() + y_offset * i,
-          &context_);
-    }
+    math::GemmBatched<T, Context, Engine>(
+        trans_a_ ? CblasTrans : CblasNoTrans,
+        trans_b_ ? CblasTrans : CblasNoTrans,
+        A.size(),
+        A.dim32(0),
+        B.size(),
+        B.dim32(0),
+        a_dim0, // M
+        b_dim1, // N
+        a_dim1, // K
+        1,
+        A.template data<T>(),
+        B.template data<T>(),
+        0,
+        Y->template mutable_data<T>(),
+        &context_,
+        use_scratch_ ? scratch_.get() : nullptr);
     return true;
   }
 
  protected:
   bool trans_a_;
   bool trans_b_;
+
+  bool use_scratch_;
+  std::shared_ptr<Tensor<Context> > scratch_;
 };
 
 } // namespace caffe2
diff --git a/caffe2/python/attention.py b/caffe2/python/attention.py
index 961f001..da9066b 100644
--- a/caffe2/python/attention.py
+++ b/caffe2/python/attention.py
@@ -27,7 +27,8 @@
     scope,
 ):
     # [batch_size, encoder_output_dim, 1]
-    attention_weighted_encoder_context = model.net.BatchMatMul(
+    attention_weighted_encoder_context = brew.batch_mat_mul(
+        model,
         [encoder_outputs_transposed, attention_weights_3d],
         s(scope, 'attention_weighted_encoder_context'),
     )
diff --git a/caffe2/python/brew.py b/caffe2/python/brew.py
index 7fd610e..ea9e215 100644
--- a/caffe2/python/brew.py
+++ b/caffe2/python/brew.py
@@ -61,6 +61,7 @@
         'add_weight_decay': add_weight_decay,
         'elementwise_linear': elementwise_linear,
         'layer_norm': layer_norm,
+        'batch_mat_mul' : batch_mat_mul,
     }
 
     def __init__(self, wrapped):
diff --git a/caffe2/python/helpers/algebra.py b/caffe2/python/helpers/algebra.py
index 8531c3d..6bc3779 100644
--- a/caffe2/python/helpers/algebra.py
+++ b/caffe2/python/helpers/algebra.py
@@ -16,3 +16,11 @@
 def sum(model, blob_in, blob_out, **kwargs):
     """Sum"""
     return model.net.Sum(blob_in, blob_out, **kwargs)
+
+
+def batch_mat_mul(model, blob_in, blob_out,
+                  enable_tensor_core=False, **kwargs):
+    if enable_tensor_core:
+        kwargs['engine'] = 'TENSORCORE'
+
+    return model.net.BatchMatMul(blob_in, blob_out, **kwargs)
diff --git a/caffe2/python/operator_test/matmul_op_test.py b/caffe2/python/operator_test/matmul_op_test.py
index 32696ae..d3c119b 100644
--- a/caffe2/python/operator_test/matmul_op_test.py
+++ b/caffe2/python/operator_test/matmul_op_test.py
@@ -5,9 +5,10 @@
 
 import numpy as np
 
-from hypothesis import given
+from hypothesis import assume, given, settings
 import hypothesis.strategies as st
 
+from caffe2.proto import caffe2_pb2
 from caffe2.python import core
 import caffe2.python.hypothesis_test_util as hu
 
@@ -49,19 +50,26 @@
 
 
 class TestBatchMatMul(hu.HypothesisTestCase):
+    @settings(max_examples=30)
     @given(C=st.integers(min_value=1, max_value=10),
            M=st.integers(min_value=1, max_value=10),
            K=st.integers(min_value=1, max_value=10),
            N=st.integers(min_value=1, max_value=10),
            trans_a=st.booleans(),
            trans_b=st.booleans(),
+           dtype=st.sampled_from([np.float32, np.float16]),
            **hu.gcs)
-    def test_batch_matmul(self, C, M, K, N, trans_a, trans_b, gc, dc):
-        X = np.random.rand(C, M, K).astype(np.float32) - 0.5
+    def test_batch_matmul(self, C, M, K, N, trans_a, trans_b, dtype, gc, dc):
+        if dtype == np.float16:
+            # fp16 is only supported with CUDA
+            assume(gc.device_type == caffe2_pb2.CUDA)
+            dc = [d for d in dc if d.device_type == caffe2_pb2.CUDA]
+
+        X = np.random.rand(C, M, K).astype(dtype) - 0.5
         if trans_a:
             X = X.swapaxes(1, 2)
 
-        Y = np.random.rand(C, K, N).astype(np.float32) - 0.5
+        Y = np.random.rand(C, K, N).astype(dtype) - 0.5
         if trans_b:
             Y = Y.swapaxes(1, 2)
 
@@ -82,10 +90,16 @@
                                    matmul_ref)
         # Check over multiple devices
         self.assertDeviceChecks(dc, op, [X, Y], [0])
+
+        kwargs = {}
+        if dtype == np.float16:
+            kwargs['threshold'] = 0.75  # default is 0.005
+
         # Gradient check wrt X
-        self.assertGradientChecks(gc, op, [X, Y], 0, [0])
+        self.assertGradientChecks(gc, op, [X, Y], 0, [0], **kwargs)
         # Gradient check wrt Y
-        self.assertGradientChecks(gc, op, [X, Y], 1, [0])
+        self.assertGradientChecks(gc, op, [X, Y], 1, [0], **kwargs)
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h
index 6c352dc..0576b18 100644
--- a/caffe2/utils/math.h
+++ b/caffe2/utils/math.h
@@ -218,6 +218,27 @@
     const int ldc,
     Context* context);
 
+// GemmBatched provides a simple abstraction into library routines
+template <typename T, class Context, class Engine = DefaultEngine>
+void GemmBatched(
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int A_size,
+    const int A_batches,
+    const int B_size,
+    const int B_batches,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const T* A,
+    const T* B,
+    const float beta,
+    T* C,
+    Context* context,
+    Tensor<Context>* scratch = nullptr,
+    TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
+
 // Gemv always takes in a M*N matrix A, and depending on whether we set TransA
 // to Trans, the output is:
 // CblasNoTrans: x is an N dim vector and y is an M dim vector.
diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc
index 99d339f..009d507 100644
--- a/caffe2/utils/math_cpu.cc
+++ b/caffe2/utils/math_cpu.cc
@@ -399,6 +399,45 @@
 
 #endif  // CAFFE2_USE_EIGEN_FOR_BLAS
 
+template <>
+void GemmBatched<float, CPUContext>(
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int A_size,
+    const int A_batches,
+    const int B_size,
+    const int B_batches,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float* A,
+    const float* B,
+    const float beta,
+    float* C,
+    CPUContext* context,
+    Tensor<CPUContext>*, /* scratch */
+    TensorProto::DataType /* math_type */) {
+
+  auto a_offset = A_size / A_batches;
+  auto b_offset = B_size / B_batches;
+  auto y_offset = M * N;
+  // loop over matrices in the batch
+  for (int i = 0; i < A_batches; ++i) {
+    math::Gemm<float, CPUContext>(
+        TransA,
+        TransB,
+        M,
+        N,
+        K,
+        1,
+        A + a_offset * i,
+        B + b_offset * i,
+        0,
+        C + y_offset * i,
+        context);
+  }
+}
 
 ////////////////////////////////////////////////////////////////////////////////
 // MKL VML alternatives.
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index 974639e..fabd971 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -261,6 +261,240 @@
   }
 }
 
+template <>
+void GemmBatched<float, CUDAContext>(
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int A_size,
+    const int A_batches,
+    const int B_size,
+    const int B_batches,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float* A,
+    const float* B,
+    const float beta,
+    float* C,
+    CUDAContext* context,
+    Tensor<CUDAContext>* scratch,
+    TensorProto::DataType math_type) {
+
+#if __CUDACC_VER_MAJOR__ < 8
+  auto a_offset = A_size / A_batches;
+  auto b_offset = B_size / B_batches;
+  auto y_offset = M * N;
+  // loop over matrices in the batch
+  for (int i = 0; i < A_batches; ++i) {
+    math::Gemm<float, CUDAContext>(
+        TransA,
+        TransB,
+        M,
+        N,
+        K,
+        1,
+        A + a_offset * i,
+        B + b_offset * i,
+        0,
+        C + y_offset * i,
+        context);
+  }
+#else
+  // Note that cublas follows fortran order, so the order is different from
+  // the cblas convention.
+  int lda = (TransA == CblasNoTrans) ? K : M;
+  int ldb = (TransB == CblasNoTrans) ? N : K;
+
+  cublasOperation_t cuTransA =
+      (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  cublasOperation_t cuTransB =
+      (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  CUBLAS_ENFORCE(cublasSgemmStridedBatched(
+        context->cublas_handle(),
+        cuTransB,
+        cuTransA,
+        N,
+        M,
+        K,
+        &alpha,
+        B,
+        ldb,
+        B_size / B_batches, // B stride
+        A,
+        lda,
+        A_size / A_batches, // A stride
+        &beta,
+        C,
+        N,
+        M*N,                // C stride
+        A_batches));
+#endif
+}
+
+namespace {
+
+__global__ void FloatToHalfKernel(const int N, const float* X, half* Y) {
+  CUDA_1D_KERNEL_LOOP(i, N) {
+    Y[i] = __float2half(X[i]);
+  }
+}
+
+__global__ void HalfToFloatKernel(const int N, const half* X, float* Y) {
+  CUDA_1D_KERNEL_LOOP(i, N) {
+    Y[i] = __half2float(X[i]);
+  }
+}
+
+};
+
+template <>
+void GemmBatched<float16, CUDAContext>(
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int A_size,
+    const int A_batches,
+    const int B_size,
+    const int B_batches,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float16* A,
+    const float16* B,
+    const float beta,
+    float16* C,
+    CUDAContext* context,
+    Tensor<CUDAContext>* scratch,
+    TensorProto::DataType math_type) {
+
+#if __CUDACC_VER_MAJOR__ < 8
+  auto a_offset = A_size / A_batches;
+  auto b_offset = B_size / B_batches;
+  auto y_offset = M * N;
+  // loop over matrices in the batch
+  for (int i = 0; i < A_batches; ++i) {
+    math::Gemm<float16, CUDAContext>(
+        TransA,
+        TransB,
+        M,
+        N,
+        K,
+        1,
+        A + a_offset * i,
+        B + b_offset * i,
+        0,
+        C + y_offset * i,
+        context);
+  }
+#else
+  // 3 options:
+  // 1) scratch != null = cast to fp32, SgemmStridedBatched, cast result to fp16
+  // 2) math_type == FLOAT, scratch == nullptr = looped SgemmEx
+  // 3) math_type == FLOAT16, scratch == nullptr = batched Hgemm
+
+  if (scratch != nullptr) {
+    // cast, cublasSgemmStridedBatched, cast
+    size_t in_elems = A_size + B_size;
+    size_t out_elems = A_batches*M*N;
+
+    scratch->Resize(in_elems+out_elems);
+    float* scratch_ptr = scratch->mutable_data<float>();
+
+    float* A_fp32 = scratch_ptr;
+    float* B_fp32 = scratch_ptr + A_size;
+    float* C_fp32 = scratch_ptr + A_size + B_size;
+
+    // cast A, B into fp32
+    HalfToFloatKernel<<<CAFFE_GET_BLOCKS(A_size),
+                        CAFFE_CUDA_NUM_THREADS,
+                        0,
+                        context->cuda_stream()>>>(A_size, (half*)A, A_fp32);
+    HalfToFloatKernel<<<CAFFE_GET_BLOCKS(B_size),
+                        CAFFE_CUDA_NUM_THREADS,
+                        0,
+                        context->cuda_stream()>>>(B_size, (half*)B, B_fp32);
+
+    // run fp32 batched Gemm
+    GemmBatched<float,CUDAContext>(
+        TransA,
+        TransB,
+        A_size,
+        A_batches,
+        B_size,
+        B_batches,
+        M,
+        N,
+        K,
+        alpha,
+        A_fp32,
+        B_fp32,
+        beta,
+        C_fp32,
+        context);
+
+    // cast result back to fp16
+    FloatToHalfKernel<<<CAFFE_GET_BLOCKS(A_batches*M*N),
+                        CAFFE_CUDA_NUM_THREADS,
+                        0,
+                        context->cuda_stream()>>>(A_batches*M*N, C_fp32, (half*)C);
+  } else {
+    if (math_type == TensorProto_DataType_FLOAT) {
+      auto a_offset = A_size / A_batches;
+      auto b_offset = B_size / B_batches;
+      auto y_offset = M * N;
+      // loop over matrices in the batch
+      for (int i = 0; i < A_batches; ++i) {
+        math::Gemm<float16, CUDAContext>(
+            TransA,
+            TransB,
+            M,
+            N,
+            K,
+            1,
+            A + a_offset * i,
+            B + b_offset * i,
+            0,
+            C + y_offset * i,
+            context);
+      }
+    } else if (math_type == TensorProto_DataType_FLOAT16) {
+      // Note that cublas follows fortran order, so the order is different from
+      // the cblas convention.
+      int lda = (TransA == CblasNoTrans) ? K : M;
+      int ldb = (TransB == CblasNoTrans) ? N : K;
+      cublasOperation_t cuTransA =
+          (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+      cublasOperation_t cuTransB =
+          (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+
+      // convert alpha, beta from float -> __half
+      auto alpha_fp16 = convert::floatToHalf(alpha);
+      auto beta_fp16 = convert::floatToHalf(beta);
+      CUBLAS_ENFORCE(cublasHgemmStridedBatched(
+            context->cublas_handle(),
+            cuTransB,
+            cuTransA,
+            N,
+            M,
+            K,
+            &alpha_fp16,
+            (const __half*)B,
+            ldb,
+            B_size / B_batches,
+            (const __half*)A,
+            lda,
+            A_size / A_batches,
+            &beta_fp16,
+            (__half*)C,
+            N,
+            M*N,
+            A_batches));
+    }
+  }
+#endif
+}
+
 #if CUDA_VERSION >= 9000
 
 // No change, but required. Defer to default CUDA engine
@@ -351,6 +585,84 @@
   }
 }
 
+template <>
+void GemmBatched<float, CUDAContext, TensorCoreEngine>(
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int A_size,
+    const int A_batches,
+    const int B_size,
+    const int B_batches,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float* A,
+    const float* B,
+    const float beta,
+    float* C,
+    CUDAContext* context,
+    Tensor<CUDAContext>* scratch,
+    TensorProto::DataType math_type) {
+  return GemmBatched<float, CUDAContext, DefaultEngine>(
+      TransA,
+      TransB,
+      A_size,
+      A_batches,
+      B_size,
+      B_batches,
+      M,
+      N,
+      K,
+      alpha,
+      A,
+      B,
+      beta,
+      C,
+      context,
+      scratch,
+      math_type);
+}
+
+template <>
+void GemmBatched<float16, CUDAContext, TensorCoreEngine>(
+    const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB,
+    const int A_size,
+    const int A_batches,
+    const int B_size,
+    const int B_batches,
+    const int M,
+    const int N,
+    const int K,
+    const float alpha,
+    const float16* A,
+    const float16* B,
+    const float beta,
+    float16* C,
+    CUDAContext* context,
+    Tensor<CUDAContext>* scratch,
+    TensorProto::DataType math_type) {
+  return GemmBatched<float16, CUDAContext, DefaultEngine>(
+      TransA,
+      TransB,
+      A_size,
+      A_batches,
+      B_size,
+      B_batches,
+      M,
+      N,
+      K,
+      alpha,
+      A,
+      B,
+      beta,
+      C,
+      context,
+      scratch,
+      math_type);
+}
+
 #endif // CUDA_VERSION >= 9000
 
 template <>
@@ -434,7 +746,9 @@
   for (int j = 0; j < batch; j++) {
     const T* x = first + j * stripe;
     CUDA_1D_KERNEL_LOOP(i, N) {
-      Y[i] += x[i];
+      float tmpY = convert::To<T, float>(Y[i]);
+      tmpY += convert::To<T,float>(x[i]);
+      Y[i] = convert::To<float,T>(tmpY);
     }
   }
 }
@@ -457,7 +771,7 @@
   }
 
 CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float);
-CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(double);
+CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float16);
 #undef CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH
 
 template <>