Add support for MklBatchMatMulv2
diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
index c487aa9..c862cdd 100644
--- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
+++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
@@ -45,9 +45,10 @@
static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name,
std::unique_ptr<EagerOperation>* new_mkl_op);
- // Creates new MKL op for MatMul
- static Status CreateMklMatMul(EagerOperation* orig_op,
- std::unique_ptr<EagerOperation>* mkl_matmul_op);
+ // Generic rewrite that can be used for any mkl op that doesn't need
+ // special processing.
+ static Status CreateGenericMklOp(EagerOperation* orig_op,
+ std::unique_ptr<EagerOperation>* mkl_op);
// Creates new MKL op for Conv2D, Conv2DBackpropInput and
// Conv2DBackpropFilter.
@@ -75,12 +76,16 @@
// Constructor
MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
: EagerOpRewrite(name, file, line) {
+ mkl_eager_ops_.push_back({"BatchMatMulV2", AlwaysRewrite,
+ CreateGenericMklOp}); // No need to check for V1 as
+ // it has been obsoleted
+ // already
mkl_eager_ops_.push_back({"Conv2D", RewriteConv2D, CreateMklConv2DOp});
mkl_eager_ops_.push_back(
{"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp});
mkl_eager_ops_.push_back(
{"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp});
- mkl_eager_ops_.push_back({"MatMul", AlwaysRewrite, CreateMklMatMul});
+ mkl_eager_ops_.push_back({"MatMul", AlwaysRewrite, CreateGenericMklOp});
}
Status MklEagerOpRewrite::Run(
@@ -133,10 +138,10 @@
return Status::OK();
}
-Status MklEagerOpRewrite::CreateMklMatMul(
- EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_matmul_op) {
+Status MklEagerOpRewrite::CreateGenericMklOp(
+ EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op) {
const string mkl_op_name = mkl_op_registry::GetMklOpName(orig_op->Name());
- TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_matmul_op));
+ TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_op));
return Status::OK();
}
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index c97cbd8..5ba4fc6 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -246,6 +246,7 @@
csinfo_.avg_pool3d = "AvgPool3D";
csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
csinfo_.batch_matmul = "BatchMatMul";
+ csinfo_.batch_matmul_v2 = "BatchMatMulV2";
csinfo_.bias_add = "BiasAdd";
csinfo_.bias_add_grad = "BiasAddGrad";
csinfo_.concat = "Concat";
@@ -380,6 +381,9 @@
rinfo_.push_back({csinfo_.batch_matmul,
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
+ rinfo_.push_back({csinfo_.batch_matmul_v2,
+ mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
+ CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.concat, mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
@@ -868,6 +872,7 @@
string avg_pool3d;
string avg_pool3d_grad;
string batch_matmul;
+ string batch_matmul_v2;
string bias_add;
string bias_add_grad;
string concat;
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index df54c9f..b69a30e 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -4307,6 +4307,39 @@
"H->K:7;I->K:8;J->L:1;K->L");
}
+TEST_F(MklLayoutPassTest, MatMul_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'MatMul'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']}");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(_MklMatMul)|A->C;B->C:1");
+}
+
+TEST_F(MklLayoutPassTest, BatchMatMul_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'BatchMatMul'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']}");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(_MklBatchMatMul)|A->C;B->C:1");
+}
+
+TEST_F(MklLayoutPassTest, BatchMatMulV2_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'BatchMatMulV2'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']}");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(_MklBatchMatMulV2)|A->C;B->C:1");
+}
+
static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
testing::StopTiming();
string s;
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 5a0401c..956ed97 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -29,7 +29,6 @@
#include <vector>
#include "mkl_cblas.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -41,13 +40,17 @@
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/matmul_bcast.h"
#include "tensorflow/core/util/mkl_util.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-template <typename Device, typename Scalar>
+// The third parameter v2_bcast is set to true if we are using V2 otherwise
+// we set it to false.
+template <typename Device, typename Scalar, bool v2_bcast>
class BatchMatMulMkl : public OpKernel {
public:
explicit BatchMatMulMkl(OpKernelConstruction *context) : OpKernel(context) {
@@ -60,28 +63,54 @@
void Compute(OpKernelContext *ctx) override {
const Tensor &lhs = ctx->input(0);
const Tensor &rhs = ctx->input(1);
- OP_REQUIRES(ctx, lhs.dims() == rhs.dims(),
- errors::InvalidArgument("lhs and rhs has different ndims: ",
- lhs.shape().DebugString(), " vs. ",
- rhs.shape().DebugString()));
- const int ndims = lhs.dims();
- OP_REQUIRES(
- ctx, ndims >= 2,
- errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims));
- TensorShape out_shape;
- for (int i = 0; i < ndims - 2; ++i) {
- OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
- errors::InvalidArgument(
- "lhs.dim(", i, ") and rhs.dim(", i,
- ") must be the same: ", lhs.shape().DebugString(), " vs ",
- rhs.shape().DebugString()));
- out_shape.AddDim(lhs.dim_size(i));
+
+ if (!v2_bcast) {
+ // Using V1, so check to make sure lhs and rhs dimensions are correct and
+ // no braocasting is needed.
+ OP_REQUIRES(ctx, lhs.dims() == rhs.dims(),
+ errors::InvalidArgument("lhs and rhs has different ndims: ",
+ lhs.shape().DebugString(), " vs. ",
+ rhs.shape().DebugString()));
+ const int ndims = lhs.dims();
+ OP_REQUIRES(
+ ctx, ndims >= 2,
+ errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims));
+ for (int i = 0; i < ndims - 2; ++i) {
+ OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
+ errors::InvalidArgument("lhs.dim(", i, ") and rhs.dim(", i,
+ ") must be the same: ",
+ lhs.shape().DebugString(), " vs ",
+ rhs.shape().DebugString()));
+ }
+ } else {
+ OP_REQUIRES(
+ ctx, lhs.dims() >= 2,
+ errors::InvalidArgument("In[0] ndims must be >= 2: ", lhs.dims()));
+ OP_REQUIRES(
+ ctx, rhs.dims() >= 2,
+ errors::InvalidArgument("In[1] ndims must be >= 2: ", rhs.dims()));
}
- auto batch_size = (ndims == 2) ? 1 : out_shape.num_elements();
- auto lhs_rows = lhs.dim_size(ndims - 2);
- auto lhs_cols = lhs.dim_size(ndims - 1);
- auto rhs_rows = rhs.dim_size(ndims - 2);
- auto rhs_cols = rhs.dim_size(ndims - 1);
+
+ // lhs and rhs can have different dimensions
+ const int ndims_lhs = lhs.dims();
+ const int ndims_rhs = rhs.dims();
+
+ // Get broadcast info
+ MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes());
+ OP_REQUIRES(
+ ctx, bcast.IsValid(),
+ errors::InvalidArgument(
+ "In[0] and In[1] must have compatible batch dimensions: ",
+ lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString()));
+
+ TensorShape out_shape = bcast.output_batch_shape();
+ auto batch_size = bcast.output_batch_size();
+
+ auto lhs_rows = lhs.dim_size(ndims_lhs - 2);
+ auto lhs_cols = lhs.dim_size(ndims_lhs - 1);
+ auto rhs_rows = rhs.dim_size(ndims_rhs - 2);
+ auto rhs_cols = rhs.dim_size(ndims_rhs - 1);
+
if (adj_x_) std::swap(lhs_rows, lhs_cols);
if (adj_y_) std::swap(rhs_rows, rhs_cols);
OP_REQUIRES(ctx, lhs_cols == rhs_rows,
@@ -89,8 +118,10 @@
"lhs mismatch rhs shape: ", lhs_cols, " vs. ", rhs_rows,
": ", lhs.shape().DebugString(), " ",
rhs.shape().DebugString(), " ", adj_x_, " ", adj_y_));
+
out_shape.AddDim(lhs_rows);
out_shape.AddDim(rhs_cols);
+
Tensor *out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
if (out->NumElements() == 0) {
@@ -122,10 +153,24 @@
a_array.reserve(batch_size);
b_array.reserve(batch_size);
c_array.reserve(batch_size);
- for (int64 i = 0; i < batch_size; i++) {
- a_array.push_back(&lhs_reshaped(i, 0, 0));
- b_array.push_back(&rhs_reshaped(i, 0, 0));
- c_array.push_back(&out_reshaped(i, 0, 0));
+
+ if (!bcast.IsBroadcastingRequired()) {
+ for (int64 i = 0; i < batch_size; i++) {
+ a_array.push_back(&lhs_reshaped(i, 0, 0));
+ b_array.push_back(&rhs_reshaped(i, 0, 0));
+ c_array.push_back(&out_reshaped(i, 0, 0));
+ }
+ } else {
+ // Broadcasting is needed, so get the mapping from flattened output batch
+ // indices to x's and y's flattened batch indices.
+ const std::vector<int64> &a_batch_indices = bcast.x_batch_indices();
+ const std::vector<int64> &b_batch_indices = bcast.y_batch_indices();
+
+ for (int64 i = 0; i < batch_size; i++) {
+ a_array.push_back(&lhs_reshaped(a_batch_indices[i], 0, 0));
+ b_array.push_back(&rhs_reshaped(b_batch_indices[i], 0, 0));
+ c_array.push_back(&out_reshaped(i, 0, 0));
+ }
}
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, &m_array[0], &n_array[0],
@@ -226,13 +271,25 @@
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
- BatchMatMulMkl<CPUDevice, TYPE>)
+ BatchMatMulMkl<CPUDevice, TYPE, false>)
+
+#define REGISTER_BATCH_MATMUL_MKL_V2(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("_MklBatchMatMulV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .Label(mkl_op_registry::kMklNameChangeOpLabel), \
+ BatchMatMulMkl<CPUDevice, TYPE, true>)
#ifdef ENABLE_MKL
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
+
+TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2);
+TF_CALL_double(REGISTER_BATCH_MATMUL_MKL_V2);
+TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL_V2);
+TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL_V2);
#endif // ENABLE_MKL
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index e681262..b9947cb 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -66,8 +66,8 @@
} else if (shapes_and_types && shapes_and_types_i) {
if (shapes_and_types_i->size() != shapes_and_types->size()) {
return errors::InvalidArgument(
- "shapes_and_types[", i,
- "].size() == ", shapes_and_types_i->size(),
+ "shapes_and_types[", i, "].size() == ",
+ shapes_and_types_i->size(),
" != shapes_and_types[0].size() == ",
shapes_and_types->size());
}
@@ -142,10 +142,21 @@
.Input("x: T")
.Input("y: T")
.Output("output: T")
- .Attr("T: {bfloat16, half, float, double, int32, complex64, complex128}")
+ .Attr(
+ "T: {bfloat16, half, float, double, int32, int64, complex64, "
+ "complex128}")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
.SetShapeFn(shape_inference::BatchMatMulShape);
+
+REGISTER_OP("_MklBatchMatMulV2")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("output: T")
+ .Attr("T: {bfloat16, half, float, double, int32, complex64, complex128}")
+ .Attr("adj_x: bool = false")
+ .Attr("adj_y: bool = false")
+ .SetShapeFn(shape_inference::BatchMatMulV2Shape);
#endif // INTEL_MKL
// --------------------------------------------------------------------------
@@ -1355,12 +1366,12 @@
T limit = limit_t->scalar<T>()();
T delta = delta_t->scalar<T>()();
if (start > limit && delta > 0) {
- return errors::InvalidArgument(
- "Requires start <= limit when delta > 0: ", start, "/", limit);
+ return errors::InvalidArgument("Requires start <= limit when delta > 0: ",
+ start, "/", limit);
}
if (start < limit && delta < 0) {
- return errors::InvalidArgument(
- "Requires start >= limit when delta < 0: ", start, "/", limit);
+ return errors::InvalidArgument("Requires start >= limit when delta < 0: ",
+ start, "/", limit);
}
if (delta == 0) {
return errors::InvalidArgument("Requires delta != 0");