resubmission of PR1175: fp16 BatchMatMul
Summary: PR 1175 caused a build error because gemmBatched was only under a specific #ifdef. Now put it outside the #ifdef, and things work.
Reviewed By: asaadaldien
Differential Revision: D5834868
fbshipit-source-id: 072a64c8f4b259ff7504104121766115b46b8aa0
diff --git a/caffe2/operators/batch_matmul_op.cc b/caffe2/operators/batch_matmul_op.cc
index c2e578d..571758d 100644
--- a/caffe2/operators/batch_matmul_op.cc
+++ b/caffe2/operators/batch_matmul_op.cc
@@ -3,7 +3,7 @@
namespace caffe2 {
-REGISTER_CPU_OPERATOR(BatchMatMul, BatchMatMulOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(BatchMatMul, BatchMatMulOp<CPUContext>);
OPERATOR_SCHEMA(BatchMatMul)
.NumInputs(2)
@@ -55,15 +55,22 @@
trans_b = GetArgument(Def(), "trans_b").i();
}
- const auto no_trans_arg = vector<Argument>();
- const auto trans_a_arg = vector<Argument>{
+ auto no_trans_arg = vector<Argument>();
+ auto trans_a_arg = vector<Argument>{
MakeArgument<int>("trans_a", 1)};
- const auto trans_b_arg = vector<Argument>{
+ auto trans_b_arg = vector<Argument>{
MakeArgument<int>("trans_b", 1)};
- const auto trans_both_arg = vector<Argument>{
+ auto trans_both_arg = vector<Argument>{
MakeArgument<int>("trans_a", 1),
MakeArgument<int>("trans_b", 1)};
+ if (ArgumentHelper::HasArgument(Def(), "use_scratch")) {
+ no_trans_arg.push_back(MakeArgument<int>("use_scratch", 1));
+ trans_a_arg.push_back(MakeArgument<int>("use_scratch", 1));
+ trans_b_arg.push_back(MakeArgument<int>("use_scratch", 1));
+ trans_both_arg.push_back(MakeArgument<int>("use_scratch", 1));
+ }
+
if (trans_a) {
if (trans_b) {
// A'B':
diff --git a/caffe2/operators/batch_matmul_op.cu b/caffe2/operators/batch_matmul_op.cu
index 81111be..2eee5f7 100644
--- a/caffe2/operators/batch_matmul_op.cu
+++ b/caffe2/operators/batch_matmul_op.cu
@@ -4,84 +4,24 @@
namespace caffe2 {
-#if __CUDACC_VER_MAJOR__ >= 8
-// CUDA 8 introduced a cublasSgemmStridedBatched function that allows us
-// to carry out batched sgemm more efficiently. This is the specialized
-// version that implements this.
template <>
-bool BatchMatMulOp<float, CUDAContext, DefaultEngine>::RunOnDevice() {
- const auto& A = Input(0);
- const auto& B = Input(1);
- auto* Y = Output(0);
-
- CAFFE_ENFORCE_EQ(A.ndim(), 3);
- CAFFE_ENFORCE_EQ(B.ndim(), 3);
- CAFFE_ENFORCE_EQ(A.dim32(0), B.dim32(0));
-
- int a_dim0, a_dim1, b_dim0, b_dim1;
-
- if (trans_a_) {
- a_dim0 = A.dim32(2);
- a_dim1 = A.dim32(1);
- } else {
- a_dim0 = A.dim32(1);
- a_dim1 = A.dim32(2);
- }
-
- if (trans_b_) {
- b_dim0 = B.dim32(2);
- b_dim1 = B.dim32(1);
- } else {
- b_dim0 = B.dim32(1);
- b_dim1 = B.dim32(2);
- }
-
- // Error checking
- CAFFE_ENFORCE(
- a_dim1 == b_dim0,
- "Dimension mismatch: ",
- trans_a_ ? "trans(A): " : "A: ",
- a_dim0,
- " ",
- a_dim1,
- trans_b_ ? ", trans(B): " : ", B: ",
- b_dim0,
- " ",
- b_dim1);
-
- Y->Resize(A.dim(0), a_dim0, b_dim1);
-
- if (!A.dim(0)) {
- Y->mutable_data<float>(); // create output tensor
- return true;
- }
-
- float alpha = 1;
- float beta = 0;
-
- CUBLAS_ENFORCE(cublasSgemmStridedBatched(
- context_.cublas_handle(),
- trans_b_ ? CUBLAS_OP_T : CUBLAS_OP_N,
- trans_a_ ? CUBLAS_OP_T : CUBLAS_OP_N,
- b_dim1,
- a_dim0,
- a_dim1,
- &alpha,
- B.data<float>(),
- trans_b_ ? a_dim1 : b_dim1, // ldb
- B.size() / B.dim(0), // b stride
- A.data<float>(),
- trans_a_ ? a_dim0 : a_dim1, // lda
- A.size() / A.dim(0), // a stride
- &beta,
- Y->mutable_data<float>(),
- b_dim1,
- a_dim0 * b_dim1, // y stride
- A.dim32(0) // batch count
- ));
- return true;
+bool BatchMatMulOp<CUDAContext, DefaultEngine>::RunOnDevice() {
+ return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
}
-#endif // __CUDACC_VER_MAJOR__ >= 8
-REGISTER_CUDA_OPERATOR(BatchMatMul, BatchMatMulOp<float, CUDAContext>);
+REGISTER_CUDA_OPERATOR(BatchMatMul, BatchMatMulOp<CUDAContext>);
+
+#if CUDA_VERSION >= 9000
+
+template <>
+bool BatchMatMulOp<CUDAContext, TensorCoreEngine>::RunOnDevice() {
+ return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
+}
+
+REGISTER_CUDA_OPERATOR_WITH_ENGINE(
+ BatchMatMul,
+ TENSORCORE,
+ BatchMatMulOp<CUDAContext, TensorCoreEngine>);
+#endif
+
} // namespace caffe2
diff --git a/caffe2/operators/batch_matmul_op.h b/caffe2/operators/batch_matmul_op.h
index de080c5..9b80acb 100644
--- a/caffe2/operators/batch_matmul_op.h
+++ b/caffe2/operators/batch_matmul_op.h
@@ -7,17 +7,26 @@
namespace caffe2 {
-template <typename T, class Context, class Engine = DefaultEngine>
+template <class Context, class Engine = DefaultEngine>
class BatchMatMulOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
BatchMatMulOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
trans_a_(OperatorBase::GetSingleArgument<int>("trans_a", 0)),
- trans_b_(OperatorBase::GetSingleArgument<int>("trans_b", 0)) {}
+ trans_b_(OperatorBase::GetSingleArgument<int>("trans_b", 0)),
+ use_scratch_(OperatorBase::GetSingleArgument<int>("use_scratch", 0)) {
+ if (use_scratch_)
+ scratch_ = std::make_shared<Tensor<Context> >();
+ }
~BatchMatMulOp() {}
bool RunOnDevice() override {
+ return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
+ }
+
+ template <typename T>
+ bool DoRunWithType() {
const auto& A = Input(0);
const auto& B = Input(1);
auto* Y = Output(0);
@@ -65,29 +74,32 @@
}
// Y = A * B
- auto a_offset = A.size() / A.dim(0);
- auto b_offset = B.size() / B.dim(0);
- auto y_offset = a_dim0 * b_dim1;
- for (int i = 0; i < A.dim32(0); ++i) {
- math::Gemm<T, Context, Engine>(
- trans_a_ ? CblasTrans : CblasNoTrans,
- trans_b_ ? CblasTrans : CblasNoTrans,
- a_dim0,
- b_dim1,
- a_dim1,
- 1,
- A.template data<T>() + a_offset * i,
- B.template data<T>() + b_offset * i,
- 0,
- Y->template mutable_data<T>() + y_offset * i,
- &context_);
- }
+ math::GemmBatched<T, Context, Engine>(
+ trans_a_ ? CblasTrans : CblasNoTrans,
+ trans_b_ ? CblasTrans : CblasNoTrans,
+ A.size(),
+ A.dim32(0),
+ B.size(),
+ B.dim32(0),
+ a_dim0, // M
+ b_dim1, // N
+ a_dim1, // K
+ 1,
+ A.template data<T>(),
+ B.template data<T>(),
+ 0,
+ Y->template mutable_data<T>(),
+ &context_,
+ use_scratch_ ? scratch_.get() : nullptr);
return true;
}
protected:
bool trans_a_;
bool trans_b_;
+
+ bool use_scratch_;
+ std::shared_ptr<Tensor<Context> > scratch_;
};
} // namespace caffe2
diff --git a/caffe2/python/attention.py b/caffe2/python/attention.py
index 961f001..da9066b 100644
--- a/caffe2/python/attention.py
+++ b/caffe2/python/attention.py
@@ -27,7 +27,8 @@
scope,
):
# [batch_size, encoder_output_dim, 1]
- attention_weighted_encoder_context = model.net.BatchMatMul(
+ attention_weighted_encoder_context = brew.batch_mat_mul(
+ model,
[encoder_outputs_transposed, attention_weights_3d],
s(scope, 'attention_weighted_encoder_context'),
)
diff --git a/caffe2/python/brew.py b/caffe2/python/brew.py
index 7fd610e..ea9e215 100644
--- a/caffe2/python/brew.py
+++ b/caffe2/python/brew.py
@@ -61,6 +61,7 @@
'add_weight_decay': add_weight_decay,
'elementwise_linear': elementwise_linear,
'layer_norm': layer_norm,
+ 'batch_mat_mul' : batch_mat_mul,
}
def __init__(self, wrapped):
diff --git a/caffe2/python/helpers/algebra.py b/caffe2/python/helpers/algebra.py
index 8531c3d..6bc3779 100644
--- a/caffe2/python/helpers/algebra.py
+++ b/caffe2/python/helpers/algebra.py
@@ -16,3 +16,11 @@
def sum(model, blob_in, blob_out, **kwargs):
"""Sum"""
return model.net.Sum(blob_in, blob_out, **kwargs)
+
+
+def batch_mat_mul(model, blob_in, blob_out,
+ enable_tensor_core=False, **kwargs):
+ if enable_tensor_core:
+ kwargs['engine'] = 'TENSORCORE'
+
+ return model.net.BatchMatMul(blob_in, blob_out, **kwargs)
diff --git a/caffe2/python/operator_test/matmul_op_test.py b/caffe2/python/operator_test/matmul_op_test.py
index 32696ae..d3c119b 100644
--- a/caffe2/python/operator_test/matmul_op_test.py
+++ b/caffe2/python/operator_test/matmul_op_test.py
@@ -5,9 +5,10 @@
import numpy as np
-from hypothesis import given
+from hypothesis import assume, given, settings
import hypothesis.strategies as st
+from caffe2.proto import caffe2_pb2
from caffe2.python import core
import caffe2.python.hypothesis_test_util as hu
@@ -49,19 +50,26 @@
class TestBatchMatMul(hu.HypothesisTestCase):
+ @settings(max_examples=30)
@given(C=st.integers(min_value=1, max_value=10),
M=st.integers(min_value=1, max_value=10),
K=st.integers(min_value=1, max_value=10),
N=st.integers(min_value=1, max_value=10),
trans_a=st.booleans(),
trans_b=st.booleans(),
+ dtype=st.sampled_from([np.float32, np.float16]),
**hu.gcs)
- def test_batch_matmul(self, C, M, K, N, trans_a, trans_b, gc, dc):
- X = np.random.rand(C, M, K).astype(np.float32) - 0.5
+ def test_batch_matmul(self, C, M, K, N, trans_a, trans_b, dtype, gc, dc):
+ if dtype == np.float16:
+ # fp16 is only supported with CUDA
+ assume(gc.device_type == caffe2_pb2.CUDA)
+ dc = [d for d in dc if d.device_type == caffe2_pb2.CUDA]
+
+ X = np.random.rand(C, M, K).astype(dtype) - 0.5
if trans_a:
X = X.swapaxes(1, 2)
- Y = np.random.rand(C, K, N).astype(np.float32) - 0.5
+ Y = np.random.rand(C, K, N).astype(dtype) - 0.5
if trans_b:
Y = Y.swapaxes(1, 2)
@@ -82,10 +90,16 @@
matmul_ref)
# Check over multiple devices
self.assertDeviceChecks(dc, op, [X, Y], [0])
+
+ kwargs = {}
+ if dtype == np.float16:
+ kwargs['threshold'] = 0.75 # default is 0.005
+
# Gradient check wrt X
- self.assertGradientChecks(gc, op, [X, Y], 0, [0])
+ self.assertGradientChecks(gc, op, [X, Y], 0, [0], **kwargs)
# Gradient check wrt Y
- self.assertGradientChecks(gc, op, [X, Y], 1, [0])
+ self.assertGradientChecks(gc, op, [X, Y], 1, [0], **kwargs)
+
if __name__ == "__main__":
import unittest
diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h
index 6c352dc..0576b18 100644
--- a/caffe2/utils/math.h
+++ b/caffe2/utils/math.h
@@ -218,6 +218,27 @@
const int ldc,
Context* context);
+// GemmBatched provides a simple abstraction into library routines
+template <typename T, class Context, class Engine = DefaultEngine>
+void GemmBatched(
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int A_size,
+ const int A_batches,
+ const int B_size,
+ const int B_batches,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const T* A,
+ const T* B,
+ const float beta,
+ T* C,
+ Context* context,
+ Tensor<Context>* scratch = nullptr,
+ TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
+
// Gemv always takes in a M*N matrix A, and depending on whether we set TransA
// to Trans, the output is:
// CblasNoTrans: x is an N dim vector and y is an M dim vector.
diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc
index 99d339f..009d507 100644
--- a/caffe2/utils/math_cpu.cc
+++ b/caffe2/utils/math_cpu.cc
@@ -399,6 +399,45 @@
#endif // CAFFE2_USE_EIGEN_FOR_BLAS
+template <>
+void GemmBatched<float, CPUContext>(
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int A_size,
+ const int A_batches,
+ const int B_size,
+ const int B_batches,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float* A,
+ const float* B,
+ const float beta,
+ float* C,
+ CPUContext* context,
+ Tensor<CPUContext>*, /* scratch */
+ TensorProto::DataType /* math_type */) {
+
+ auto a_offset = A_size / A_batches;
+ auto b_offset = B_size / B_batches;
+ auto y_offset = M * N;
+ // loop over matrices in the batch
+ for (int i = 0; i < A_batches; ++i) {
+ math::Gemm<float, CPUContext>(
+ TransA,
+ TransB,
+ M,
+ N,
+ K,
+ 1,
+ A + a_offset * i,
+ B + b_offset * i,
+ 0,
+ C + y_offset * i,
+ context);
+ }
+}
////////////////////////////////////////////////////////////////////////////////
// MKL VML alternatives.
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index 974639e..fabd971 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -261,6 +261,240 @@
}
}
+template <>
+void GemmBatched<float, CUDAContext>(
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int A_size,
+ const int A_batches,
+ const int B_size,
+ const int B_batches,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float* A,
+ const float* B,
+ const float beta,
+ float* C,
+ CUDAContext* context,
+ Tensor<CUDAContext>* scratch,
+ TensorProto::DataType math_type) {
+
+#if __CUDACC_VER_MAJOR__ < 8
+ auto a_offset = A_size / A_batches;
+ auto b_offset = B_size / B_batches;
+ auto y_offset = M * N;
+ // loop over matrices in the batch
+ for (int i = 0; i < A_batches; ++i) {
+ math::Gemm<float, CUDAContext>(
+ TransA,
+ TransB,
+ M,
+ N,
+ K,
+ 1,
+ A + a_offset * i,
+ B + b_offset * i,
+ 0,
+ C + y_offset * i,
+ context);
+ }
+#else
+ // Note that cublas follows fortran order, so the order is different from
+ // the cblas convention.
+ int lda = (TransA == CblasNoTrans) ? K : M;
+ int ldb = (TransB == CblasNoTrans) ? N : K;
+
+ cublasOperation_t cuTransA =
+ (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+ cublasOperation_t cuTransB =
+ (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+ CUBLAS_ENFORCE(cublasSgemmStridedBatched(
+ context->cublas_handle(),
+ cuTransB,
+ cuTransA,
+ N,
+ M,
+ K,
+ &alpha,
+ B,
+ ldb,
+ B_size / B_batches, // B stride
+ A,
+ lda,
+ A_size / A_batches, // A stride
+ &beta,
+ C,
+ N,
+ M*N, // C stride
+ A_batches));
+#endif
+}
+
+namespace {
+
+__global__ void FloatToHalfKernel(const int N, const float* X, half* Y) {
+ CUDA_1D_KERNEL_LOOP(i, N) {
+ Y[i] = __float2half(X[i]);
+ }
+}
+
+__global__ void HalfToFloatKernel(const int N, const half* X, float* Y) {
+ CUDA_1D_KERNEL_LOOP(i, N) {
+ Y[i] = __half2float(X[i]);
+ }
+}
+
+};
+
+template <>
+void GemmBatched<float16, CUDAContext>(
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int A_size,
+ const int A_batches,
+ const int B_size,
+ const int B_batches,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float16* A,
+ const float16* B,
+ const float beta,
+ float16* C,
+ CUDAContext* context,
+ Tensor<CUDAContext>* scratch,
+ TensorProto::DataType math_type) {
+
+#if __CUDACC_VER_MAJOR__ < 8
+ auto a_offset = A_size / A_batches;
+ auto b_offset = B_size / B_batches;
+ auto y_offset = M * N;
+ // loop over matrices in the batch
+ for (int i = 0; i < A_batches; ++i) {
+ math::Gemm<float16, CUDAContext>(
+ TransA,
+ TransB,
+ M,
+ N,
+ K,
+ 1,
+ A + a_offset * i,
+ B + b_offset * i,
+ 0,
+ C + y_offset * i,
+ context);
+ }
+#else
+ // 3 options:
+ // 1) scratch != null = cast to fp32, SgemmStridedBatched, cast result to fp16
+ // 2) math_type == FLOAT, scratch == nullptr = looped SgemmEx
+ // 3) math_type == FLOAT16, scratch == nullptr = batched Hgemm
+
+ if (scratch != nullptr) {
+ // cast, cublasSgemmStridedBatched, cast
+ size_t in_elems = A_size + B_size;
+ size_t out_elems = A_batches*M*N;
+
+ scratch->Resize(in_elems+out_elems);
+ float* scratch_ptr = scratch->mutable_data<float>();
+
+ float* A_fp32 = scratch_ptr;
+ float* B_fp32 = scratch_ptr + A_size;
+ float* C_fp32 = scratch_ptr + A_size + B_size;
+
+ // cast A, B into fp32
+ HalfToFloatKernel<<<CAFFE_GET_BLOCKS(A_size),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context->cuda_stream()>>>(A_size, (half*)A, A_fp32);
+ HalfToFloatKernel<<<CAFFE_GET_BLOCKS(B_size),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context->cuda_stream()>>>(B_size, (half*)B, B_fp32);
+
+ // run fp32 batched Gemm
+ GemmBatched<float,CUDAContext>(
+ TransA,
+ TransB,
+ A_size,
+ A_batches,
+ B_size,
+ B_batches,
+ M,
+ N,
+ K,
+ alpha,
+ A_fp32,
+ B_fp32,
+ beta,
+ C_fp32,
+ context);
+
+ // cast result back to fp16
+ FloatToHalfKernel<<<CAFFE_GET_BLOCKS(A_batches*M*N),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context->cuda_stream()>>>(A_batches*M*N, C_fp32, (half*)C);
+ } else {
+ if (math_type == TensorProto_DataType_FLOAT) {
+ auto a_offset = A_size / A_batches;
+ auto b_offset = B_size / B_batches;
+ auto y_offset = M * N;
+ // loop over matrices in the batch
+ for (int i = 0; i < A_batches; ++i) {
+ math::Gemm<float16, CUDAContext>(
+ TransA,
+ TransB,
+ M,
+ N,
+ K,
+ 1,
+ A + a_offset * i,
+ B + b_offset * i,
+ 0,
+ C + y_offset * i,
+ context);
+ }
+ } else if (math_type == TensorProto_DataType_FLOAT16) {
+ // Note that cublas follows fortran order, so the order is different from
+ // the cblas convention.
+ int lda = (TransA == CblasNoTrans) ? K : M;
+ int ldb = (TransB == CblasNoTrans) ? N : K;
+ cublasOperation_t cuTransA =
+ (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+ cublasOperation_t cuTransB =
+ (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+
+ // convert alpha, beta from float -> __half
+ auto alpha_fp16 = convert::floatToHalf(alpha);
+ auto beta_fp16 = convert::floatToHalf(beta);
+ CUBLAS_ENFORCE(cublasHgemmStridedBatched(
+ context->cublas_handle(),
+ cuTransB,
+ cuTransA,
+ N,
+ M,
+ K,
+ &alpha_fp16,
+ (const __half*)B,
+ ldb,
+ B_size / B_batches,
+ (const __half*)A,
+ lda,
+ A_size / A_batches,
+ &beta_fp16,
+ (__half*)C,
+ N,
+ M*N,
+ A_batches));
+ }
+ }
+#endif
+}
+
#if CUDA_VERSION >= 9000
// No change, but required. Defer to default CUDA engine
@@ -351,6 +585,84 @@
}
}
+template <>
+void GemmBatched<float, CUDAContext, TensorCoreEngine>(
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int A_size,
+ const int A_batches,
+ const int B_size,
+ const int B_batches,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float* A,
+ const float* B,
+ const float beta,
+ float* C,
+ CUDAContext* context,
+ Tensor<CUDAContext>* scratch,
+ TensorProto::DataType math_type) {
+ return GemmBatched<float, CUDAContext, DefaultEngine>(
+ TransA,
+ TransB,
+ A_size,
+ A_batches,
+ B_size,
+ B_batches,
+ M,
+ N,
+ K,
+ alpha,
+ A,
+ B,
+ beta,
+ C,
+ context,
+ scratch,
+ math_type);
+}
+
+template <>
+void GemmBatched<float16, CUDAContext, TensorCoreEngine>(
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int A_size,
+ const int A_batches,
+ const int B_size,
+ const int B_batches,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float16* A,
+ const float16* B,
+ const float beta,
+ float16* C,
+ CUDAContext* context,
+ Tensor<CUDAContext>* scratch,
+ TensorProto::DataType math_type) {
+ return GemmBatched<float16, CUDAContext, DefaultEngine>(
+ TransA,
+ TransB,
+ A_size,
+ A_batches,
+ B_size,
+ B_batches,
+ M,
+ N,
+ K,
+ alpha,
+ A,
+ B,
+ beta,
+ C,
+ context,
+ scratch,
+ math_type);
+}
+
#endif // CUDA_VERSION >= 9000
template <>
@@ -434,7 +746,9 @@
for (int j = 0; j < batch; j++) {
const T* x = first + j * stripe;
CUDA_1D_KERNEL_LOOP(i, N) {
- Y[i] += x[i];
+ float tmpY = convert::To<T, float>(Y[i]);
+ tmpY += convert::To<T,float>(x[i]);
+ Y[i] = convert::To<float,T>(tmpY);
}
}
}
@@ -457,7 +771,7 @@
}
CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float);
-CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(double);
+CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float16);
#undef CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH
template <>