Add support for MklBatchMatMulv2
diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
index c487aa9..c862cdd 100644
--- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
+++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
@@ -45,9 +45,10 @@
   static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name,
                            std::unique_ptr<EagerOperation>* new_mkl_op);
 
-  // Creates new MKL op for MatMul
-  static Status CreateMklMatMul(EagerOperation* orig_op,
-                                std::unique_ptr<EagerOperation>* mkl_matmul_op);
+  // Generic rewrite that can be used for any mkl op that doesn't need
+  // special processing.
+  static Status CreateGenericMklOp(EagerOperation* orig_op,
+                                   std::unique_ptr<EagerOperation>* mkl_op);
 
   // Creates new MKL op for Conv2D, Conv2DBackpropInput and
   // Conv2DBackpropFilter.
@@ -75,12 +76,16 @@
 // Constructor
 MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
     : EagerOpRewrite(name, file, line) {
+  mkl_eager_ops_.push_back({"BatchMatMulV2", AlwaysRewrite,
+                            CreateGenericMklOp});  // No need to check for V1 as
+                                                   // it has been obsoleted
+                                                   // already
   mkl_eager_ops_.push_back({"Conv2D", RewriteConv2D, CreateMklConv2DOp});
   mkl_eager_ops_.push_back(
       {"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp});
   mkl_eager_ops_.push_back(
       {"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp});
-  mkl_eager_ops_.push_back({"MatMul", AlwaysRewrite, CreateMklMatMul});
+  mkl_eager_ops_.push_back({"MatMul", AlwaysRewrite, CreateGenericMklOp});
 }
 
 Status MklEagerOpRewrite::Run(
@@ -133,10 +138,10 @@
   return Status::OK();
 }
 
-Status MklEagerOpRewrite::CreateMklMatMul(
-    EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_matmul_op) {
+Status MklEagerOpRewrite::CreateGenericMklOp(
+    EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op) {
   const string mkl_op_name = mkl_op_registry::GetMklOpName(orig_op->Name());
-  TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_matmul_op));
+  TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_op));
   return Status::OK();
 }
 
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index c97cbd8..5ba4fc6 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -246,6 +246,7 @@
     csinfo_.avg_pool3d = "AvgPool3D";
     csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
     csinfo_.batch_matmul = "BatchMatMul";
+    csinfo_.batch_matmul_v2 = "BatchMatMulV2";
     csinfo_.bias_add = "BiasAdd";
     csinfo_.bias_add_grad = "BiasAddGrad";
     csinfo_.concat = "Concat";
@@ -380,6 +381,9 @@
     rinfo_.push_back({csinfo_.batch_matmul,
                       mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
                       CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
+    rinfo_.push_back({csinfo_.batch_matmul_v2,
+                      mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
+                      CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
     rinfo_.push_back(
         {csinfo_.concat, mkl_op_registry::GetMklOpName(csinfo_.concat),
          CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
@@ -868,6 +872,7 @@
     string avg_pool3d;
     string avg_pool3d_grad;
     string batch_matmul;
+    string batch_matmul_v2;
     string bias_add;
     string bias_add_grad;
     string concat;
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index df54c9f..b69a30e 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -4307,6 +4307,39 @@
             "H->K:7;I->K:8;J->L:1;K->L");
 }
 
+TEST_F(MklLayoutPassTest, MatMul_Positive) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'C' op: 'MatMul'"
+      " attr { key: 'T'      value { type: DT_FLOAT } }"
+      " input: ['A', 'B']}");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(_MklMatMul)|A->C;B->C:1");
+}
+
+TEST_F(MklLayoutPassTest, BatchMatMul_Positive) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'C' op: 'BatchMatMul'"
+      " attr { key: 'T'      value { type: DT_FLOAT } }"
+      " input: ['A', 'B']}");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(_MklBatchMatMul)|A->C;B->C:1");
+}
+
+TEST_F(MklLayoutPassTest, BatchMatMulV2_Positive) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'C' op: 'BatchMatMulV2'"
+      " attr { key: 'T'      value { type: DT_FLOAT } }"
+      " input: ['A', 'B']}");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(_MklBatchMatMulV2)|A->C;B->C:1");
+}
+
 static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
   testing::StopTiming();
   string s;
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 5a0401c..956ed97 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -29,7 +29,6 @@
 #include <vector>
 
 #include "mkl_cblas.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -41,13 +40,17 @@
 #include "tensorflow/core/kernels/fill_functor.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/matmul_bcast.h"
 #include "tensorflow/core/util/mkl_util.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
 
