blob: 11e34076538d9e291f0a953d89e1dab13f49cd44 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/common_runtime/mkl_layout_pass.h"
#include <algorithm>
#include <vector>
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
// NOTE: Unit tests in this file rely on a topological sorted graph for
// printing. But since sibling nodes of a node in the topologically sorted graph
// can be printed in different orders, tests may fail if the order in which
// sibling nodes are visited is changed.
namespace {
const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0";
const char kGPUDevice[] = "/job:a/replica:0/task:0/device:GPU:0";
static void InitGraph(const string& s, Graph* graph,
const string& device = kCPUDevice) {
GraphDef graph_def;
auto parser = protobuf::TextFormat::Parser();
CHECK(parser.MergeFromString(s, &graph_def)) << s;
GraphConstructorOptions opts;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
for (Node* node : graph->nodes()) {
node->set_assigned_device_name(device);
}
}
class MklLayoutPassTest : public ::testing::Test {
public:
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
Node* FindNode(const string& name) {
for (Node* node : graph_.nodes()) {
if (node->name() == name) return node;
}
LOG(FATAL) << name;
}
void InitGraph(const string& s, const string& device = kCPUDevice) {
::tensorflow::InitGraph(s, &graph_, device);
original_ = CanonicalGraphString(&graph_);
}
static bool IncludeNode(const Node* n) { return n->IsOp(); }
static string EdgeId(const Node* n, int index) {
if (index == 0) {
return n->name();
} else if (index == Graph::kControlSlot) {
return strings::StrCat(n->name(), ":control");
} else {
return strings::StrCat(n->name(), ":", index);
}
}
string CanonicalGraphString(Graph* g) {
std::vector<string> nodes;
std::vector<string> edges;
for (const Node* n : g->nodes()) {
if (IncludeNode(n)) {
nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
}
}
for (const Edge* e : g->edges()) {
if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
EdgeId(e->dst(), e->dst_input())));
}
}
// Canonicalize
std::sort(nodes.begin(), nodes.end());
std::sort(edges.begin(), edges.end());
return strings::StrCat(absl::StrJoin(nodes, ";"), "|",
absl::StrJoin(edges, ";"));
}
string DoMklLayoutOptimizationPass() {
string before = CanonicalGraphString(&graph_);
LOG(ERROR) << "Before MKL layout rewrite pass: " << before;
std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
RunMklLayoutRewritePass(ug);
string result = CanonicalGraphString(&graph_);
LOG(ERROR) << "After MKL layout rewrite pass: " << result;
return result;
}
// Returns the attribute value only from the first node
template <typename T>
T DoMklLayoutOptimizationPassGetAttrVal(const string& attr,
const string& node_name) {
DoMklLayoutOptimizationPass();
T attr_val;
for (const Node* n : graph_.nodes()) {
if (IncludeNode(n) && n->type_string() == node_name) {
TF_CHECK_OK(GetNodeAttr(n->def(), attr, &attr_val));
return attr_val;
}
}
return attr_val;
}
const string& OriginalGraph() const { return original_; }
Graph graph_;
string original_;
};
// TODO(nhasabni): remove these two ops later once all of the file is modified
// to use new type-specific ops.
REGISTER_OP("Input").Output("o: float").SetIsStateful();
REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
REGISTER_OP("Output2").Input("i: float").Input("i1: float").SetIsStateful();
REGISTER_OP("Float32Input").Output("o: float").SetIsStateful();
REGISTER_OP("Float32InputList")
.Output("o: N * float")
.Attr("N: int")
.SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
REGISTER_OP("DoubleInput").Output("o: double").SetIsStateful();
REGISTER_OP("QuantizedInput").Output("o: quint8").SetIsStateful();
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
REGISTER_OP("_MklInput2")
.Output("o: uint8")
.Output("o1: uint8")
.SetIsStateful();
REGISTER_OP("QuantizedUnsignedInt8Input").Output("o: quint8").SetIsStateful();
REGISTER_OP("QuantizedSignedInt8Input").Output("o: qint8").SetIsStateful();
REGISTER_OP("QuantizedSignedInt32Input").Output("o: qint32").SetIsStateful();
REGISTER_OP("Float32Output2")
.Input("i: float")
.Input("i1: float")
.SetIsStateful();
REGISTER_OP("Output").Input("i: float").SetIsStateful();
REGISTER_OP("QInt8Input").Output("o: qint8").SetIsStateful();
REGISTER_OP("QUInt8Input").Output("o: quint8").SetIsStateful();
REGISTER_OP("QInt32Input").Output("o: qint32").SetIsStateful();
#ifdef ENABLE_INTEL_MKL_BFLOAT16
REGISTER_OP("BFloat16Input").Output("o: bfloat16").SetIsStateful();
REGISTER_OP("BFloat16InputList")
.Output("o: N * bfloat16")
.Attr("N: int")
.SetIsStateful();
REGISTER_OP("BFloat16Output2")
.Input("i: bfloat16")
.Input("i1: bfloat16")
.SetIsStateful();
#endif // ENABLE_INTEL_MKL_BFLOAT16
/////////////////////////////////////////////////////////////////////
// Unit tests related to node merge optimization
/////////////////////////////////////////////////////////////////////
// clang-format off
TEST_F(MklLayoutPassTest, Basic) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Zeta);D(Zeta)|"
"A->C;A->D;B->C:1;B->D:1");
}
// Test set 1: Conv2D + AddBias
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Zeta(E,Y)
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Conv2D'" \
" attr { key: 'T' value { type:" #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'BiasAdd'" \
" attr { key: 'T' value { type:" #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" input: ['C', 'D'] }" \
"node { name: 'Y' op: '" #INPUT "'}" \
"node { name: 'Z' op: 'Zeta'" \
" attr {key: 'T' value { type:" #T " } }" \
" input: ['E', 'Y']}"); \
\
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");D(" #INPUT ");DMT/_0(Const);" \
"DMT/_1(Const);DMT/_2(Const);E(_MklConv2DWithBias);Y(" #INPUT ");" \
"Z(Zeta)|A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"\
"A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;" \
"DMT/_2->E:5;E->Z;Y->Z:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_Conv2DWithBias_Positive);
#undef REGISTER_TEST
// Graph contains only Conv2D, no AddBias.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Conv2D'" \
" attr { key: 'T' value { type:" #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B']}"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklConv2D);DMT/_0(Const);" \
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;A:control->" \
"DMT/_1:control;B->C:1;DMT/_0->C:2;DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_Conv2DWithBias_Negative_NoAddBias);
#undef REGISTER_TEST
// Conv2D output does not go to BiasAdd.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Conv2D'" \
" attr { key: 'T' value { type:" #T "} }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: 'BiasAdd'" \
" attr { key: 'T' value { type:" #T "} }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" input: ['D', 'E'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklConv2D);D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);E(" #INPUT ");F(BiasAdd)|A->C;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;B->C:1;" \
"D->F;DMT/_0->C:2;DMT/_1->C:3;E->F:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_Conv2DWithBias_Negative_Dataflow1);
#undef REGISTER_TEST
// Conv2D has two outgoing edges: BiasAdd and some other dummy node (Zeta).
// Merge should not be done in such case.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: 'BiasAdd'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" input: ['D', 'E'] }" \
"node { name: 'G' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['C', 'E'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklConv2D);D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);E(" #INPUT ");F(BiasAdd);G(Zeta)|" \
"A->C;A:control->DMT/_0:control;A:control->DMT/_1:control;B->C:1;" \
"C->G;D->F;DMT/_0->C:2;DMT/_1->C:3;E->F:1;E->G:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_Conv2DWithBias_Negative_Dataflow2);
#undef REGISTER_TEST
// data_format attribute value mismatch. Merge should not be done
// in such case.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'BiasAdd'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHCW' } }" \
" input: ['C', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklConv2D);D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);E(BiasAdd)|A->C;A:control->" \
"DMT/_0:control;A:control->DMT/_1:control;B->C:1;C->E;" \
"D->E:1;DMT/_0->C:2;DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_Conv2DWithBias_Negative_AttrMismatch);
#undef REGISTER_TEST
// Test set 2: BiasAddGrad + Conv2DBackpropFilter fusion tests
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'Conv2DBackpropFilter'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B', 'C'] }" \
"node { name: 'E' op: 'BiasAddGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" input: ['C'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(" #INPUT ");" \
"D(_MklConv2DBackpropFilterWithBias);DMT/_0(Const);DMT/_1(Const);" \
"DMT/_2(Const)|A->D;A:control->DMT/_0:control;A:control->" \
"DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_Conv2DBackpropFilterFusion_Positive);
#undef REGISTER_TEST
// BiasAddGrad fusion in the presence of BackpropFilter. But nodes do not match
// criteria for rewrite. So rewrite should not happen. 3rd input of
// Conv2DBackpropFilter is different than input to BiasAddGrad.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'Conv2DBackpropFilter'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B', 'C'] }" \
"node { name: 'E' op: 'BiasAddGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" input: ['A'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(" #INPUT ");" \
"D(_MklConv2DBackpropFilter);DMT/_0(Const);DMT/_1(Const);" \
"DMT/_2(Const);E(BiasAddGrad)|A->D;A->E;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_Conv2DBackpropFilterFusion_Negative1);
#undef REGISTER_TEST
// BiasAddGrad fusion, but nodes do not match criteria for fusion.
// Different input formats.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'Conv2DBackpropFilter'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B', 'C'] }" \
"node { name: 'E' op: 'BiasAddGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" input: ['A'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(" #INPUT ");" \
"D(_MklConv2DBackpropFilter);DMT/_0(Const);DMT/_1(Const);" \
"DMT/_2(Const);E(BiasAddGrad)|A->D;A->E;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_Conv2DBackpropFilterFusion_Negative2);
#undef REGISTER_TEST
// BiasAddGrad fusion in the presence of BackpropFilter only. Fusion is done
// before node rewrite. Check this ordering.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'M' op: '_MklInput'}" \
"node { name: 'N' op: '_MklInput'}" \
"node { name: 'O' op: '_MklInput'}" \
"node { name: 'D' op: '_MklConv2DWithBias'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B', 'C', 'M', 'N', 'O']}" \
"node { name: 'E' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['D', 'A']}" \
"node { name: 'F' op: 'Int32Input'}" \
"node { name: 'G' op: '_MklConv2DBackpropFilter'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" input: ['E', 'F', 'A', 'M', 'N', 'O'] }" \
"node { name: 'H' op: 'BiasAddGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" input: ['E'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklConv2DWithBias);" \
"E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" \
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;" \
"C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;" \
"O->G:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_Conv2DBackpropFilterFusion_Negative3);
#undef REGISTER_TEST
// C=Conv2D(A,B); E=BiasAdd(C,D); Y=Zeta(E,X);
// G=Conv2DBackpropInput(F,B,E)
// This is a case of node rewrite followed by node merge followed by connecting
// filter output of Conv2DWithBias to filter input of Conv2DBackpropInput.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'BiasAdd'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" input: ['C', 'D'] }" \
"node { name: 'X' op: '" #INPUT "'}" \
"node { name: 'Y' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['E', 'X']}" \
"node { name: 'F' op: 'Int32Input'}" \
"node { name: 'G' op: 'Conv2DBackpropInput'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['F', 'B', 'E']}" \
"node { name: 'Z' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['G', 'X']}"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");D(" #INPUT ");DMT/_0(Const);" \
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);E(_MklConv2DWithBias);" \
"F(Int32Input);G(_MklConv2DBackpropInput);X(" #INPUT ");Y(Zeta);" \
"Z(Zeta)|A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;" \
"DMT/_1->E:4;DMT/_2->E:5;DMT/_3->G:3;E->G:2;E->Y;E:1->G:1;E:2->G:5;" \
"E:3->G:4;F->G;F:control->DMT/_3:control;G->Z;X->Y:1;X->Z:1"); \
}
// TODO(nhasabni): Enable bfloat16 test when we enable the operator.
REGISTER_TEST_FLOAT32(NodeMerge_Conv2DWithBias_ConvBpropInput_FilterFwd);
#undef REGISTER_TEST
// Test set 3: Pad + Conv2D fusion
// padding is VALID type
// A = input(image), B = input(paddings), C= Pad = input of conv2D,
// D=input(filter), E = Conv2D, Z = Zeta
// C=Pad(A,B); E=Conv2D(C,D); Z=Zeta(E,Y)
// After layout pass
// _MklPadWithConv2D(A, D, B, DMT/_0, DMT/_1, DMT/_2)
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['C', 'D'] }" \
"node { name: 'Y' op: '" #INPUT "'}" \
"node { name: 'Z' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['E', 'Y']}"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);D(" #INPUT ");DMT/_0(Const);" \
"DMT/_1(Const);DMT/_2(Const);E(_MklPadWithConv2D);Y(" #INPUT ");" \
"Z(Zeta)|A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"\
"A:control->DMT/_2:control;B->E:2;D->E:1;DMT/_0->E:3;DMT/_1->E:4;" \
"DMT/_2->E:5;E->Z;Y->Z:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithConv2D_Positive);
#undef REGISTER_TEST
// Test if input control edges do not duplicate after merge.
// If both the merging ops have input control edge from a common op
// then, the merged op will have only one control edge from that
// common op.
// padding is VALID type
// A = input(image), A1 = input, B = input(paddings),
// C= Pad = input of conv2D,
// D=input(filter), E = Conv2D, Z = Zeta
// C=Pad(A,B); E=Conv2D(C,D); Z=Zeta(E,Y)
// A1:control->C:control
// A1:control->E:control
// After layout pass:
// _MklPadWithConv2D(A, D, B, DMT/_0, DMT/_1, DMT/_2)
// A1:control->E:control (only one control edge)
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph("node { name: 'A1' op: '" #INPUT "'}" \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['C', 'D'] }" \
"node { name: 'Y' op: '" #INPUT "'}" \
"node { name: 'Z' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['E', 'Y']}"); \
Node* a1 = FindNode("A1"); \
Node* c = FindNode("C"); \
Node* e = FindNode("E"); \
const Edge* edge = graph_.AddControlEdge(a1, c); \
const Edge* edge_1 = graph_.AddControlEdge(a1, e); \
ASSERT_NE(edge, nullptr); \
ASSERT_NE(edge_1, nullptr); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");A1(" #INPUT ");B(Int32Input);D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(_MklPadWithConv2D);" \
"Y(" #INPUT ");Z(Zeta)|A->E;A1:control->E:control;A:control->" \
"DMT/_0:control;A:control->DMT/_1:control;A:control->" \
"DMT/_2:control;B->E:2;D->E:1;DMT/_0->E:3;DMT/_1->E:4;" \
"DMT/_2->E:5;E->Z;Y->Z:1"); \
}
REGISTER_TEST_ALL_TYPES(Input_ControlEdge_PadWithConv2D_Positive);
#undef REGISTER_TEST
// Test if output control edges does not duplicate after merge.
// If both the merging ops have output control edge to a common op,
// then after merge, the merged op will have only one control edge
// to that common op.
// padding is VALID type
// A = input(image), B = input(paddings), C= Pad = input of conv2D,
// D=input(filter), E = Conv2D, Z = Zeta
// C=Pad(A,B); E=Conv2D(C,D); Z=Zeta(E,Y)
// C:control->A1:control
// E:control->A1:control
// After layout pass:
// _MklPadWithConv2D(A, D, B, DMT/_0, DMT/_1, DMT/_2)
// E:control->A1:control (only one control edge)
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph("node { name: 'A1' op: '" #INPUT "'}" \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['C', 'D'] }" \
"node { name: 'Y' op: '" #INPUT "'}" \
"node { name: 'Z' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['E', 'Y']}"); \
Node* a1 = FindNode("A1"); \
Node* c = FindNode("C"); \
Node* e = FindNode("E"); \
const Edge* edge = graph_.AddControlEdge(c, a1); \
const Edge* edge_1 = graph_.AddControlEdge(e, a1); \
ASSERT_NE(edge, nullptr); \
ASSERT_NE(edge_1, nullptr); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");A1(" #INPUT ");B(Int32Input);D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(_MklPadWithConv2D);" \
"Y(" #INPUT ");Z(Zeta)|A->E;A:control->DMT/_0:control;A:control->" \
"DMT/_1:control;A:control->DMT/_2:control;B->E:2;D->E:1;" \
"DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;E->Z;E:control->A1:control;" \
"Y->Z:1"); \
}
REGISTER_TEST_ALL_TYPES(Output_ControlEdge_PadWithConv2D_Positive);
#undef REGISTER_TEST
// Pad + Conv2D fusion with padding is VALID,
// Input node pointing to both Pad and Conv2D
// A = input(image), B = input(paddings), C= Pad
// E = Conv2D, Z = Zeta
// C=Pad(A,B); E=Conv2D(C,A); Z=Zeta(E,Y)
// After layout pass
// _MklPadWithConv2D(A, A, B, DMT/_0, DMT/_1, DMT/_2)
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'E' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['C', 'A'] }" \
"node { name: 'Y' op: '" #INPUT "'}" \
"node { name: 'Z' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['E', 'Y']}"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);DMT/_0(Const);DMT/_1(Const);" \
"DMT/_2(Const);E(_MklPadWithConv2D);Y(" #INPUT ");Z(Zeta)|" \
"A->E;A->E:1;A:control->DMT/_0:control;A:control->" \
"DMT/_1:control;A:control->DMT/_2:control;B->E:2;DMT/_0->E:3;"\
"DMT/_1->E:4;DMT/_2->E:5;E->Z;Y->Z:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithConv2D_Common_Input);
#undef REGISTER_TEST
// Pad + Conv2D with padding is VALID,
// Input node pointing to both Pad and Conv2D
// Output of both Pad and Conv2D feeds one node (Z as Output2)
// A = input(as image), B = input(as paddings), C= Pad
// E = Conv2D, Z = Output2
// C=Pad(A,B); E=Conv2D(C,A); Z=Output(C,E)
// After layout pass - No merging, since Pad and Conv2D both
// feed to the same node (Z)
#define REGISTER_TEST(NAME, T, INPUT, OUTPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'E' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['C', 'A'] }" \
"node { name: 'Z' op: '" #OUTPUT "'" \
" input: ['C', 'E']}"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(Pad);DMT/_0(Const);" \
"DMT/_1(Const);E(_MklConv2D);Z(" #OUTPUT ")|A->C;A->E:1;" \
"B->C:1;C->E;C->Z;C:control->DMT/_0:control;C:control->" \
"DMT/_1:control;DMT/_0->E:2;DMT/_1->E:3;E->Z:1"); \
}
REGISTER_TEST(NodeMerge_PadWithConv2D_Common_InOutput, DT_FLOAT, Float32Input,
Float32Output2);
#ifdef ENABLE_INTEL_MKL_BFLOAT16
// TODO(nhasabni): Enable bfloat16 test when we enable the operator.
REGISTER_TEST(NodeMerge_PadWithConv2D_Common_InOutput, DT_BFLOAT16,
BFloat16Input, BFloat16Output2);
#endif
#undef REGISTER_TEST
// Pad + Conv2D; padding is SAME
// A = input(image), B = input(paddings), C= Pad = input of conv2D,
// D=input(filter), E = Conv2D, Z = Zeta
// C=Pad(A,B); E=Conv2D(C,D); Z=Zeta(E,Y)
// After layout pass - No merging
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['C', 'D'] }" \
"node { name: 'Y' op: '" #INPUT "'}" \
"node { name: 'Z' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['E', 'Y']}"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(Pad);D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);E(_MklConv2D);Y(" #INPUT ");" \
"Z(Zeta)|A->C;B->C:1;C->E;C:control->DMT/_0:control;" \
"C:control->DMT/_1:control;D->E:1;DMT/_0->E:2;DMT/_1->E:3;" \
"E->Z;Y->Z:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithConv2D_Negative);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'Input0' op: '" #INPUT "'}" \
"node { name: 'Input1' op: '" #INPUT "'}" \
"node { name: 'Const0' op: 'Const'" \
" attr {key: 'dtype' value { type: DT_INT32 } }" \
" attr {key: 'value' value { " \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape { dim { size: 4 } }" \
" tensor_content: " \
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\" \
"000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Const1' op: 'Const'" \
" attr {key: 'dtype' value { type: DT_INT32 } }" \
" attr {key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape {dim {size: 4 }}" \
" tensor_content: " \
"'\\000\\000\\000\\000\\003\\000\\000\\000\\001" \
"\\000\\000\\000\\002\\000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Transpose0' op: 'Transpose'" \
" input: ['Input0', 'Const0']" \
" attr { key: 'T' value { type: " #T "} }" \
" attr { key: 'Tperm' value { type: DT_INT32 } } }" \
"node { name: 'Conv2D' op: 'Conv2D'" \
" input: ['Transpose0', 'Input1']" \
" attr { key: 'T' value { type: " #T "} }" \
" attr { key: 'data_format' value { s: 'NHWC' }}" \
" attr { key: 'dilations' value {list: {i:1,i:1,i:1,i:1}}}" \
" attr { key: 'padding' value {s: 'SAME'}}" \
" attr { key: 'strides' value {list: {i:1,i:1,i:1,i:1}}}" \
" attr { key: 'use_cudnn_on_gpu' value {b: true}}}" \
"node { name: 'Transpose1' op: 'Transpose'" \
" input: ['Conv2D', 'Const1' ]" \
" attr { key: 'T' value { type: " #T "}}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'Relu' op: 'Relu'" \
" attr { key: 'T' value { type: " #T "} }" \
" input: ['Transpose1'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"Const0(Const);Const1(Const);Conv2D(_MklConv2D);DMT/_0(Const);" \
"DMT/_1(Const);Input0(" #INPUT ");Input1(" #INPUT ");" \
"Relu(_MklRelu)|Conv2D->Relu;Conv2D:2->Relu:1;DMT/_0->Conv2D:2;" \
"DMT/_1->Conv2D:3;Input0->Conv2D;Input0:control->DMT/_0:control;"\
"Input0:control->DMT/_1:control;Input1->Conv2D:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeConv2DTranspose_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'Input0' op: '" #INPUT "'}" \
"node { name: 'Input1' op: '" #INPUT "'}" \
"node { name: 'Const0' op: 'Const'" \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape {dim {size: 4}}" \
" tensor_content: " \
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\" \
"000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Const1' op: 'Const'" \
" attr { key: 'dtype' value { type: DT_INT32 }}" \
" attr {" \
" key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape { dim { size: 4 }}" \
" tensor_content: " \
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\" \
"000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node {name: 'Transpose0' op: 'Transpose'" \
" input: ['Input0', 'Const0']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'Conv2D' op: 'Conv2D'" \
" input: ['Transpose0', 'Input1']" \
" attr { key: 'T' value { type: " #T "}}" \
" attr { key: 'data_format' value { s: 'NHWC' }}" \
" attr { key: 'dilations' value { list: {i:1,i:1,i:1,i:1}}}" \
" attr { key: 'padding' value { s: 'SAME' }}" \
" attr { key: 'strides' value { list: {i:1,i:1,i:1,i:1}}}" \
" attr { key: 'use_cudnn_on_gpu' value { b: true }}}" \
"node {name: 'Transpose1' op: 'Transpose'" \
" input: ['Conv2D', 'Const1']" \
" attr { key: 'T' value { type: " #T "}}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'Relu' op: 'Relu'" \
" attr { key: 'T' value { type: " #T "}}" \
" input: ['Transpose1'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"Const0(Const);Const1(Const);Conv2D(_MklConv2D);DMT/_0(Const);" \
"DMT/_1(Const);DMT/_2(Const);Input0(" #INPUT ");Input1(" #INPUT \
");Relu(_MklRelu);Transpose0(_MklTranspose);" \
"Transpose1(_MklTranspose)|Const0->Transpose0:1;" \
"Const1->Transpose1:1;Conv2D->Transpose1;DMT/_0->Conv2D:2;" \
"DMT/_1->Conv2D:3;DMT/_2->Relu:1;Input0->Transpose0;" \
"Input1->Conv2D:1;Transpose0->Conv2D;Transpose0:control->DMT/" \
"_0:control;Transpose0:control->DMT/_1:control;Transpose1->Relu;"\
"Transpose1:control->DMT/_2:control"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeConv2DTranspose_Negative);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'Input0' op: '" #INPUT "'}" \
"node { name: 'Input1' op: '" #INPUT "'}" \
"node { name: 'Const0' op: 'Const'" \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape { dim {size: 5}}" \
" tensor_content:" \
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
"\\000\\000\\000\\001\\000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Const1' op: 'Const'" \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape { dim { size: 5 }}" \
" tensor_content: " \
"'\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\002" \
"\\000\\000\\000\\003\\000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Transpose0' op: 'Transpose'" \
" input: ['Input0', 'Const0']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'Conv3D' op: 'Conv3D'" \
"input: ['Transpose0', 'Input1']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'data_format' value { s: 'NDHWC' }}" \
" attr { key: 'dilations' value { list: {i:1,i:1,i:1,i:1,i:1}}}" \
" attr { key: 'padding' value { s: 'SAME' }}" \
" attr { key: 'strides' value { list: {i:1,i:1,i:1,i:1,i:1}}}" \
" attr { key: 'use_cudnn_on_gpu' value { b: true }}}" \
"node { name: 'Transpose1' op: 'Transpose'" \
" input: ['Conv3D', 'Const1']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'Relu' op: 'Relu'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['Transpose1'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);" \
"DMT/_1(Const);Input0(" #INPUT ");Input1(" #INPUT ");" \
"Relu(_MklRelu)|Conv3D->Relu;Conv3D:2->Relu:1;" \
"DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;Input0->Conv3D;" \
"Input0:control->DMT/_0:control;" \
"Input0:control->DMT/_1:control;Input1->Conv3D:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeConv3DTranspose_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'Input0' op: '" #INPUT "'}" \
"node { name: 'Input1' op: '" #INPUT "'}" \
"node { name: 'Const0' op: 'Const'" \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape { dim {size: 5}}" \
" tensor_content:" \
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
"\\000\\000\\000\\001\\000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Const1' op: 'Const'" \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape { dim { size: 5 }}" \
" tensor_content: " \
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
"\\000\\000\\000\\001\\000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Transpose0' op: 'Transpose'" \
" input: ['Input0', 'Const0']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'Conv3D' op: 'Conv3D'" \
"input: ['Transpose0', 'Input1']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'data_format' value { s: 'NDHWC' }}" \
" attr { key: 'dilations' value { list: {i:1,i:1,i:1,i:1,i:1}}}" \
" attr { key: 'padding' value { s: 'SAME' }}" \
" attr { key: 'strides' value { list: {i:1,i:1,i:1,i:1,i:1}}}" \
" attr { key: 'use_cudnn_on_gpu' value { b: true }}}" \
"node { name: 'Transpose1' op: 'Transpose'" \
" input: ['Conv3D', 'Const1']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'Relu' op: 'Relu'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['Transpose1'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);" \
"DMT/_1(Const);DMT/_2(Const);Input0(" #INPUT ");" \
"Input1(" #INPUT ");Relu(_MklRelu);Transpose0(_MklTranspose);" \
"Transpose1(_MklTranspose)|Const0->Transpose0:1;Const1->" \
"Transpose1:1;Conv3D->Transpose1;DMT/_0->Conv3D:2;" \
"DMT/_1->Conv3D:3;DMT/_2->Relu:1;Input0->Transpose0;Input1->" \
"Conv3D:1;Transpose0->Conv3D;Transpose0:control->" \
"DMT/_0:control;Transpose0:control->DMT/_1:control;" \
"Transpose1->Relu;Transpose1:control->DMT/_2:control"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeConv3DTranspose_Negative);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'Input0' op: '" #INPUT "'}" \
"node { name: 'Const0' op: 'Const'" \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape { dim {size: 5}}" \
" tensor_content:" \
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
"\\000\\000\\000\\001\\000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Const1' op: 'Const'" \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape { dim { size: 5 }}" \
" tensor_content: " \
"'\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\002" \
"\\000\\000\\000\\003\\000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Transpose0' op: 'Transpose'" \
" input: ['Input0', 'Const0']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'MaxPool3D' op: 'MaxPool3D'" \
"input: ['Transpose0']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'data_format' value { s: 'NDHWC' }}" \
" attr { key: 'padding' value { s: 'SAME' }}" \
" attr { key: 'strides' value { list: {i:1,i:2,i:2,i:2,i:1}}}" \
" attr { key: 'ksize' value { list: {i:1,i:1,i:1,i:1,i:1}}}}"\
"node { name: 'Transpose1' op: 'Transpose'" \
" input: ['MaxPool3D', 'Const1']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'Relu' op: 'Relu'" \
" attr { key: 'T' value { type: " #T " }}" \
" input: ['Transpose1'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"Const0(Const);Const1(Const);DMT/_0(Const);Input0(" #INPUT ");" \
"MaxPool3D(_MklMaxPool3D);Relu(_MklRelu)|DMT/_0->MaxPool3D:1;" \
"Input0->MaxPool3D;Input0:control->DMT/_0:control;" \
"MaxPool3D->Relu;MaxPool3D:2->Relu:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeMaxPool3DTranspose_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'Input0' op: '" #INPUT "'}" \
"node { name: 'Const0' op: 'Const'" \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape { dim {size: 5}}" \
" tensor_content:" \
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
"\\000\\000\\000\\001\\000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Const1' op: 'Const'" \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value'" \
" value {" \
" tensor {" \
" dtype: DT_INT32" \
" tensor_shape { dim { size: 5 }}" \
" tensor_content: " \
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
"\\000\\000\\000\\001\\000\\000\\000'" \
" }" \
" }" \
" }" \
"}" \
"node { name: 'Transpose0' op: 'Transpose'" \
" input: ['Input0', 'Const0']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'MaxPool3D' op: 'MaxPool3D'" \
"input: ['Transpose0']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'data_format' value { s: 'NDHWC' }}" \
" attr { key: 'padding' value { s: 'SAME' }}" \
" attr { key: 'strides' value { list: {i:1,i:2,i:2,i:2,i:1}}}" \
" attr { key: 'ksize' value { list: {i:1,i:1,i:1,i:1,i:1}}}}"\
"node { name: 'Transpose1' op: 'Transpose'" \
" input: ['MaxPool3D', 'Const1']" \
" attr { key: 'T' value { type: " #T " }}" \
" attr { key: 'Tperm' value { type: DT_INT32 }}}" \
"node { name: 'Relu' op: 'Relu'" \
" attr { key: 'T' value { type: " #T " }}" \
" input: ['Transpose1'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"Const0(Const);Const1(Const);DMT/_0(Const);DMT/_1(Const);Input0(" #INPUT\
");MaxPool3D(_MklMaxPool3D);Relu(_MklRelu);Transpose0(_MklTranspose);" \
"Transpose1(_MklTranspose)|Const0->Transpose0:1;" \
"Const1->Transpose1:1;DMT/_0->MaxPool3D:1;DMT/_1->Relu:1;Input0->" \
"Transpose0;MaxPool3D->Transpose1;Transpose0->MaxPool3D;Transpose0:" \
"control->DMT/_0:control;Transpose1->Relu;Transpose1:control->" \
"DMT/_1:control"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeMaxPool3DTranspose_Negative);
#undef REGISTER_TEST
/////////////////////////////////////////////////////////////////////
// Unit tests related to rewriting node to Mkl node
/////////////////////////////////////////////////////////////////////
// Single Conv2D Op; No Mkl layer on the input and on the output.
// We will generate dummy Mkl tensor as 2nd input of Conv2D.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T \
" } }" \
" input: ['B', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklConv2D);D(Zeta);" \
"DMT/_0(Const);DMT/_1(Const)|A->C;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" \
"DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Conv2D_Basic);
#undef REGISTER_TEST
// Test case for the Depthwise FWD pass
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'DepthwiseConv2dNative'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T \
" } }" \
" input: ['B', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklDepthwiseConv2dNative);" \
"D(Zeta);DMT/_0(Const);DMT/_1(Const)|A->C;A:control->" \
"DMT/_0:control;A:control->DMT/_1:control;B->C:1;B->D;C->D:1;" \
"DMT/_0->C:2;DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_DepthwiseConv2dNative_Basic);
#undef REGISTER_TEST
// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
// have 2 outputs, both of which will be inputs to next Conv2D.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" input: ['A', 'C']}" \
"node { name: 'E' op: 'Zeta' " \
" attr { key: 'T' value { type: " #T "} }" \
" input: ['C', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklConv2D);D(_MklConv2D);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->C;A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->C:1;C->D:1;C->E;" \
"C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Conv2D_Positive1);
#undef REGISTER_TEST
// Conv2D with INT32 which is not supported by Mkl
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
InitGraph(
"node { name: 'A' op: 'HalfInput'}"
"node { name: 'B' op: 'HalfInput'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_HALF } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(HalfInput);B(HalfInput);C(Conv2D);D(Zeta)|"
"A->C;B->C:1;B->D;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_ConstInp) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'D' op: 'QuantizeV2'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'SCALED' } }"
" attr { key: 'round_mode' value { s: 'HALF_TO_EVEN' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(Const);C(Const);D(QuantizeV2);E(Zeta)|"
"A->D;B->D:1;C->D:2;D->E");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_MinFirst) {
InitGraph(
"node { name: 'A' op: 'Input' } "
"node { name: 'B' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'D' op: 'QuantizeV2'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'MIN_FIRST' } }"
" attr { key: 'round_mode' value { s: 'HALF_TO_EVEN' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Const);C(Const);D(_MklQuantizeV2);DMT/_0(Const);DMT/"
"_1(Const);DMT/_2(Const);E(Zeta)|"
"A->D;A:control->DMT/_0:control;A:control->DMT/"
"_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;D->E;DMT/"
"_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_NarrowRange_True) {
InitGraph(
"node { name: 'A' op: 'Input' } "
"node { name: 'B' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'D' op: 'QuantizeV2'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'SCALED' } }"
" attr { key: 'round_mode' value { s: 'HALF_TO_EVEN' } }"
" attr { key: 'narrow_range' value { b: true } }"
" attr { key: 'axis' value { i: -1 } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Const);C(Const);D(QuantizeV2);E(Zeta)|"
"A->D;B->D:1;C->D:2;D->E");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_PerSlice_Enabled) {
InitGraph(
"node { name: 'A' op: 'Input' } "
"node { name: 'B' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'D' op: 'QuantizeV2'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'SCALED' } }"
" attr { key: 'round_mode' value { s: 'HALF_TO_EVEN' } }"
" attr { key: 'narrow_range' value { b: false } }"
" attr { key: 'axis' value { i: 2 } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Const);C(Const);D(QuantizeV2);E(Zeta)|"
"A->D;B->D:1;C->D:2;D->E");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_HalfFromZero) {
InitGraph(
"node { name: 'A' op: 'Input' } "
"node { name: 'B' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'D' op: 'QuantizeV2'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'SCALED' } }"
" attr { key: 'round_mode' value { s: 'HALF_FROM_ZERO' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Const);C(Const);D(QuantizeV2);E(Zeta)|"
"A->D;B->D:1;C->D:2;D->E");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Positive) {
InitGraph(
"node { name: 'A' op: 'Input' } "
"node { name: 'B' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'D' op: 'QuantizeV2'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'SCALED' } }"
" attr { key: 'round_mode' value { s: 'HALF_TO_EVEN' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Const);C(Const);D(_MklQuantizeV2);DMT/_0(Const);DMT/"
"_1(Const);DMT/_2(Const);E(Zeta)|"
"A->D;A:control->DMT/_0:control;A:control->DMT/"
"_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;D->E;DMT/"
"_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Dequantize_Negative_Const_Input) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_QUINT8 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_QUINT8 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Dequantize'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'SCALED' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(Input);C(Input);D(Dequantize);"
"E(Zeta)|A->D;B->D:1;C->D:2;D->E");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Dequantize_Negative_Non_SCALED_Mode) {
InitGraph(
"node { name: 'A' op: 'QuantizedInput'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Dequantize'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'MIN_FIRST' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QuantizedInput);B(Input);C(Input);D(Dequantize);"
"E(Zeta)|A->D;B->D:1;C->D:2;D->E");
}
// Rewrite test for _FusedConv2D Op with BiasAdd fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklFusedConv2D);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive1);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with Relu fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops' value { list: {s: 'Relu'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklFusedConv2D);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive2);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with BiasAdd+Relu fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'Relu'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklFusedConv2D);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive3);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with BiasAdd+Relu6 fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'Relu6'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklFusedConv2D);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive4);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with BiasAdd+Elu fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'Elu'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklFusedConv2D);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive5);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with BiasAdd+Add fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 2 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'Add'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C', 'D']}" \
"node { name: 'F' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['E', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);" \
"E(_MklFusedConv2D);F(Zeta)|A->E;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;A:control->DMT/_2:control;" \
"A:control->DMT/_3:control;B->E:1;C->E:2;D->E:3;D->F:1;" \
"DMT/_0->E:4;DMT/_1->E:5;DMT/_2->E:6;DMT/_3->E:7;E->F"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive6);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with BiasAdd+Add+Relu fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 2 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'Add', s: 'Relu'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C', 'D']}" \
"node { name: 'F' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['E', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);" \
"E(_MklFusedConv2D);F(Zeta)|A->E;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;A:control->DMT/_2:control;" \
"A:control->DMT/_3:control;B->E:1;C->E:2;D->E:3;D->F:1;" \
"DMT/_0->E:4;DMT/_1->E:5;DMT/_2->E:6;DMT/_3->E:7;E->F"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with BiasAdd+LeakyRelu fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'LeakyRelu'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklFusedConv2D);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive8);
#undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNative Op fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: " FUSED_OPS " } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_MklFusedDepthwiseConv2dNative);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
// BiasAdd fusion
#define FUSED_OPS "{s: 'BiasAdd'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive1);
// BiasAdd + Relu fusion
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive2);
// BiasAdd + Relu6 fusion
#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu6'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3);
// BiasAdd + Elu fusion
#define FUSED_OPS "{s: 'BiasAdd', s: 'Elu'}"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4);
#undef FUSED_OPS
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with unsupported fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_FusedConv2D);" \
"E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1);
#undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNative with unsupported fusion
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T " } }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_FusedDepthwiseConv2dNative);" \
"E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Negative1);
#undef REGISTER_TEST
// Rewrite test for _FusedConv2D Op with unsupported type
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedConv2D'" \
" attr { key: 'T' value { type:" #T "} }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T "} }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_FusedConv2D);E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \
}
REGISTER_TEST(NodeRewrite_FusedConv2D_Negative2, DT_DOUBLE, DoubleInput);
#undef REGISTER_TEST
// Rewrite test for _FusedDepthwiseConv2dNativeOp with unsupported type
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \
" attr { key: 'T' value { type:" #T "} }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T "} }" \
" input: ['D', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \
"D(_FusedDepthwiseConv2dNative);" \
"E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \
}
REGISTER_TEST(NodeRewrite_FusedDepthwiseConv2dNative_Negative2,
DT_DOUBLE, DoubleInput);
#undef REGISTER_TEST
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedMatMul'" \
" attr { key: 'T' value { type:" #T "} }" \
" attr { key: 'transpose_a' value { b: false } }" \
" attr { key: 'transpose_b' value { b: false } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'Z' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['D', 'C']}"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklFusedMatMul);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);Z(Zeta)" \
"|A->D;A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->Z:1;D->Z;DMT/_0->D:3;" \
"DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Positive)
#undef REGISTER_TEST
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedMatMul'" \
" attr { key: 'T' value { type: " #T "} }" \
" attr { key: 'transpose_a' value { b: true } }" \
" attr { key: 'transpose_b' value { b: false } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'Z' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['D', 'C']}"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_FusedMatMul);Z(Zeta)"\
"|A->D;B->D:1;C->D:2;C->Z:1;D->Z"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Negative);
#undef REGISTER_TEST
// Merge test for PadWithFusedConv2D Op with BiasAdd fusion
// padding is VALID type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C, D, E)
// G = Zeta(F, E)
// After layout pass
// _MklPadWithFusedConv2D(A, D, E, B, DMT/_0, DMT/_1, DMT/_2, DMT/_3)
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['C', 'D', 'E']}" \
"node { name: 'G' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['F', 'E'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);D(" #INPUT ");DMT/_0(Const);DMT/_1(Const);"\
"DMT/_2(Const);DMT/_3(Const);E(" #INPUT ");F(_MklPadWithFusedConv2D);" \
"G(Zeta)|A->F;A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->F:3;D->F:1;" \
"DMT/_0->F:4;DMT/_1->F:5;DMT/_2->F:6;DMT/_3->F:7;E->F:2;E->G:1;F->G"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithFusedConv2D_Positive1);
#undef REGISTER_TEST
// Merge test for PadWithFusedConv2D Op with BiasAdd+Relu fusion
// padding is VALID type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C, D, E) (With relu)
// G = Zeta(F, E)
// After layout pass
// _MklPadWithFusedConv2D(A, D, E, B, DMT/_0, DMT/_1, DMT/_2, DMT/_3)
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'Relu'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['C', 'D', 'E']}" \
"node { name: 'G' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T "} }" \
" input: ['F', 'E'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);D(" #INPUT ");DMT/_0(Const);DMT/_1(Const);"\
"DMT/_2(Const);DMT/_3(Const);E(" #INPUT ");F(_MklPadWithFusedConv2D);"\
"G(Zeta)|A->F;A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->F:3;" \
"D->F:1;DMT/_0->F:4;DMT/_1->F:5;DMT/_2->F:6;DMT/" \
"_3->F:7;E->F:2;E->G:1;F->G"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithFusedConv2D_Positive2);
#undef REGISTER_TEST
// Merge test for PadWithFusedConv2D Op with unsupported fusion
// padding is VALID type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias),
// F = _FusedConv2D(C, D, E) (With Unsupported), G = Zeta(F, E)
// After layout pass - No merging
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['C', 'D', 'E']}" \
"node { name: 'G' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['F', 'E'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(Pad);D(" #INPUT ");E(" #INPUT ");"\
"F(_FusedConv2D);G(Zeta)|A->C;B->C:1;C->F;D->F:1;E->F:2;E->G:1;" \
"F->G"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithFusedConv2D_Negative1);
#undef REGISTER_TEST
// Merge test for PadWithFusedConv2D Op with BiasAdd fusion
// padding is SAME type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C,D,E)
// G = Zeta(F,E)
// After layout pass - No merging
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['C', 'D', 'E']}" \
"node { name: 'G' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['F', 'E'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(Pad);D(" #INPUT ");DMT/_0(Const);DMT/" \
"_1(Const);DMT/_2(Const);E(" #INPUT ");F(_MklFusedConv2D);G(Zeta)|A->C;"\
"B->C:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;" \
"C:control->DMT/_2:control;D->F:1;DMT/_0->F:3;DMT/_1->F:4;DMT/" \
"_2->F:5;E->F:2;E->G:1;F->G"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithFusedConv2D_Negative2);
#undef REGISTER_TEST
// Merge test for PadWithFusedConv2D Op with BiasAdd+Relu fusion
// padding is SAME type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C,D,E)(With relu)
// G = Zeta(F,E)
// After layout pass - No merging
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, " \
"i:1, i:1} } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'Relu'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['C', 'D', 'E']}" \
"node { name: 'G' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['F', 'E'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(Pad);D(" #INPUT ");DMT/_0(Const);DMT/" \
"_1(Const);DMT/_2(Const);E(" #INPUT ");F(_MklFusedConv2D);G(Zeta)|" \
"A->C;B->C:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"\
"C:control->DMT/_2:control;D->F:1;DMT/_0->F:3;DMT/_1->F:4;DMT/" \
"_2->F:5;E->F:2;E->G:1;F->G"); \
}
REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithFusedConv2D_Negative3);
#undef REGISTER_TEST
// Tests that there are no duplicate input control edges after merge.
// If both the merging ops have input control edges from a common op
// then, the merged op will have only one control edge from that
// common op. This test only add additional input control edge check
// based on the previous test NodeMerge_PadWithFusedConv2D_Positive1
// padding is VALID type
// A = input(image), X = input, B = input(paddings),
// C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C, D, E)
// G = Zeta(F, E)
// X:control->C:control
// X:control->F:control
// After layout pass:
// _MklPadWithFusedConv2D(A, D, B, F, DMT/_0, DMT/_1, DMT/_2, DMT/_3)
// X:control->E:control (only one control edge)
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'X' op: '" #INPUT "'}" \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['C', 'D', 'E']}" \
"node { name: 'G' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['F', 'E']}"); \
Node* x = FindNode("X"); \
Node* c = FindNode("C"); \
Node* f = FindNode("F"); \
const Edge* edge = graph_.AddControlEdge(x, c); \
const Edge* edge_1 = graph_.AddControlEdge(x, f); \
ASSERT_NE(edge, nullptr); \
ASSERT_NE(edge_1, nullptr); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);D(" #INPUT ");DMT/_0(Const);" \
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);E(" #INPUT ");" \
"F(_MklPadWithFusedConv2D);G(Zeta);X(" #INPUT ")|A->F;A:control->"\
"DMT/_0:control;A:control->DMT/_1:control;A:control->" \
"DMT/_2:control;A:control->DMT/_3:control;B->F:3;D->F:1;" \
"DMT/_0->F:4;DMT/_1->F:5;DMT/_2->F:6;DMT/_3->F:7;E->F:2;E->G:1;" \
"F->G;X:control->F:control"); \
}
REGISTER_TEST_ALL_TYPES(Input_ControlEdge_PadWithFusedConv2D_Positive);
#undef REGISTER_TEST
// ts that there are no duplicate output control edges after merge.
// If both the merging ops have output control edge to a common op,
// then after merge, the merged op will have only one control edge
// to that common op. This test only add additional output control edge check
// based on the previous test NodeMerge_PadWithFusedConv2D_Positive1
// padding is VALID type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C, D, E)
// G = Zeta(F, E), X = input
// C:control->X:control
// F:control->X:control
// After layout pass:
// _MklPadWithFusedConv2D(A, D, B, F, DMT/_0, DMT/_1, DMT/_2, DMT/_2)
// F:control->X:control (only one control edge)
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'X' op: '" #INPUT "'}" \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['C', 'D', 'E']}" \
"node { name: 'G' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['F', 'E']}"); \
Node* x = FindNode("X"); \
Node* c = FindNode("C"); \
Node* f = FindNode("F"); \
const Edge* edge = graph_.AddControlEdge(c, x); \
const Edge* edge_1 = graph_.AddControlEdge(f, x); \
ASSERT_NE(edge, nullptr); \
ASSERT_NE(edge_1, nullptr); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);D(" #INPUT ");DMT/_0(Const);DMT/_1(Const);"\
"DMT/_2(Const);DMT/_3(Const);E(" #INPUT ");F(_MklPadWithFusedConv2D);" \
"G(Zeta);X(" #INPUT ")|A->F;A:control->DMT/_0:control;A:control->DMT/" \
"_1:control;A:control->DMT/_2:control;A:control->DMT/_3:control;B->F:3;"\
"D->F:1;DMT/_0->F:4;DMT/_1->F:5;DMT/_2->F:6;DMT/_3->F:7;E->F:2;E->G:1;"\
"F->G;F:control->X:control"); \
}
REGISTER_TEST_ALL_TYPES(Output_ControlEdge_PadWithFusedConv2D_Positive);
#undef REGISTER_TEST
// Pad + _FusedConv2D with padding is VALID,
// Input node pointing to both Pad and _FusedConv2D
// Output of both Pad and _FusedConv2D feeds one node (G as Output2)
// A = input(as image), B = input(as paddings), C = Pad(A, B)
// E = input(as bias), F = _FusedConv2D(C, A, E), G = Output(C, F)
// After layout pass - No merging, since Pad and _FusedConv2D both
// feed to the same node (Z)
#define REGISTER_TEST(NAME, T, INPUT, OUTPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Pad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tpaddings' value { type: DT_INT32 } }" \
" input: ['A', 'B']}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: '_FusedConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \
"} }" \
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}" \
" input: ['C', 'A', 'E']}" \
"node { name: 'G' op: '" #OUTPUT "'" \
" input: ['C', 'F']}"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(Pad);DMT/_0(Const);DMT/_1(Const);DMT/" \
"_2(Const);E(" #INPUT ");F(_MklFusedConv2D);G(" #OUTPUT \
")|A->C;A->F:1;B->C:1;C->F;C->G;C:control->DMT/_0:control;" \
"C:control->DMT/_1:control;C:control->DMT/_2:control;DMT/_0->F:3;" \
"DMT/_1->F:4;DMT/_2->F:5;E->F:2;F->G:1"); \
}
REGISTER_TEST(NodeMerge_PadWithFusedConv2D_Common_InOutput, DT_FLOAT,
Float32Input, Float32Output2);
#ifdef ENABLE_INTEL_MKL_BFLOAT16
// TODO(nhasabni): Enable bfloat16 test when we enable the operator.
REGISTER_TEST(NodeMerge_PadWithFusedConv2D_Common_InOutput, DT_BFLOAT16,
BFloat16Input, BFloat16Output2);
#endif
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'Conv2DBackpropFilter'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(" #INPUT ");D(_MklConv2DBackpropFilter);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" \
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" \
"DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Conv2DGradFilter_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'Conv2DBackpropInput'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['B', 'A', 'C']}" \
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(" #INPUT ");D(_MklConv2DBackpropInput);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" \
"A->D:1;A->E;B->D;B:control->DMT/_0:control;" \
"B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;" \
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Conv2DGradInput_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'DepthwiseConv2dNativeBackpropFilter'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A', 'B', 'C']}" \
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(" #INPUT ");D(_" \
"MklDepthwiseConv2dNativeBackpropFilter);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" \
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" \
"DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_DepthwiseConv2dNativeGradFilter_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'DepthwiseConv2dNativeBackpropInput'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['B', 'A', 'C']}" \
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(" #INPUT ");D(_" \
"MklDepthwiseConv2dNativeBackpropInput);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" \
"A->D:1;A->E;B->D;B:control->DMT/_0:control;" \
"B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;" \
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_DepthwiseConv2DGradInput_Positive);
#undef REGISTER_TEST
// Check that we never rewrite BiasAddGrad.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Polygamma'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['C', 'A']}" \
"node { name: 'E' op: 'BiasAddGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" input: ['D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(Polygamma);D(Zeta);E(BiasAddGrad)|" \
"A->C;A->D:1;B->C:1;C->D;D->E"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_BiasAddGrad_Positive);
#undef REGISTER_TEST
// Check that we never rewrite BiasAddGrad.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'MatMul'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'transpose_a' value { b: false } }" \
" attr { key: 'transpose_b' value { b: false } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['C', 'A']}" \
"node { name: 'E' op: 'BiasAddGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" input: ['D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklMatMul);D(Zeta);E(BiasAddGrad)" \
"|A->C;A->D:1;B->C:1;C->D;D->E"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_BiasAddGrad_Positive1);
#undef REGISTER_TEST
// Check that we never rewrite BiasAddGrad.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'M' op: '_MklInput'}" \
"node { name: 'N' op: '_MklInput'}" \
"node { name: 'C' op: '_MklConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A', 'B', 'M', 'N']}" \
"node { name: 'D' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['C', 'A']}" \
"node { name: 'E' op: 'BiasAddGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" input: ['D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklConv2D);D(Zeta);E(BiasAddGrad);" \
"M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;" \
"M->C:2;N->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_BiasAddGrad_Positive2);
#undef REGISTER_TEST
// Concat Op test: Concat with no Mkl layer feeding it
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: 'Const' " \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value' value { " \
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " \
" int_val: 0 } } } }" \
"node { name: 'B' op: '" #INPUT "List'" \
" attr { key: 'N' value { i: 2 } }}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'Concat'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'N' value { i: 2 } }" \
" input: ['A', 'B:0', 'B:1']}" \
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['C', 'D'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(Const);B(" #INPUT "List);C(" #INPUT ");D(_MklConcat);DMT/_0(Const);"\
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" \
"B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Concat_Basic);
#undef REGISTER_TEST
// Concat with 2 Mkl layers feeding it
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'F' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['C', 'D']}" \
"node { name: 'G' op: 'Const' " \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value' value { " \
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " \
" int_val: 0 } } } }" \
"node { name: 'H' op: 'Concat'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'N' value { i: 2 } }" \
" input: ['G', 'E', 'F']}" \
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'H'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);" \
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" \
"F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"B->E:1;C->F;C:control->DMT/_2:control;C:control->DMT/_3:control;" \
"D->F:1;DMT/_0->E:2;DMT/_1->E:3;DMT/_2->F:2;DMT/_3->F:3;" \
"DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;" \
"G:control->DMT/_4:control;H->I:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Concat_Input_Mkl);
#undef REGISTER_TEST
// Concat with 1 Mkl and 1 non-Mkl layer feeding it
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
REGISTER_TEST_ALL_TYPES(NodeRewrite_Concat_Input_MixedMkl) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['C', 'D']}" \
"node { name: 'G' op: 'Const' " \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value' value { " \
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " \
" int_val: 0 } } } }" \
"node { name: 'H' op: 'Concat'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'N' value { i: 2 } }" \
" input: ['G', 'E', 'F']}" \
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'H'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);" \
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);" \
"H(_MklConcat);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" \
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;" \
"G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1"); \
}
// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: 'Const' " \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value' value { " \
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " \
" int_val: 0 } } } }" \
"node { name: 'B' op: '" #INPUT "List'" \
" attr { key: 'N' value { i: 2 } }}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'ConcatV2'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tidx' value { type: DT_INT32 } }" \
" attr { key: 'N' value { i: 2 } }" \
" input: ['B:0', 'B:1', 'A']}" \
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['C', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(Const);B(" #INPUT "List);C(" #INPUT ");D(_MklConcatV2);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D:2;B->D;" \
"B:1->D:1;B:control->DMT/_0:control;B:control->DMT/_1:control;" \
"B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;" \
"DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_ConcatV2_Basic);
#undef REGISTER_TEST
// ConcatV2 with 2 Mkl layers feeding it
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'F' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['C', 'D']}" \
"node { name: 'G' op: 'Const' " \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value' value { " \
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " \
" int_val: 0 } } } }" \
"node { name: 'H' op: 'ConcatV2'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tidx' value { type: DT_INT32 } }" \
" attr { key: 'N' value { i: 2 } }" \
" input: ['E', 'F', 'G']}" \
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'H'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);" \
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" \
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;B->E:1;C->F;" \
"C:control->DMT/_2:control;C:control->DMT/_3:control;" \
"D->F:1;DMT/_0->E:2;DMT/_1->E:3;DMT/_2->F:2;DMT/_3->F:3;" \
"DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;" \
"F:2->H:4;G->H:2;H->I:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_ConcatV2_Input_Mkl);
#undef REGISTER_TEST
// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'Conv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A', 'B']}" \
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['C', 'D']}" \
"node { name: 'G' op: 'Const' " \
" attr { key: 'dtype' value { type: DT_INT32 } }" \
" attr { key: 'value' value { " \
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " \
" int_val: 0 } } } }" \
"node { name: 'H' op: 'ConcatV2'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Tidx' value { type: DT_INT32 } }" \
" attr { key: 'N' value { i: 2 } }" \
" input: ['E', 'F', 'G']}" \
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'H'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);" \
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);" \
"H(_MklConcatV2);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" \
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;" \
"E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;" \
"G->H:2;H->I:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_ConcatV2_Input_MixedMkl);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Relu'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;" \
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Relu_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'ReluGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklReluGrad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_ReluGrad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Relu'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'ReluGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklRelu);C(_MklReluGrad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;" \
"DMT/_1->C:2"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_ReluReluGrad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Relu6'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklRelu6);C(Zeta);DMT/_0(Const)|A->B;A->C;" \
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Relu6_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Relu6Grad'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklRelu6Grad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Relu6Grad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Relu6'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Relu6Grad'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklRelu6);C(_MklRelu6Grad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;" \
"DMT/_1->C:2"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Relu6Relu6Grad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'LeakyRelu'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.1 } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklLeakyRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;" \
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyRelu_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'LeakyRelu'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 2.0 } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(LeakyRelu);C(Zeta)|A->B;A->C;B->C:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyRelu_Negative);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'LeakyReluGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.1 } }" \
" input: ['A', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklLeakyReluGrad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluGrad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'LeakyReluGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 2.0 } }" \
" input: ['A', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(LeakyReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluGrad_Negative);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'LeakyRelu'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.1 } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'LeakyReluGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.1 } }" \
" input: ['A', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklLeakyRelu);C(_MklLeakyReluGrad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;" \
"DMT/_1->C:2"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluLeakyReluGrad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Tanh'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklTanh);C(Zeta);DMT/_0(Const)|A->B;A->C;" \
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Tanh_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'TanhGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklTanhGrad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhGrad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Tanh'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'TanhGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['B', 'A'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklTanh);C(_MklTanhGrad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->B;A->C:1;A->D;A:control->DMT/_0:control;" \
"B->C;B:1->C:2;B:control->DMT/_1:control;C->D:1;DMT/_0->B:1;" \
"DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'AvgPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklAvgPool);C(Zeta);DMT/_0(Const)|A->B;A->C;" \
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_AvgPool_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: 'Int32Input'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'AvgPoolGrad' " \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" \
" input: ['A', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['B', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(Int32Input);B(" #INPUT ");C(_MklAvgPoolGrad);D(Zeta);" \
"DMT/_0(Const);DMT/_1(Const)|A->C;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" \
"DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_AvgPoolGrad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'I' op: 'Int32Input'}" \
"node { name: 'B' op: 'AvgPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'AvgPoolGrad' " \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" \
" input: ['I', 'B'] }" \
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklAvgPool);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;" \
"B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;" \
"I:control->DMT/_1:control"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_AvgPoolAvgPoolGrad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: 'FusedBatchNormGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'is_training' value { b: true } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);" \
"E(" #INPUT ");F(_MklFusedBatchNormGrad);G(Zeta)|A->F;A->G;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;A:control->DMT/_3:control;" \
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \
"E->F:4;F->G:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGrad_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Float32Input'}" \
"node { name: 'D' op: 'Float32Input'}" \
"node { name: 'E' op: 'Float32Input'}" \
"node { name: 'F' op: 'FusedBatchNormGradV2'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'is_training' value { b: true } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(Float32Input);D(Float32Input);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);" \
"E(Float32Input);F(_MklFusedBatchNormGradV2);G(Zeta)|A->F;A->G;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;A:control->DMT/_3:control;" \
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \
"E->F:4;F->G:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV2_Positive);
#undef REGISTER_TEST
// T, U combination is not supported by MKL. Node will not be rewritten
// into MKL node.
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGradV2_Negative) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'HalfInput'}"
"node { name: 'B' op: 'HalfInput'}"
"node { name: 'C' op: 'Float32Input'}"
"node { name: 'D' op: 'Float32Input'}"
"node { name: 'E' op: 'Float32Input'}"
"node { name: 'F' op: 'FusedBatchNormGradV2'"
" attr { key: 'T' value { type: DT_HALF } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(HalfInput);B(HalfInput);C(Float32Input);D(Float32Input);"
"E(Float32Input);F(FusedBatchNormGradV2);G(Zeta)|A->F;A->G;"
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
}
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: '" #INPUT "'}" \
"node { name: 'F' op: 'FusedBatchNorm'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'is_training' value { b: true } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(" #INPUT ");DMT/_0(Const);" \
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(" #INPUT ");" \
"F(_MklFusedBatchNorm);G(Zeta)|A->F;A->G;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;A:control->DMT/_3:control;" \
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \
"E->F:4;F->G:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNorm_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Float32Input'}" \
"node { name: 'C' op: 'Float32Input'}" \
"node { name: 'D' op: 'Float32Input'}" \
"node { name: 'E' op: 'Float32Input'}" \
"node { name: 'F' op: 'FusedBatchNormV2'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'is_training' value { b: true } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Float32Input);C(Float32Input);D(Float32Input);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);" \
"E(Float32Input);F(_MklFusedBatchNormV2);G(Zeta)|A->F;A->G;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;A:control->DMT/_3:control;" \
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \
"E->F:4;F->G:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV2_Positive);
#undef REGISTER_TEST
// T, U combination is not supported by MKL. Node will not be rewritten
// into MKL node.
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV2_Negative) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'HalfInput'}"
"node { name: 'B' op: 'Float32Input'}"
"node { name: 'C' op: 'Float32Input'}"
"node { name: 'D' op: 'Float32Input'}"
"node { name: 'E' op: 'Float32Input'}"
"node { name: 'F' op: 'FusedBatchNormV2'"
" attr { key: 'T' value { type: DT_HALF } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(HalfInput);B(Float32Input);C(Float32Input);D(Float32Input);"
"E(Float32Input);F(FusedBatchNormV2);G(Zeta)|A->F;A->G;"
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
}
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Float32Input'}" \
"node { name: 'C' op: 'Float32Input'}" \
"node { name: 'D' op: 'Float32Input'}" \
"node { name: 'E' op: 'Float32Input'}" \
"node { name: 'F' op: 'FusedBatchNormV3'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'is_training' value { b: true } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Float32Input);C(Float32Input);D(Float32Input);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);" \
"E(Float32Input);F(_MklFusedBatchNormV3);G(Zeta)|A->F;A->G;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;A:control->DMT/_3:control;" \
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \
"E->F:4;F->G:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_Positive);
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Float32Input'}" \
"node { name: 'C' op: 'Float32Input'}" \
"node { name: 'D' op: 'Float32Input'}" \
"node { name: 'E' op: 'Float32Input'}" \
"node { name: 'F' op: 'FusedBatchNormV3'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: " DATA_FORMAT " } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'is_training' value { b: true } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Float32Input);C(Float32Input);" \
"D(Float32Input);E(Float32Input);F(FusedBatchNormV3);G(Zeta)" \
"|A->F;A->G;B->F:1;C->F:2;D->F:3;E->F:4;F->G:1"); \
}
#define DATA_FORMAT "'NCDHW'"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_5D_Negative_1);
#define DATA_FORMAT "'NDHWC'"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_5D_Negative_2);
#undef DATA_FORMAT
#undef REGISTER_TEST
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'HalfInput'}"
"node { name: 'B' op: 'Float32Input'}"
"node { name: 'C' op: 'Float32Input'}"
"node { name: 'D' op: 'Float32Input'}"
"node { name: 'E' op: 'Float32Input'}"
"node { name: 'F' op: 'FusedBatchNormV3'"
" attr { key: 'T' value { type: DT_HALF } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(HalfInput);B(Float32Input);C(Float32Input);D(Float32Input);"
"E(Float32Input);F(FusedBatchNormV3);G(Zeta)|A->F;A->G;"
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
}
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: 'Float32Input'}" \
"node { name: 'D' op: 'Float32Input'}" \
"node { name: 'E' op: 'Float32Input'}" \
"node { name: 'F' op: 'Float32Input'}" \
"node { name: 'G' op: 'FusedBatchNormGradV3'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: " DATA_FORMAT " } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'is_training' value { b: true } }" \
" input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \
"node { name: 'H' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'G'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(Float32Input);D(Float32Input);" \
"E(Float32Input);F(Float32Input);G(FusedBatchNormGradV3);H(Zeta)" \
"|A->G;A->H;B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \
}
#define DATA_FORMAT "'NCDHW'"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_1);
#define DATA_FORMAT "'NDHWC'"
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_2);
#undef DATA_FORMAT
#undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Input'}" \
"node { name: 'C' op: 'Input'}" \
"node { name: 'D' op: 'Input'}" \
"node { name: 'E' op: 'Input'}" \
"node { name: 'F' op: '_FusedBatchNormEx'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'num_side_inputs' value { i: 0 } }" \
" attr { key: 'is_training' value { b: true } }" \
" attr { key: 'activation_mode' value { s: 'Relu' } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Input);C(Input);D(Input);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);" \
"DMT/_4(Const);E(Input);" \
"F(_MklFusedBatchNormEx);G(Zeta)|A->F;A->G;" \
"A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;A:control->DMT/_3:control;" \
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \
"E->F:4;F->G:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Positive);
#undef REGISTER_TEST
// Rewrite test for _FusedBatchNormEx Op with side input
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Input'}" \
"node { name: 'C' op: 'Input'}" \
"node { name: 'D' op: 'Input'}" \
"node { name: 'E' op: 'Input'}" \
"node { name: 'F' op: '" #INPUT "'}" \
"node { name: 'G' op: '_FusedBatchNormEx'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'num_side_inputs' value { i: 1 } }" \
" attr { key: 'is_training' value { b: true } }" \
" attr { key: 'activation_mode' value { s: 'Relu' } }" \
" input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \
"node { name: 'H' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'G'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Input);C(Input);D(Input);E(Input);" \
"F(" #INPUT ");G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \
"B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative1);
#undef REGISTER_TEST
// Rewrite test for _FusedBatchNormEx Op with Identity activation
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Input'}" \
"node { name: 'C' op: 'Input'}" \
"node { name: 'D' op: 'Input'}" \
"node { name: 'E' op: 'Input'}" \
"node { name: 'G' op: '_FusedBatchNormEx'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'U' value { type: DT_FLOAT } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'epsilon' value { f: 0.0001 } }" \
" attr { key: 'num_side_inputs' value { i: 1 } }" \
" attr { key: 'is_training' value { b: true } }" \
" attr { key: 'activation_mode' value { s: 'Identity' } }" \
" input: ['A', 'B', 'C', 'D', 'E'] }" \
"node { name: 'H' op: 'Zeta'" \
" attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'G'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Input);C(Input);D(Input);E(Input);" \
"G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \
"B->G:1;C->G:2;D->G:3;E->G:4;G->H:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2);
#undef REGISTER_TEST
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
InitGraph(
"node { name: 'A' op: 'QuantizedUnsignedInt8Input'}"
"node { name: 'B' op: 'QuantizedSignedInt8Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'QuantizedSignedInt32Input'}"
"node { name: 'H' op: 'QuantizedDepthwiseConv2D'"
" attr { key: 'Tinput' value { type: DT_QUINT8 } }"
" attr { key: 'Tfilter' value { type: DT_QINT8 } }"
" attr { key: 'out_type' value { type: DT_QINT32 } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F'] }"
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_QINT32 } }"
" input: ['G', 'H'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(QuantizedUnsignedInt8Input);B(QuantizedSignedInt8Input);C(Input);"
"D(Input);DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);"
"DMT/_4(Const);DMT/_5(Const);E(Input);F(Input);"
"G(QuantizedSignedInt32Input);H(_MklQuantizedDepthwiseConv2D);I(Zeta)"
"|A->H;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;A:control->DMT/_5:control;B->H:1;C->H:2;"
"D->H:3;DMT/_0->H:6;DMT/_1->H:7;DMT/_2->H:8;DMT/_3->H:9;DMT/_4->H:10;"
"DMT/_5->H:11;E->H:4;F->H:5;G->I;H->I:1");
}
/////////////////////////////////////////////////////////////////////
// Unit tests related to context-based node rewrite
/////////////////////////////////////////////////////////////////////
// If any of the inputs is an MKL op, then rewrite Slice to Mkl op.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'M' op: '_MklInput'}" \
"node { name: 'N' op: '_MklInput'}" \
"node { name: 'C' op: '_MklConv2D'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'use_cudnn_on_gpu' value { b: false } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'SAME' } }" \
" input: ['A', 'B', 'M', 'N']}" \
"node { name: 'D' op: 'Int32Input'}" \
"node { name: 'E' op: 'Int32Input'}" \
"node { name: 'F' op: 'Slice'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Index' value { type: DT_INT32 } }" \
" input: ['C', 'D', 'E'] }" \
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(_MklConv2D);D(Int32Input);" \
"DMT/_0(Const);DMT/_1(Const);" \
"E(Int32Input);F(_MklSlice);G(Zeta);M(_MklInput);N(_MklInput)|" \
"A->C;A->G;B->C:1;C->F;C->G:1;C:2->F:3;" \
"C:control->DMT/_0:control;C:control->DMT/" \
"_1:control;" \
"D->F:1;DMT/_0->F:4;DMT/_1->F:5;" \
"E->F:2;M->C:2;N->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Ctxbased_Slice_Positive);
#undef REGISTER_TEST
// If none of the inputs is an MKL op, then Slice should not be rewritten.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'Int32Input'}" \
"node { name: 'C' op: 'Int32Input'}" \
"node { name: 'D' op: 'Slice'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'Index' value { type: DT_INT32 } }" \
" input: ['A', 'B', 'C'] }" \
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Int32Input);C(Int32Input);" \
"D(Slice);E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Ctxbased_Slice_Negative);
#undef REGISTER_TEST
/////////////////////////////////////////////////////////////////////
// Unit tests related to rewriting node for workspace edges
/////////////////////////////////////////////////////////////////////
/* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'LRN'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.001 } }" \
" attr { key: 'beta' value { f: 0.75 } }" \
" attr { key: 'bias' value { f: 1.0 } }" \
" attr { key: 'depth_radius' value { i: 2 } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" \
" input: ['B'] }" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'MaxPoolGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" \
" input: ['B', 'C', 'D'] }" \
"node { name: 'F' op: '" #INPUT "'}" \
"node { name: 'G' op: 'LRNGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.001 } }" \
" attr { key: 'beta' value { f: 0.75 } }" \
" attr { key: 'bias' value { f: 1.0 } }" \
" attr { key: 'depth_radius' value { i: 2 } }" \
" input: ['E', 'F', 'B'] }" \
"node { name: 'H' op: '" #INPUT "'}" \
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['H', 'G'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklLRN);C(_MklMaxPool);D(" #INPUT ");DMT/_0(Const);" \
"DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(" #INPUT ");" \
"G(_MklLRNGrad);H(" #INPUT ");I(Zeta)|A->B;A:control->DMT/_0:control;" \
"B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;" \
"B:control->DMT/_1:control;C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;" \
"DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;E->G;E:1->G:4;" \
"E:control->DMT/_2:control;F->G:1;G->I:1;H->I"); \
}
REGISTER_TEST_FLOAT32(MaxPoolLRN_Positive);
// TODO(nhasabni): Enable bfloat16 test when we enable the operator.
#undef REGISTER_TEST
/* Test LRN->LRNGrad replacement by workspace nodes. */
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'LRN'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.001 } }" \
" attr { key: 'beta' value { f: 0.75 } }" \
" attr { key: 'bias' value { f: 1.0 } }" \
" attr { key: 'depth_radius' value { i: 2 } }" \
" input: ['A'] }" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'LRNGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.001 } }" \
" attr { key: 'beta' value { f: 0.75 } }" \
" attr { key: 'bias' value { f: 1.0 } }" \
" attr { key: 'depth_radius' value { i: 2 } }" \
" input: ['C', 'D', 'B'] }" \
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['C', 'E'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklLRN);C(" #INPUT ");D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(_MklLRNGrad);F(Zeta)|" \
"A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;" \
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" \
"D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1"); \
}
REGISTER_TEST_FLOAT32(LRN_Positive);
#undef REGISTER_TEST
/* Test LRN->LRNGrad replacement when only one of them is present. */
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'LRN'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.001 } }" \
" attr { key: 'beta' value { f: 0.75 } }" \
" attr { key: 'bias' value { f: 1.0 } }" \
" attr { key: 'depth_radius' value { i: 2 } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklLRN);C(Zeta);DMT/_0(Const)|" \
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \
}
REGISTER_TEST_FLOAT32(LRN_Negative1);
#undef REGISTER_TEST
/* Test LRN->LRNGrad replacement when only one of them is present. */
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'LRNGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.001 } }" \
" attr { key: 'beta' value { f: 0.75 } }" \
" attr { key: 'bias' value { f: 1.0 } }" \
" attr { key: 'depth_radius' value { i: 2 } }" \
" input: ['A', 'B', 'C'] }" \
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(LRNGrad);" \
"E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1"); \
}
REGISTER_TEST_ALL_TYPES(LRN_Negative2);
#undef REGISTER_TEST
/* Test LRN->LRNGrad negative case, where single LRN feeds
2 LRNGrad nodes at different slots. */
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'LRN'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.001 } }" \
" attr { key: 'beta' value { f: 0.75 } }" \
" attr { key: 'bias' value { f: 1.0 } }" \
" attr { key: 'depth_radius' value { i: 2 } }" \
" input: ['A'] }" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'LRNGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.001 } }" \
" attr { key: 'beta' value { f: 0.75 } }" \
" attr { key: 'bias' value { f: 1.0 } }" \
" attr { key: 'depth_radius' value { i: 2 } }" \
" input: ['C', 'D', 'B'] }" \
"node { name: 'F' op: 'LRNGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'alpha' value { f: 0.001 } }" \
" attr { key: 'beta' value { f: 0.75 } }" \
" attr { key: 'bias' value { f: 1.0 } }" \
" attr { key: 'depth_radius' value { i: 2 } }" \
" input: ['C', 'B', 'D'] }" \
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['E', 'F'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklLRN);C(" #INPUT ");D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(_MklLRNGrad);" \
"F(LRNGrad);G(Zeta)|A->B;A:control->DMT/_0:control;B->E:2;B->F:1;" \
"B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;C:control->DMT/_1:control;" \
"C:control->DMT/_2:control;D->E:1;D->F:2;DMT/_0->B:1;" \
"DMT/_1->E:4;DMT/_2->E:5;E->G;F->G:1"); \
}
REGISTER_TEST_FLOAT32(LRN_Negative3);
#undef REGISTER_TEST
/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" \
" input: ['A'] }" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '" #INPUT "'}" \
"node { name: 'E' op: 'MaxPoolGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" \
" input: ['C', 'B', 'D'] }" \
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['C', 'E'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklMaxPool);C(" #INPUT ");D(" #INPUT ");" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);" \
"F(Zeta)|A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;" \
"B:3->E:7;C->E;C->F;C:control->DMT/_1:control;" \
"C:control->DMT/_2:control;D->E:2;DMT/_0->B:1;DMT/_1->E:4;" \
"DMT/_2->E:6;E->F:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Positive);
#undef REGISTER_TEST
// Test MaxPool>MaxPoolGrad replacement when only one of them is present.
// In this case, we will rewrite MaxPool node but workspace edges will not
// be present.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(_MklMaxPool);C(Zeta);DMT/_0(Const)|" \
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Negative1);
#undef REGISTER_TEST
// Test MaxPoolGrad replacement when only one of them is present.
// In this case, we will rewrite MaxPoolGrad and for workspace tensor and
// its Mkl part, we will generate dummy tensor.
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: 'MaxPoolGrad'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" \
" input: ['A', 'B', 'C'] }" \
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'D'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(MaxPoolGrad);" \
"E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Negative2);
#undef REGISTER_TEST
// Test MaxPool handling for batch-wise pooling (NCHW)
// No rewrite should take place in such case
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Negative3);
#undef REGISTER_TEST
// Test MaxPool handling for batch-wise pooling (NCHW)
// No rewrite should take place in such case
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Negative4);
#undef REGISTER_TEST
// Test MaxPool handling for depth-wise pooling (NHWC)
// No rewrite should take place in such case
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Negative5);
#undef REGISTER_TEST
// Test MaxPool handling for depth-wise pooling (NCHW)
// No rewrite should take place in such case
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NCHW' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Negative6);
#undef REGISTER_TEST
// Test MaxPool handling for batch-wise pooling (NHWC)
// No rewrite should take place in such case
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Negative7);
#undef REGISTER_TEST
// Test MaxPool handling for batch-wise pooling (NHWC)
// No rewrite should take place in such case
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Negative8);
#undef REGISTER_TEST
// Test MaxPool handling for depth-wise pooling (NHWC)
// No rewrite should take place in such case
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Negative9);
#undef REGISTER_TEST
// Test MaxPool handling for depth-wise pooling (NHWC)
// No rewrite should take place in such case
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: 'MaxPool'" \
" attr { key: 'T' value { type: " #T " } }" \
" attr { key: 'data_format' value { s: 'NHWC' } }" \
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" \
" attr { key: 'padding' value { s: 'VALID' } }" \
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }" \
" input: ['A'] }" \
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeWorkspace_MaxPool_Negative10);
#undef REGISTER_TEST
// clang-format on
/////////////////////////////////////////////////////////////////////
// Single Conv2D Op on GPU device
// No rewrite should happen
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'O' op: '_MklInput'}"
"node { name: 'D' op: '_MklConv2DWithBias'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
"node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['E'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
"E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;"
"M->D:3;N->D:4;O->D:5");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Conv2DBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1");
}
TEST_F(MklLayoutPassTest,
NodeRewrite_DepthwiseConv2dNativeGradFilter_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'DepthwiseConv2dNativeBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D("
"DepthwiseConv2dNativeBackpropFilter);E(Zeta)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'ReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu6_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu6'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Relu6);C(Zeta)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu6Grad_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Relu6Grad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Relu6Grad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'AvgPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Concat Op test: Concat with no Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'InputList'"
" attr { key: 'N' value { i: 2 } }}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Concat'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['A', 'B:0', 'B:1']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
"B->D:1;B:1->D:2;C->E;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'InputList'"
" attr { key: 'N' value { i: 2 } }}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'ConcatV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tidx' value { type: DT_INT32 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['B:0', 'B:1', 'A']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_IndexTest) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT64 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'InputList'"
" attr { key: 'N' value { i: 2 } }}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'ConcatV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tidx' value { type: DT_INT64 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['B:0', 'B:1', 'A']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNorm'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV2_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNorm'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNorm'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
"M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Int32Input'}"
"node { name: 'D' op: 'Slice'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Index' value { type: DT_INT32 } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Int32Input);D(Slice);E(Zeta)|A->D;A->E;"
"B->D:1;C->D:2;D->E:1");
}
// The following positive and negative tests test the rewrite of Add and AddV2
// to MKL versions. The operators will be rewritten only if one of the inputs
// comes from another MKL operator.
TEST_F(MklLayoutPassTest, PositiveRewriteAdd) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A']}"
"node { name: 'N' op: 'Add'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['M', 'B']}");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAdd)"
"|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;"
"M:1->N:2;M:control->DMT/_1:control");
}
TEST_F(MklLayoutPassTest, NegativeRewriteAdd) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'N' op: 'Add'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);N(Add)|A->N;B->N:1");
}
TEST_F(MklLayoutPassTest, PositiveRewriteAddV2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A']}"
"node { name: 'N' op: 'AddV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['M', 'B']}");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAddV2)"
"|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;"
"M:1->N:2;M:control->DMT/_1:control");
}
TEST_F(MklLayoutPassTest, NegativeRewriteAddV2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'N' op: 'AddV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);N(AddV2)|A->N;B->N:1");
}
/////////////////////////////////////////////////////////////////////
// Post-rewrite fixup pass test
/////////////////////////////////////////////////////////////////////
TEST_F(MklLayoutPassTest, PostRewriteFixUpPass) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Const' "
" attr { key: 'dtype' value { type: DT_UINT8 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_UINT8 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'E' op: '_MklAdd'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A', 'D', 'D']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Const);E(_MklAdd);"
"M(_MklInput);N(_MklInput)|A->C;A->E:1;B->C:1;C->E;C:2->E:2;"
"D->E:3;M->C:2;N->C:3");
}
/////////////////////////////////////////////////////////////////////
// Unit tests related to filter caching.
//
// These tests check if the attribute `is_filter_const` is set to true
// when filter is a constant and false otherwise for various operators
// such as Conv2D, Conv2DWithBias, Conv3D etc.
/////////////////////////////////////////////////////////////////////
// Conv2D op where filter is a constant.
TEST_F(MklLayoutPassTest, Conv2D_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const' " // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklConv2D"));
}
// Conv2D op where filter is NOT a constant.
TEST_F(MklLayoutPassTest, Conv2D_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklConv2D"));
}
// Conv2D + BiasAdd fusion where filter is a constant.
TEST_F(MklLayoutPassTest, Conv2DWithBias_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const'" // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklConv2DWithBias"));
}
// Conv2D + BiasAdd fusion where filter is NOT a constant.
TEST_F(MklLayoutPassTest, Conv2DWithBias_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklConv2DWithBias"));
}
// Conv3D op where filter is a constant.
TEST_F(MklLayoutPassTest, Conv3D_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const' " // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Conv3D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCDHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1, "
"i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1, "
"i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklConv3D"));
}
// Conv3D op where filter is NOT a constant.
TEST_F(MklLayoutPassTest, Conv3D_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'Conv3D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCDHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1, "
"i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1, "
"i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklConv3D"));
}
// Pad + Conv2D fusion where filter is a constant.
TEST_F(MklLayoutPassTest, PadWithConv2D_FilterCaching_Positive) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Const'" // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklPadWithConv2D"));
}
// Pad + Conv2D fusion where filter is NOT a constant.
TEST_F(MklLayoutPassTest, PadWithConv2D_FilterCaching_Negative) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}" // Filter
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklPadWithConv2D"));
}
// _FusedConv2D + BiasAdd fusion where filter is a constant.
TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const'" // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklFusedConv2D"));
}
// _FusedDepthwiseConv2dNative + BiasAdd fusion where filter is a constant.
TEST_F(MklLayoutPassTest,
FusedDepthwiseConv2dNativeWithBias_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const'" // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklFusedDepthwiseConv2dNative"));
}
// _FusedConv2D + BiasAdd fusion where filter is NOT a constant.
TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklFusedConv2D"));
}
// _FusedDepthwiseConv2dNative + BiasAdd fusion where filter is NOT a constant.
TEST_F(MklLayoutPassTest,
FusedDepthwiseConv2dNativeWithBias_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedDepthwiseConv2dNative'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" attr { key: 'leakyrelu_alpha' value { f: 0.2 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklFusedDepthwiseConv2dNative"));
}
// Depthwise Conv2D op where filter is a constant.
TEST_F(MklLayoutPassTest, DepthwiseConv2dNative_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const'" // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'DepthwiseConv2dNative'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklDepthwiseConv2dNative"));
}
// Depthwise Conv2D op where filter is NOT a constant.
TEST_F(MklLayoutPassTest, DepthwiseConv2dNative_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'DepthwiseConv2dNative'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklDepthwiseConv2dNative"));
}
// Fused QuantizedMatMulWithBias Op Rewrite test
// Rewrite the QuantizedMatMulWithBias with _MklQuantizedMatMulWithBias
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedMatMulWithBias_Positive) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QInt8Input' }"
"node { name: 'C' op: 'QInt32Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'QInt32Input'}"
"node { name: 'I' op: 'QuantizedMatMulWithBias'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QINT8 } }"
" attr { key: 'Tbias' value { type: DT_QINT32 } }"
" attr { key: 'Toutput' value { type: DT_QINT32 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G']}"
"node { name: 'J' op: 'Zeta' attr { key: 'T' value { type: DT_QINT32 } }"
" input: ['I', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QInt8Input);C(QInt32Input);D(Input);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);"
"DMT/_4(Const);DMT/_5(Const);DMT/_6(Const);E(Input);F(Input);"
"G(Input);H(QInt32Input);I(_MklQuantizedMatMulWithBias);"
"J(Zeta)|A->I;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;A:control->DMT/_5:control;"
"A:control->DMT/_6:control;B->I:1;C->I:2;D->I:3;DMT/_0->I:7;"
"DMT/_1->I:8;DMT/_2->I:9;DMT/_3->I:10;DMT/_4->I:11;DMT/_5->I:12;"
"DMT/_6->I:13;E->I:4;F->I:5;G->I:6;H->J:1;I->J");
}
// Rewrite test for QuantizedMatMulWithBias Op with unsupported input
// Rewrite should not happen
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedMatMulWithBias_Negative) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QUInt8Input' }"
"node { name: 'C' op: 'QInt32Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'QInt32Input'}"
"node { name: 'I' op: 'QuantizedMatMulWithBias'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QUINT8 } }"
" attr { key: 'Tbias' value { type: DT_QINT32 } }"
" attr { key: 'Toutput' value { type: DT_QINT32 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G']}"
"node { name: 'J' op: 'Zeta' attr { key: 'T' value { type: DT_QINT32 } }"
" input: ['I', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QUInt8Input);C(QInt32Input);D(Input);E(Input);"
"F(Input);G(Input);H(QInt32Input);I(QuantizedMatMulWithBias);"
"J(Zeta)|A->I;B->I:1;C->I:2;D->I:3;"
"E->I:4;F->I:5;G->I:6;H->J:1;I->J");
}
// Fused QuantizedMatMulWithBiasAndRelu Op Rewrite test
// Rewrite the QuantizedMatMulWithBiasAndRelu with
// _MklQuantizedMatMulWithBiasAndRelu
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedMatMulWithBiasAndRelu_Positive) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QInt8Input' }"
"node { name: 'C' op: 'Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'QInt32Input'}"
"node { name: 'I' op: 'QuantizedMatMulWithBiasAndRelu'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QINT8 } }"
" attr { key: 'Toutput' value { type: DT_QINT32 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G']}"
"node { name: 'J' op: 'Zeta' attr { key: 'T' value { type: DT_QINT32 } }"
" input: ['I', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QInt8Input);C(Input);D(Input);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);"
"DMT/_4(Const);DMT/_5(Const);DMT/_6(Const);E(Input);F(Input);"
"G(Input);H(QInt32Input);I(_MklQuantizedMatMulWithBiasAndRelu);"
"J(Zeta)|A->I;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;A:control->DMT/_5:control;"
"A:control->DMT/_6:control;B->I:1;C->I:2;D->I:3;DMT/_0->I:7;"
"DMT/_1->I:8;DMT/_2->I:9;DMT/_3->I:10;DMT/_4->I:11;DMT/_5->I:12;"
"DMT/_6->I:13;E->I:4;F->I:5;G->I:6;H->J:1;I->J");
}
// Rewrite test for QuantizedMatMulWithBiasAndRelu Op with unsupported input
// Rewrite should not happen
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedMatMulWithBiasAndRelu_Negative) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QUInt8Input' }"
"node { name: 'C' op: 'Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'QInt32Input'}"
"node { name: 'I' op: 'QuantizedMatMulWithBiasAndRelu'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QUINT8 } }"
" attr { key: 'Toutput' value { type: DT_QINT32 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G']}"
"node { name: 'J' op: 'Zeta' attr { key: 'T' value { type: DT_QINT32 } }"
" input: ['I', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QUInt8Input);C(Input);D(Input);"
"E(Input);F(Input);G(Input);H(QInt32Input);"
"I(QuantizedMatMulWithBiasAndRelu);J(Zeta)|A->I;"
"B->I:1;C->I:2;D->I:3;E->I:4;F->I:5;"
"G->I:6;H->J:1;I->J");
}
// Fused QuantizedMatMulWithBiasAndReluAndRequantize Op Rewrite test
// Rewrite the QuantizedMatMulWithBiasAndReluAndRequantize with
// _MklQuantizedMatMulWithBiasAndReluAndRequantize
TEST_F(MklLayoutPassTest,
NodeRewrite_QuantizedMatMulWithBiasAndReluAndRequantize_Positive) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QInt8Input' }"
"node { name: 'C' op: 'QInt32Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'Input'}"
"node { name: 'I' op: 'Input'}"
"node { name: 'J' op: 'QUInt8Input'}"
"node { name: 'K' op: 'QuantizedMatMulWithBiasAndReluAndRequantize'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QINT8 } }"
" attr { key: 'Tbias' value { type: DT_QINT32 } }"
" attr { key: 'Toutput' value { type: DT_QUINT8 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']}"
"node { name: 'L' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['K', 'J'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QInt8Input);C(QInt32Input);"
"D(Input);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);"
"DMT/_5(Const);DMT/_6(Const);DMT/_7(Const);DMT/_8(Const);E(Input);"
"F(Input);G(Input);H(Input);I(Input);J(QUInt8Input);"
"K(_MklQuantizedMatMulWithBiasAndReluAndRequantize);L(Zeta)|A->K;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;A:control->DMT/_5:control;"
"A:control->DMT/_6:control;A:control->DMT/_7:control;"
"A:control->DMT/_8:control;B->K:1;C->K:2;D->K:3;DMT/_0->K:9;"
"DMT/_1->K:10;DMT/_2->K:11;DMT/_3->K:12;DMT/_4->K:13;DMT/_5->K:14;"
"DMT/_6->K:15;DMT/_7->K:16;DMT/_8->K:17;E->K:4;F->K:5;G->K:6;"
"H->K:7;I->K:8;J->L:1;K->L");
}
// Rewrite test for QuantizedMatMulWithBiasAndRelu Op with unsupported input
// Rewrite should not happen
TEST_F(MklLayoutPassTest,
NodeRewrite_QuantizedMatMulWithBiasAndReluAndRequantize_Negative) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QUInt8Input' }"
"node { name: 'C' op: 'QInt32Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'Input'}"
"node { name: 'I' op: 'Input'}"
"node { name: 'J' op: 'QUInt8Input'}"
"node { name: 'K' op: 'QuantizedMatMulWithBiasAndReluAndRequantize'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QUINT8 } }"
" attr { key: 'Tbias' value { type: DT_QINT32 } }"
" attr { key: 'Toutput' value { type: DT_QUINT8 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']}"
"node { name: 'L' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['K', 'J'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QUInt8Input);C(QInt32Input);"
"D(Input);E(Input);F(Input);G(Input);H(Input);I(Input);"
"J(QUInt8Input);"
"K(QuantizedMatMulWithBiasAndReluAndRequantize);L(Zeta)|A->K;"
"B->K:1;C->K:2;D->K:3;E->K:4;F->K:5;G->K:6;"
"H->K:7;I->K:8;J->L:1;K->L");
}
TEST_F(MklLayoutPassTest, MatMul_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'MatMul'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklMatMul)|A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, BatchMatMul_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'BatchMatMul'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklBatchMatMul)|A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, BatchMatMulV2_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'BatchMatMulV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklBatchMatMulV2)|A->C;B->C:1");
}
static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
testing::StopTiming();
string s;
for (int in = 0; in < 10; in++) {
s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
}
random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox);
for (int op = 0; op < op_nodes; op++) {
s += strings::Printf(
"node { name: 'op%04d' op: 'Zeta' attr { key: 'T' value { "
"type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
op, rnd.Uniform(10), rnd.Uniform(10));
}
bool first = true;
while (iters > 0) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
InitGraph(s, graph.get());
int N = graph->num_node_ids();
if (first) {
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
first = false;
}
{
testing::StartTiming();
std::unique_ptr<Graph> ug(graph.get());
RunMklLayoutRewritePass(&ug);
testing::StopTiming();
}
iters -= N; // Our benchmark units are individual graph nodes,
// not whole graphs
}
}
BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace
} // namespace tensorflow
#endif // INTEL_MKL && ENABLE_MKL