blob: 81111be206a44c23acd426775c02c38e0615d4ca [file] [log] [blame]
#include "caffe2/operators/batch_matmul_op.h"
#include "caffe2/core/context_gpu.h"
namespace caffe2 {
#if __CUDACC_VER_MAJOR__ >= 8
// CUDA 8 introduced a cublasSgemmStridedBatched function that allows us
// to carry out batched sgemm more efficiently. This is the specialized
// version that implements this.
template <>
bool BatchMatMulOp<float, CUDAContext, DefaultEngine>::RunOnDevice() {
const auto& A = Input(0);
const auto& B = Input(1);
auto* Y = Output(0);
CAFFE_ENFORCE_EQ(A.ndim(), 3);
CAFFE_ENFORCE_EQ(B.ndim(), 3);
CAFFE_ENFORCE_EQ(A.dim32(0), B.dim32(0));
int a_dim0, a_dim1, b_dim0, b_dim1;
if (trans_a_) {
a_dim0 = A.dim32(2);
a_dim1 = A.dim32(1);
} else {
a_dim0 = A.dim32(1);
a_dim1 = A.dim32(2);
}
if (trans_b_) {
b_dim0 = B.dim32(2);
b_dim1 = B.dim32(1);
} else {
b_dim0 = B.dim32(1);
b_dim1 = B.dim32(2);
}
// Error checking
CAFFE_ENFORCE(
a_dim1 == b_dim0,
"Dimension mismatch: ",
trans_a_ ? "trans(A): " : "A: ",
a_dim0,
" ",
a_dim1,
trans_b_ ? ", trans(B): " : ", B: ",
b_dim0,
" ",
b_dim1);
Y->Resize(A.dim(0), a_dim0, b_dim1);
if (!A.dim(0)) {
Y->mutable_data<float>(); // create output tensor
return true;
}
float alpha = 1;
float beta = 0;
CUBLAS_ENFORCE(cublasSgemmStridedBatched(
context_.cublas_handle(),
trans_b_ ? CUBLAS_OP_T : CUBLAS_OP_N,
trans_a_ ? CUBLAS_OP_T : CUBLAS_OP_N,
b_dim1,
a_dim0,
a_dim1,
&alpha,
B.data<float>(),
trans_b_ ? a_dim1 : b_dim1, // ldb
B.size() / B.dim(0), // b stride
A.data<float>(),
trans_a_ ? a_dim0 : a_dim1, // lda
A.size() / A.dim(0), // a stride
&beta,
Y->mutable_data<float>(),
b_dim1,
a_dim0 * b_dim1, // y stride
A.dim32(0) // batch count
));
return true;
}
#endif // __CUDACC_VER_MAJOR__ >= 8
REGISTER_CUDA_OPERATOR(BatchMatMul, BatchMatMulOp<float, CUDAContext>);
} // namespace caffe2