-template <typename Device, typename Scalar>
+//  The third parameter v2_bcast is set to true if we are using V2 otherwise
+//  we set it to false.
+template <typename Device, typename Scalar, bool v2_bcast>
 class BatchMatMulMkl : public OpKernel {
  public:
   explicit BatchMatMulMkl(OpKernelConstruction *context) : OpKernel(context) {
@@ -60,28 +63,54 @@
   void Compute(OpKernelContext *ctx) override {
     const Tensor &lhs = ctx->input(0);
     const Tensor &rhs = ctx->input(1);
-    OP_REQUIRES(ctx, lhs.dims() == rhs.dims(),
-                errors::InvalidArgument("lhs and rhs has different ndims: ",
-                                        lhs.shape().DebugString(), " vs. ",
-                                        rhs.shape().DebugString()));
-    const int ndims = lhs.dims();
-    OP_REQUIRES(
-        ctx, ndims >= 2,
-        errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims));
-    TensorShape out_shape;
-    for (int i = 0; i < ndims - 2; ++i) {
-      OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
-                  errors::InvalidArgument(
-                      "lhs.dim(", i, ") and rhs.dim(", i,
-                      ") must be the same: ", lhs.shape().DebugString(), " vs ",
-                      rhs.shape().DebugString()));
-      out_shape.AddDim(lhs.dim_size(i));
+
+    if (!v2_bcast) {
+      // Using V1, so check to make sure lhs and rhs dimensions are correct and
+      // no braocasting is needed.
+      OP_REQUIRES(ctx, lhs.dims() == rhs.dims(),
+                  errors::InvalidArgument("lhs and rhs has different ndims: ",
+                                          lhs.shape().DebugString(), " vs. ",
+                                          rhs.shape().DebugString()));
+      const int ndims = lhs.dims();
+      OP_REQUIRES(
+          ctx, ndims >= 2,
+          errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims));
+      for (int i = 0; i < ndims - 2; ++i) {
+        OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
+                    errors::InvalidArgument("lhs.dim(", i, ") and rhs.dim(", i,
+                                            ") must be the same: ",
+                                            lhs.shape().DebugString(), " vs ",
+                                            rhs.shape().DebugString()));
+      }
+    } else {
+      OP_REQUIRES(
+          ctx, lhs.dims() >= 2,
+          errors::InvalidArgument("In[0] ndims must be >= 2: ", lhs.dims()));
+      OP_REQUIRES(
+          ctx, rhs.dims() >= 2,
+          errors::InvalidArgument("In[1] ndims must be >= 2: ", rhs.dims()));
     }
-    auto batch_size = (ndims == 2) ? 1 : out_shape.num_elements();
-    auto lhs_rows = lhs.dim_size(ndims - 2);
-    auto lhs_cols = lhs.dim_size(ndims - 1);
-    auto rhs_rows = rhs.dim_size(ndims - 2);
-    auto rhs_cols = rhs.dim_size(ndims - 1);
+
+    // lhs and rhs can have different dimensions
+    const int ndims_lhs = lhs.dims();
+    const int ndims_rhs = rhs.dims();
+
+    // Get broadcast info
+    MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes());
+    OP_REQUIRES(
+        ctx, bcast.IsValid(),
+        errors::InvalidArgument(
+            "In[0] and In[1] must have compatible batch dimensions: ",
+            lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString()));
+
+    TensorShape out_shape = bcast.output_batch_shape();
+    auto batch_size = bcast.output_batch_size();
+
+    auto lhs_rows = lhs.dim_size(ndims_lhs - 2);
+    auto lhs_cols = lhs.dim_size(ndims_lhs - 1);
+    auto rhs_rows = rhs.dim_size(ndims_rhs - 2);
+    auto rhs_cols = rhs.dim_size(ndims_rhs - 1);
+
     if (adj_x_) std::swap(lhs_rows, lhs_cols);
     if (adj_y_) std::swap(rhs_rows, rhs_cols);
     OP_REQUIRES(ctx, lhs_cols == rhs_rows,
@@ -89,8 +118,10 @@
                     "lhs mismatch rhs shape: ", lhs_cols, " vs. ", rhs_rows,
                     ": ", lhs.shape().DebugString(), " ",
                     rhs.shape().DebugString(), " ", adj_x_, " ", adj_y_));
+
     out_shape.AddDim(lhs_rows);
     out_shape.AddDim(rhs_cols);
