[Intel MKL] Supporting MatMul, Transpose and Softmax with BFloat16 type
This PR enables _MklMatMul, _MklTranspose and _MklSoftmax with BFloat16 type.
Some of the changes are suggested by Clang format checker.
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index ba2526b..fd6e2be 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -179,7 +179,16 @@
search_string += string(";") + string(" T in [");
search_string += DataType_Name(T) + string("]");
- return kernel.find(search_string) != string::npos;
+ // Temporarily replacing earlier check by adding a type-specific check so
+ // that we can selectively decide which type is support by MKL operators.
+ // That way kernel registration does not decide which operators we support.
+ // We are using this change to temporarily disable BFLOAT16 support. Once
+ // we want to enable it, we will go back to earlier check.
+ if (kernel.find(search_string) != string::npos) {
+ return T == DT_COMPLEX128 || T == DT_COMPLEX64 || T == DT_DOUBLE ||
+ T == DT_FLOAT;
+ }
+ return false;
}
// Check if the operator with 'op_name' and type 'T' is an MKL operator that
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index e3b5bb6..6f9262b 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -34,10 +34,10 @@
// This header file is part of MKL ML, need equivalent file in MKL DNN
#ifndef INTEL_MKL_DNN_ONLY
#include "mkl_cblas.h"
-#else
-#include "mkldnn.h"
#endif
+#include "mkldnn.h"
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -63,11 +63,11 @@
dim_pair[0].first = transpose_a_ ? 0 : 1;
dim_pair[0].second = transpose_b_ ? 1 : 0;
- OP_REQUIRES(
- ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
- errors::InvalidArgument(
- "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
- ", In[1]: ", b.shape().DebugString()));
+ OP_REQUIRES(ctx,
+ a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
+ errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
+ a.shape().DebugString(), ", In[1]: ",
+ b.shape().DebugString()));
int a_dim_remaining = 1 - dim_pair[0].first;
int b_dim_remaining = 1 - dim_pair[0].second;
TensorShape out_shape(
@@ -100,8 +100,8 @@
auto b_ptr = (b.template flat<T>().data());
auto c_ptr = (out->template flat<T>().data());
- MklBlasGemm(transpose_a, transpose_b, m, n, k, a_ptr, transpose_a ? m : k,
- b_ptr, transpose_b ? k : n, c_ptr, n);
+ MklBlasGemm(ctx, transpose_a, transpose_b, m, n, k, a_ptr,
+ transpose_a ? m : k, b_ptr, transpose_b ? k : n, c_ptr, n);
}
private:
@@ -147,9 +147,9 @@
// layout, leading dimension is the stride between consecutive rows, max(1,n)
//
// --------------------------------------------------------------------------
- void MklBlasGemm(bool transa, bool transb, const int m, const int n,
- const int k, const float* a, const int lda, const float* b,
- const int ldb, float* c, const int ldc) {
+ void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
+ const int n, const int k, const float* a, const int lda,
+ const float* b, const int ldb, float* c, const int ldc) {
// BLAS GEMM API defines Matrix Multiplication as c = alpha * op(a) * op(b)
// + beta * c.
// Since TF MatMul does not have parameters for alpha, beta, we set them to
@@ -173,14 +173,38 @@
#endif
}
- // MKLDNN only supports SGEMM
+ void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
+ const int n, const int k, const bfloat16* a, const int lda,
+ const bfloat16* b, const int ldb, bfloat16* c,
+ const int ldc) {
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+ const char* const ftrans[] = {"N", "T", "C"};
+ int index_transa = transa ? 1 : 0;
+ int index_transb = transb ? 1 : 0;
+
+ Tensor c_float;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
+
+ // MKL DNN only supports the Fortran api and requires column major while
+ // Tensorflow uses row major so we reverse the order A and B
+ mkldnn_gemm_bf16bf16f32(ftrans[index_transb], ftrans[index_transa], &n, &m,
+ &k, &alpha,
+ reinterpret_cast<const mkldnn_bfloat16_t*>(b), &ldb,
+ reinterpret_cast<const mkldnn_bfloat16_t*>(a), &lda,
+ &beta, c_float.flat<float>().data(), &ldc);
+
+ FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
+ }
+
+// MKLDNN only supports SGEMM
#ifndef INTEL_MKL_DNN_ONLY
// Matrix-Matrix Multiplication with FP64 tensors. For detailed info about
// parameters, look at FP32 function description.
- void MklBlasGemm(bool transa, bool transb, const int m, const int n,
- const int k, const double* a, const int lda, const double* b,
- const int ldb, double* c, const int ldc) {
+ void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
+ const int n, const int k, const double* a, const int lda,
+ const double* b, const int ldb, double* c, const int ldc) {
const double alpha = 1.0;
const double beta = 0.0;
cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
@@ -190,8 +214,8 @@
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
// For detailed info about parameters, look at FP32 function description.
- void MklBlasGemm(bool transa, bool transb, const int m, const int n,
- const int k, const complex64* a, const int lda,
+ void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
+ const int n, const int k, const complex64* a, const int lda,
const complex64* b, const int ldb, complex64* c,
int const ldc) {
const MKL_Complex8 alpha = {1.0f, 0.0f};
@@ -206,8 +230,8 @@
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
// tensors. For detailed info about parameters, look at FP32 function
// description.
- void MklBlasGemm(bool transa, bool transb, const int m, const int n,
- const int k, const complex128* a, const int lda,
+ void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
+ const int n, const int k, const complex128* a, const int lda,
const complex128* b, const int ldb, complex128* c,
const int ldc) {
const MKL_Complex16 alpha = {1.0, 0.0};
@@ -233,6 +257,7 @@
// TODO(inteltf) Consider template specialization when adding/removing
// additional types
TF_CALL_float(REGISTER_CPU);
+TF_CALL_bfloat16(REGISTER_CPU);
#ifndef INTEL_MKL_DNN_ONLY
TF_CALL_double(REGISTER_CPU);
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index e84d007..acbedaa 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -17,7 +17,6 @@
#ifdef INTEL_MKL
#include "mkldnn.hpp"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -25,6 +24,7 @@
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
using mkldnn::prop_kind;
using mkldnn::softmax_forward;
@@ -168,9 +168,9 @@
net.push_back(softmax_fwd);
stream(stream::kind::eager).submit(net).wait();
} catch (mkldnn::error& e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) + ", in file " +
- string(__FILE__) + ":" + std::to_string(__LINE__);
+ string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+ string(e.message) + ", in file " + string(__FILE__) +
+ ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
@@ -188,6 +188,7 @@
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklSoftmaxOp<CPUDevice, type>);
TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
+TF_CALL_bfloat16(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc
index f6d8470..8598bf3 100644
--- a/tensorflow/core/kernels/mkl_transpose_op.cc
+++ b/tensorflow/core/kernels/mkl_transpose_op.cc
@@ -78,9 +78,8 @@
mkl_comatcopy(
'R', trans, in.dim_size(0), in.dim_size(1), alpha,
reinterpret_cast<const MKL_Complex8*>(in.flat<complex64>().data()),
- in.dim_size(1),
- reinterpret_cast<MKL_Complex8*>(
- const_cast<complex64*>(out->flat<complex64>().data())),
+ in.dim_size(1), reinterpret_cast<MKL_Complex8*>(const_cast<complex64*>(
+ out->flat<complex64>().data())),
in.dim_size(0));
return Status::OK();
}
@@ -92,9 +91,8 @@
mkl_zomatcopy(
'R', trans, in.dim_size(0), in.dim_size(1), alpha,
reinterpret_cast<const MKL_Complex16*>(in.flat<complex128>().data()),
- in.dim_size(1),
- reinterpret_cast<MKL_Complex16*>(
- const_cast<complex128*>(out->flat<complex128>().data())),
+ in.dim_size(1), reinterpret_cast<MKL_Complex16*>(const_cast<complex128*>(
+ out->flat<complex128>().data())),
in.dim_size(0));
return Status::OK();
}
@@ -145,8 +143,8 @@
stream(stream::kind::eager).submit(net).wait();
return Status::OK();
} catch (mkldnn::error& e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + std::string(e.message) + ", in file " +
+ string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+ std::string(e.message) + ", in file " +
std::string(__FILE__) + ":" + std::to_string(__LINE__);
return errors::Aborted("Operation received an exception:", error_msg);
}
@@ -184,10 +182,9 @@
case DT_FLOAT:
return MKLTransposeND<float>(ctx, in, out, perm);
break;
- // TODO(nhasabni): Enable this case when we turn on bfloat16 compilation.
- // case DT_BFLOAT16:
- // return MKLTransposeND<bfloat16>(ctx, in, out, perm);
- // break;
+ case DT_BFLOAT16:
+ return MKLTransposeND<bfloat16>(ctx, in, out, perm);
+ break;
// TODO(nhasabni): support other types such as INT8.
default:
break;
@@ -232,10 +229,9 @@
case DT_FLOAT:
return MKLTransposeND<float>(ctx, in, out, perm);
break;
- // TODO(nhasabni): Enable this case when we turn on bfloat16 compilation.
- // case DT_BFLOAT16:
- // return MKLTransposeND<bfloat16>(ctx, in, out, perm);
- // break;
+ case DT_BFLOAT16:
+ return MKLTransposeND<bfloat16>(ctx, in, out, perm);
+ break;
// TODO(nhasabni): support other types such as INT8.
default:
break;
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index dde2db3..660fc9d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -917,7 +917,7 @@
.Output("product: T")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
- .Attr("T: {float, double, complex64, complex128}")
+ .Attr("T: {bfloat16, float, double, complex64, complex128}")
.SetShapeFn(shape_inference::MatMulShape);
#endif // INTEL_MKL
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 4d10054..ae4ccdf 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -2200,7 +2200,7 @@
.Input("mkl_logits: uint8")
.Output("softmax: T")
.Output("mkl_softmax: uint8")
- .Attr("T: {half, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
})