[Intel MKL] Supporting MatMul, Transpose and Softmax with BFloat16 type

This PR enables _MklMatMul, _MklTranspose and _MklSoftmax with BFloat16 type.
Some of the changes are suggested by Clang format checker.
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index ba2526b..fd6e2be 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -179,7 +179,16 @@
   search_string += string(";") + string(" T in [");
   search_string += DataType_Name(T) + string("]");
 
-  return kernel.find(search_string) != string::npos;
+  // Temporarily replacing earlier check by adding a type-specific check so
+  // that we can selectively decide which type is support by MKL operators.
+  // That way kernel registration does not decide which operators we support.
+  // We are using this change to temporarily disable BFLOAT16 support. Once
+  // we want to enable it, we will go back to earlier check.
+  if (kernel.find(search_string) != string::npos) {
+    return T == DT_COMPLEX128 || T == DT_COMPLEX64 || T == DT_DOUBLE ||
+           T == DT_FLOAT;
+  }
+  return false;
 }
 
 // Check if the operator with 'op_name' and type 'T' is an MKL operator that
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index e3b5bb6..6f9262b 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -34,10 +34,10 @@
 // This header file is part of MKL ML, need equivalent file in MKL DNN
 #ifndef INTEL_MKL_DNN_ONLY
 #include "mkl_cblas.h"
-#else
-#include "mkldnn.h"
 #endif
 
+#include "mkldnn.h"
+
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -63,11 +63,11 @@
     dim_pair[0].first = transpose_a_ ? 0 : 1;
     dim_pair[0].second = transpose_b_ ? 1 : 0;
 
-    OP_REQUIRES(
-        ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
-        errors::InvalidArgument(
-            "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
-            ", In[1]: ", b.shape().DebugString()));
+    OP_REQUIRES(ctx,
+                a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
+                errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
+                                        a.shape().DebugString(), ", In[1]: ",
+                                        b.shape().DebugString()));
     int a_dim_remaining = 1 - dim_pair[0].first;
     int b_dim_remaining = 1 - dim_pair[0].second;
     TensorShape out_shape(
@@ -100,8 +100,8 @@
     auto b_ptr = (b.template flat<T>().data());
     auto c_ptr = (out->template flat<T>().data());
 
-    MklBlasGemm(transpose_a, transpose_b, m, n, k, a_ptr, transpose_a ? m : k,
-                b_ptr, transpose_b ? k : n, c_ptr, n);
+    MklBlasGemm(ctx, transpose_a, transpose_b, m, n, k, a_ptr,
+                transpose_a ? m : k, b_ptr, transpose_b ? k : n, c_ptr, n);
   }
 
  private:
@@ -147,9 +147,9 @@
   // layout, leading dimension is the stride between consecutive rows, max(1,n)
   //
   // --------------------------------------------------------------------------
-  void MklBlasGemm(bool transa, bool transb, const int m, const int n,
-                   const int k, const float* a, const int lda, const float* b,
-                   const int ldb, float* c, const int ldc) {
+  void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
+                   const int n, const int k, const float* a, const int lda,
+                   const float* b, const int ldb, float* c, const int ldc) {
     // BLAS GEMM API defines Matrix Multiplication as c = alpha * op(a) * op(b)
     // + beta * c.
     // Since TF MatMul does not have parameters for alpha, beta, we set them to
@@ -173,14 +173,38 @@
 #endif
   }
 
-  // MKLDNN only supports SGEMM
+  void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
+                   const int n, const int k, const bfloat16* a, const int lda,
+                   const bfloat16* b, const int ldb, bfloat16* c,
+                   const int ldc) {
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+    const char* const ftrans[] = {"N", "T", "C"};
+    int index_transa = transa ? 1 : 0;
+    int index_transb = transb ? 1 : 0;
+
+    Tensor c_float;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
+
+    // MKL DNN only supports the Fortran api and requires column major while
+    // Tensorflow uses row major so we reverse the order A and B
+    mkldnn_gemm_bf16bf16f32(ftrans[index_transb], ftrans[index_transa], &n, &m,
+                            &k, &alpha,
+                            reinterpret_cast<const mkldnn_bfloat16_t*>(b), &ldb,
+                            reinterpret_cast<const mkldnn_bfloat16_t*>(a), &lda,
+                            &beta, c_float.flat<float>().data(), &ldc);
+
+    FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
+  }
+
+// MKLDNN only supports SGEMM
 #ifndef INTEL_MKL_DNN_ONLY
 
   // Matrix-Matrix Multiplication with FP64 tensors. For detailed info about
   // parameters, look at FP32 function description.
