blob: c862cdd1efc497044361bbe5e2e6e6d6242c705b [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef INTEL_MKL
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
class MklEagerOpRewrite : public EagerOpRewrite {
public:
MklEagerOpRewrite(string name, string file, string line);
typedef struct {
string op_name;
std::function<bool(EagerOperation*)> RewriteRule;
std::function<Status(EagerOperation*, std::unique_ptr<EagerOperation>*)>
CreateMklOp;
} MklEagerOp;
private:
// TODO(intel-tf): refactor with unordered_map;
// especially when adding more ops/rewrite rules in future.
std::vector<MklEagerOp> mkl_eager_ops_;
// The entry point to execute the op rewrite.
Status Run(EagerOperation* orig_op,
std::unique_ptr<tensorflow::EagerOperation>* out_op);
// Initializes the new op and sets up its inputs and attributes
static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name,
std::unique_ptr<EagerOperation>* new_mkl_op);
// Generic rewrite that can be used for any mkl op that doesn't need
// special processing.
static Status CreateGenericMklOp(EagerOperation* orig_op,
std::unique_ptr<EagerOperation>* mkl_op);
// Creates new MKL op for Conv2D, Conv2DBackpropInput and
// Conv2DBackpropFilter.
static Status CreateMklConv2DOp(
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op);
// Rewrite rule for Conv2D, Conv2DBackpropInput and Conv2DBackpropFilter.
static bool RewriteConv2D(EagerOperation* op);
// Calls op-specific rewrite function to create new MKL op.
Status RewriteToMklOp(EagerOperation* orig_op,
std::unique_ptr<EagerOperation>* mkl_op,
const int op_idx);
// Checks whether we can rewrite the op to MKL one or not.
bool ShouldRewriteOp(EagerOperation* op, int* op_idx);
// Default rewrite rule to be used when rewrite should happen without any
// restriction.
static bool AlwaysRewrite(EagerOperation* op) { return true; }
};
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
// Constructor
MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
: EagerOpRewrite(name, file, line) {
mkl_eager_ops_.push_back({"BatchMatMulV2", AlwaysRewrite,
CreateGenericMklOp}); // No need to check for V1 as
// it has been obsoleted
// already
mkl_eager_ops_.push_back({"Conv2D", RewriteConv2D, CreateMklConv2DOp});
mkl_eager_ops_.push_back(
{"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp});
mkl_eager_ops_.push_back(
{"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp});
mkl_eager_ops_.push_back({"MatMul", AlwaysRewrite, CreateGenericMklOp});
}
Status MklEagerOpRewrite::Run(
EagerOperation* orig_op,
std::unique_ptr<tensorflow::EagerOperation>* out_op) {
int found_op_idx = -1;
if (ShouldRewriteOp(orig_op, &found_op_idx)) {
TF_CHECK_OK(RewriteToMklOp(orig_op, out_op, found_op_idx));
}
return Status::OK();
}
Status MklEagerOpRewrite::SetupNewOp(
EagerOperation* orig_op, const string mkl_op_name,
std::unique_ptr<EagerOperation>* new_mkl_op) {
const tensorflow::AttrTypeMap* types;
bool is_function = false;
TF_RETURN_IF_ERROR(
tensorflow::AttrTypeMapForOp(mkl_op_name.c_str(), &types, &is_function));
EagerContext* ctx = orig_op->EagerContext();
new_mkl_op->reset(new tensorflow::EagerOperation(ctx, mkl_op_name.c_str(),
is_function, types));
int num_inputs = orig_op->Inputs().size();
// Add all inputs to the new op.
for (int i = 0; i < num_inputs; ++i) {
(*new_mkl_op)->AddInput(orig_op->Inputs()[i]);
}
// Copy all attributes to the new op.
string name;
const NodeDef& orig_ndef = orig_op->MutableAttrs()->BuildNodeDef();
AttrSlice attr_list(orig_ndef);
for (const auto& attr : attr_list) {
(*new_mkl_op)->MutableAttrs()->Set(attr.first, attr.second);
}
(*new_mkl_op)
->MutableAttrs()
->Set("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
if (orig_op->Device() != nullptr) {
(*new_mkl_op)->SetDevice(orig_op->Device());
} else {
string device_name =
DeviceNameUtils::ParsedNameToString(orig_op->GetDeviceName());
(*new_mkl_op)->SetDeviceName(device_name.c_str());
}
return Status::OK();
}
Status MklEagerOpRewrite::CreateGenericMklOp(
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op) {
const string mkl_op_name = mkl_op_registry::GetMklOpName(orig_op->Name());
TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_op));
return Status::OK();
}
Status MklEagerOpRewrite::CreateMklConv2DOp(
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
const string mkl_op_name =
mkl_op_registry::GetMklEagerOpName(orig_op->Name());
TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_conv2d_op));
return Status::OK();
}
bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op, int* op_idx) {
// Don't rewrite the op if MKL use is disabled at runtime.
if (DisableMKL()) {
return false;
}
DataType data_type;
if (op->Attrs().Get("T", &data_type) != Status::OK()) {
return false;
}
// Check if we have registered MKL kernel for this op.
if (!mkl_op_registry::IsMklNameChangeOp(
mkl_op_registry::GetMklEagerOpName(op->Name()), data_type) &&
!mkl_op_registry::IsMklNameChangeOp(
mkl_op_registry::GetMklOpName(op->Name()), data_type)) {
return false;
}
*op_idx = -1;
// Find and call the op's rewrite rule that determines whether we need to
// rewrite this op or not.
for (auto it = mkl_eager_ops_.begin(); it != mkl_eager_ops_.end(); ++it) {
if (it->op_name.compare(op->Name()) == 0 && it->RewriteRule(op)) {
*op_idx = it - mkl_eager_ops_.begin();
return true;
}
}
return false;
}
Status MklEagerOpRewrite::RewriteToMklOp(
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op,
const int op_idx) {
mkl_eager_ops_[op_idx].CreateMklOp(orig_op, mkl_op);
return Status::OK();
}
bool MklEagerOpRewrite::RewriteConv2D(EagerOperation* op) {
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
string padding;
TF_CHECK_OK(GetNodeAttr(ndef, "padding", &padding));
// Right now MKL Conv2D does not support explicit padding.
return (padding != "EXPLICIT");
}
} // namespace tensorflow
#endif // INTEL_MKL