+
     Tensor *out = nullptr;
     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
     if (out->NumElements() == 0) {
@@ -122,10 +153,24 @@
     a_array.reserve(batch_size);
     b_array.reserve(batch_size);
     c_array.reserve(batch_size);
-    for (int64 i = 0; i < batch_size; i++) {
-      a_array.push_back(&lhs_reshaped(i, 0, 0));
-      b_array.push_back(&rhs_reshaped(i, 0, 0));
-      c_array.push_back(&out_reshaped(i, 0, 0));
+
+    if (!bcast.IsBroadcastingRequired()) {
+      for (int64 i = 0; i < batch_size; i++) {
+        a_array.push_back(&lhs_reshaped(i, 0, 0));
+        b_array.push_back(&rhs_reshaped(i, 0, 0));
+        c_array.push_back(&out_reshaped(i, 0, 0));
+      }
+    } else {
+      // Broadcasting is needed, so get the mapping from flattened output batch
+      // indices to x's and y's flattened batch indices.
+      const std::vector<int64> &a_batch_indices = bcast.x_batch_indices();
+      const std::vector<int64> &b_batch_indices = bcast.y_batch_indices();
+
+      for (int64 i = 0; i < batch_size; i++) {
+        a_array.push_back(&lhs_reshaped(a_batch_indices[i], 0, 0));
+        b_array.push_back(&rhs_reshaped(b_batch_indices[i], 0, 0));
+        c_array.push_back(&out_reshaped(i, 0, 0));
+      }
     }
 
     MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, &m_array[0], &n_array[0],
@@ -226,13 +271,25 @@
                               .Device(DEVICE_CPU)                             \
                               .TypeConstraint<TYPE>("T")                      \
                               .Label(mkl_op_registry::kMklNameChangeOpLabel), \
-                          BatchMatMulMkl<CPUDevice, TYPE>)
+                          BatchMatMulMkl<CPUDevice, TYPE, false>)
+
+#define REGISTER_BATCH_MATMUL_MKL_V2(TYPE)                                    \
+  REGISTER_KERNEL_BUILDER(Name("_MklBatchMatMulV2")                           \
+                              .Device(DEVICE_CPU)                             \
+                              .TypeConstraint<TYPE>("T")                      \
+                              .Label(mkl_op_registry::kMklNameChangeOpLabel), \
+                          BatchMatMulMkl<CPUDevice, TYPE, true>)
 
 #ifdef ENABLE_MKL
 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
 TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
 TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
 TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
+
+TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2);
+TF_CALL_double(REGISTER_BATCH_MATMUL_MKL_V2);
+TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL_V2);
+TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL_V2);
 #endif  // ENABLE_MKL
 
 }  // end namespace tensorflow
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index e681262..b9947cb 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -66,8 +66,8 @@
           } else if (shapes_and_types && shapes_and_types_i) {
             if (shapes_and_types_i->size() != shapes_and_types->size()) {
               return errors::InvalidArgument(
-                  "shapes_and_types[", i,
-                  "].size() == ", shapes_and_types_i->size(),
+                  "shapes_and_types[", i, "].size() == ",
+                  shapes_and_types_i->size(),
                   " != shapes_and_types[0].size() == ",
                   shapes_and_types->size());
             }
@@ -142,10 +142,21 @@
     .Input("x: T")
     .Input("y: T")
     .Output("output: T")
-    .Attr("T: {bfloat16, half, float, double, int32, complex64, complex128}")
+    .Attr(
+        "T: {bfloat16, half, float, double, int32, int64, complex64, "
+        "complex128}")
     .Attr("adj_x: bool = false")
     .Attr("adj_y: bool = false")
     .SetShapeFn(shape_inference::BatchMatMulShape);
+
+REGISTER_OP("_MklBatchMatMulV2")
+    .Input("x: T")
+    .Input("y: T")
+    .Output("output: T")
+    .Attr("T: {bfloat16, half, float, double, int32, complex64, complex128}")
+    .Attr("adj_x: bool = false")
+    .Attr("adj_y: bool = false")
+    .SetShapeFn(shape_inference::BatchMatMulV2Shape);
 #endif  // INTEL_MKL
 
 // --------------------------------------------------------------------------
@@ -1355,12 +1366,12 @@
   T limit = limit_t->scalar<T>()();
   T delta = delta_t->scalar<T>()();
   if (start > limit && delta > 0) {
-    return errors::InvalidArgument(
-        "Requires start <= limit when delta > 0: ", start, "/", limit);
+    return errors::InvalidArgument("Requires start <= limit when delta > 0: ",
+                                   start, "/", limit);
   }
   if (start < limit && delta < 0) {
-    return errors::InvalidArgument(
-        "Requires start >= limit when delta < 0: ", start, "/", limit);
+    return errors::InvalidArgument("Requires start >= limit when delta < 0: ",
+                                   start, "/", limit);
   }
   if (delta == 0) {
     return errors::InvalidArgument("Requires delta != 0");