-  void MklBlasGemm(bool transa, bool transb, const int m, const int n,
-                   const int k, const double* a, const int lda, const double* b,
-                   const int ldb, double* c, const int ldc) {
+  void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
+                   const int n, const int k, const double* a, const int lda,
+                   const double* b, const int ldb, double* c, const int ldc) {
     const double alpha = 1.0;
     const double beta = 0.0;
     cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
@@ -190,8 +214,8 @@
 
   // Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
   // For detailed info about parameters, look at FP32 function description.
-  void MklBlasGemm(bool transa, bool transb, const int m, const int n,
-                   const int k, const complex64* a, const int lda,
+  void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
+                   const int n, const int k, const complex64* a, const int lda,
                    const complex64* b, const int ldb, complex64* c,
                    int const ldc) {
     const MKL_Complex8 alpha = {1.0f, 0.0f};
@@ -206,8 +230,8 @@
   // Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
   // tensors. For detailed info about parameters, look at FP32 function
   // description.
-  void MklBlasGemm(bool transa, bool transb, const int m, const int n,
-                   const int k, const complex128* a, const int lda,
+  void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
+                   const int n, const int k, const complex128* a, const int lda,
                    const complex128* b, const int ldb, complex128* c,
                    const int ldc) {
     const MKL_Complex16 alpha = {1.0, 0.0};
@@ -233,6 +257,7 @@
 // TODO(inteltf) Consider template specialization when adding/removing
 // additional types
 TF_CALL_float(REGISTER_CPU);
+TF_CALL_bfloat16(REGISTER_CPU);
 
 #ifndef INTEL_MKL_DNN_ONLY
 TF_CALL_double(REGISTER_CPU);
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index e84d007..acbedaa 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -17,7 +17,6 @@
 #ifdef INTEL_MKL
 
 #include "mkldnn.hpp"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -25,6 +24,7 @@
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/util/mkl_util.h"
 #include "tensorflow/core/util/tensor_format.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 
 using mkldnn::prop_kind;
 using mkldnn::softmax_forward;
@@ -168,9 +168,9 @@
       net.push_back(softmax_fwd);
       stream(stream::kind::eager).submit(net).wait();
     } catch (mkldnn::error& e) {
-      string error_msg = "Status: " + std::to_string(e.status) +
-                         ", message: " + string(e.message) + ", in file " +
-                         string(__FILE__) + ":" + std::to_string(__LINE__);
+      string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+                         string(e.message) + ", in file " + string(__FILE__) +
+                         ":" + std::to_string(__LINE__);
       OP_REQUIRES_OK(
           context,
           errors::Aborted("Operation received an exception:", error_msg));
@@ -188,6 +188,7 @@
           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
       MklSoftmaxOp<CPUDevice, type>);
 TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
+TF_CALL_bfloat16(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc
index f6d8470..8598bf3 100644
--- a/tensorflow/core/kernels/mkl_transpose_op.cc
+++ b/tensorflow/core/kernels/mkl_transpose_op.cc
@@ -78,9 +78,8 @@
   mkl_comatcopy(
       'R', trans, in.dim_size(0), in.dim_size(1), alpha,
       reinterpret_cast<const MKL_Complex8*>(in.flat<complex64>().data()),
-      in.dim_size(1),
-      reinterpret_cast<MKL_Complex8*>(
-          const_cast<complex64*>(out->flat<complex64>().data())),
+      in.dim_size(1), reinterpret_cast<MKL_Complex8*>(const_cast<complex64*>(
+                          out->flat<complex64>().data())),
       in.dim_size(0));
   return Status::OK();
 }
@@ -92,9 +91,8 @@
   mkl_zomatcopy(
       'R', trans, in.dim_size(0), in.dim_size(1), alpha,
       reinterpret_cast<const MKL_Complex16*>(in.flat<complex128>().data()),
-      in.dim_size(1),
-      reinterpret_cast<MKL_Complex16*>(
-          const_cast<complex128*>(out->flat<complex128>().data())),
+      in.dim_size(1), reinterpret_cast<MKL_Complex16*>(const_cast<complex128*>(
+                          out->flat<complex128>().data())),
       in.dim_size(0));
   return Status::OK();
 }
@@ -145,8 +143,8 @@
     stream(stream::kind::eager).submit(net).wait();
     return Status::OK();
   } catch (mkldnn::error& e) {
-    string error_msg = "Status: " + std::to_string(e.status) +
-                       ", message: " + std::string(e.message) + ", in file " +
+    string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+                       std::string(e.message) + ", in file " +
                        std::string(__FILE__) + ":" + std::to_string(__LINE__);
     return errors::Aborted("Operation received an exception:", error_msg);
   }
@@ -184,10 +182,9 @@
       case DT_FLOAT:
         return MKLTransposeND<float>(ctx, in, out, perm);
         break;
-      // TODO(nhasabni): Enable this case when we turn on bfloat16 compilation.
-      // case DT_BFLOAT16:
-      //  return MKLTransposeND<bfloat16>(ctx, in, out, perm);
-      //  break;
+      case DT_BFLOAT16:
+        return MKLTransposeND<bfloat16>(ctx, in, out, perm);
+        break;
       // TODO(nhasabni): support other types such as INT8.
       default:
         break;
@@ -232,10 +229,9 @@
       case DT_FLOAT:
         return MKLTransposeND<float>(ctx, in, out, perm);
         break;
-      // TODO(nhasabni): Enable this case when we turn on bfloat16 compilation.
-      // case DT_BFLOAT16:
-      //  return MKLTransposeND<bfloat16>(ctx, in, out, perm);
-      //  break;
+      case DT_BFLOAT16:
+        return MKLTransposeND<bfloat16>(ctx, in, out, perm);
+        break;
       // TODO(nhasabni): support other types such as INT8.
       default:
         break;
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index dde2db3..660fc9d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -917,7 +917,7 @@
     .Output("product: T")
     .Attr("transpose_a: bool = false")
     .Attr("transpose_b: bool = false")
-    .Attr("T: {float, double, complex64, complex128}")
+    .Attr("T: {bfloat16, float, double, complex64, complex128}")
     .SetShapeFn(shape_inference::MatMulShape);
 #endif  // INTEL_MKL
 
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 4d10054..ae4ccdf 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -2200,7 +2200,7 @@
     .Input("mkl_logits: uint8")
     .Output("softmax: T")
     .Output("mkl_softmax: uint8")
-    .Attr("T: {half, float, double}")
+    .Attr("T: {bfloat16, half, float, double}")
     .SetShapeFn([](InferenceContext* c) {
       return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
     })