blob: df54c9f745ea7fcf55eb134a21d0b174b4d37680 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include <algorithm>
#include <vector>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
// NOTE: Unit tests in this file rely on a topological sorted graph for
// printing. But since sibling nodes of a node in the topologically sorted graph
// can be printed in different orders, tests may fail if the order in which
// sibling nodes are visited is changed.
namespace {
const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0";
const char kGPUDevice[] = "/job:a/replica:0/task:0/device:GPU:0";
static void InitGraph(const string& s, Graph* graph,
const string& device = kCPUDevice) {
GraphDef graph_def;
auto parser = protobuf::TextFormat::Parser();
// parser.AllowRelaxedWhitespace(true);
CHECK(parser.MergeFromString(s, &graph_def)) << s;
GraphConstructorOptions opts;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
for (Node* node : graph->nodes()) {
node->set_assigned_device_name(device);
}
}
class MklLayoutPassTest : public ::testing::Test {
public:
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
// Ashraf added
Node* FindNode(const string& name) {
for (Node* node : graph_.nodes()) {
if (node->name() == name) return node;
}
LOG(FATAL) << name;
}
void InitGraph(const string& s, const string& device = kCPUDevice) {
::tensorflow::InitGraph(s, &graph_, device);
original_ = CanonicalGraphString(&graph_);
}
static bool IncludeNode(const Node* n) { return n->IsOp(); }
static string EdgeId(const Node* n, int index) {
if (index == 0) {
return n->name();
} else if (index == Graph::kControlSlot) {
return strings::StrCat(n->name(), ":control");
} else {
return strings::StrCat(n->name(), ":", index);
}
}
string CanonicalGraphString(Graph* g) {
std::vector<string> nodes;
std::vector<string> edges;
for (const Node* n : g->nodes()) {
if (IncludeNode(n)) {
nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
}
}
for (const Edge* e : g->edges()) {
if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
EdgeId(e->dst(), e->dst_input())));
}
}
// Canonicalize
std::sort(nodes.begin(), nodes.end());
std::sort(edges.begin(), edges.end());
return strings::StrCat(absl::StrJoin(nodes, ";"), "|",
absl::StrJoin(edges, ";"));
}
string DoMklLayoutOptimizationPass() {
string before = CanonicalGraphString(&graph_);
LOG(ERROR) << "Before MKL layout rewrite pass: " << before;
std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
RunMklLayoutRewritePass(ug);
string result = CanonicalGraphString(&graph_);
LOG(ERROR) << "After MKL layout rewrite pass: " << result;
return result;
}
// Returns the attribute value only from the first node
template <typename T>
T DoMklLayoutOptimizationPassGetAttrVal(const string& attr,
const string& node_name) {
DoMklLayoutOptimizationPass();
T attr_val;
for (const Node* n : graph_.nodes()) {
if (IncludeNode(n) && n->type_string() == node_name) {
TF_CHECK_OK(GetNodeAttr(n->def(), attr, &attr_val));
return attr_val;
}
}
return attr_val;
}
const string& OriginalGraph() const { return original_; }
Graph graph_;
string original_;
};
REGISTER_OP("Input").Output("o: float").SetIsStateful();
REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
REGISTER_OP("DoubleInput").Output("o: double").SetIsStateful();
REGISTER_OP("QuantizedInput").Output("o: quint8").SetIsStateful();
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
REGISTER_OP("_MklInput2")
.Output("o: uint8")
.Output("o1: uint8")
.SetIsStateful();
REGISTER_OP("QuantizedUnsignedInt8Input").Output("o: quint8").SetIsStateful();
REGISTER_OP("QuantizedSignedInt8Input").Output("o: qint8").SetIsStateful();
REGISTER_OP("QuantizedSignedInt32Input").Output("o: qint32").SetIsStateful();
REGISTER_OP("Output2").Input("i: float").Input("i1: float").SetIsStateful();
REGISTER_OP("Output").Input("i: float").SetIsStateful();
REGISTER_OP("QInt8Input").Output("o: qint8").SetIsStateful();
REGISTER_OP("QUInt8Input").Output("o: quint8").SetIsStateful();
REGISTER_OP("QInt32Input").Output("o: qint32").SetIsStateful();
/////////////////////////////////////////////////////////////////////
// Unit tests related to node merge optimization
/////////////////////////////////////////////////////////////////////
TEST_F(MklLayoutPassTest, Basic) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Zeta);D(Zeta)|"
"A->C;A->D;B->C:1;B->D:1");
}
// Test set 1: Conv2D + AddBias
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Zeta(E,Y)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Zeta)|A->E;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;"
"DMT/_2->E:5;E->Z;Y->Z:1");
}
// Graph contains only Conv2D, no AddBias.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);DMT/_0(Const);DMT/_1(Const)|"
"A->C;A:control->DMT/_0:control;A:control->DMT/_1:control;B->C:1;"
"DMT/_0->C:2;DMT/_1->C:3");
}
// Conv2D output does not go to BiasAdd.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D', 'E'] }"); // Output of _MklConv2D does not go to BiasAdd.
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);DMT/_0(Const);"
"DMT/_1(Const);E(Input);F(BiasAdd)|A->C;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;D->F;DMT/_0->C:2;DMT/_1->C:3;"
"E->F:1");
}
// Conv2D has two outgoing edges: BiasAdd and some other dummy node (Zeta).
// Merge should not be done in such case.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D', 'E'] }" // Conv2D has two outputs.
// No merge should happen.
"node { name: 'G' op: 'Zeta'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);DMT/_0(Const);"
"DMT/_1(Const);E(Input);F(BiasAdd);G(Zeta)|A->C;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;B->C:1;C->G;"
"D->F;DMT/_0->C:2;DMT/_1->C:3;E->F:1;E->G:1");
}
// data_format attribute value mismatch. Merge should not be done
// in such case.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHCW' } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);DMT/_0(Const);"
"DMT/_1(Const);E(BiasAdd)|A->C;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;C->E;D->E:1;DMT/_0->C:2;"
"DMT/_1->C:3");
}
// Test set 2: BiasAddGrad + Conv2DBackpropFilter fusion tests
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Conv2DBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);"
"D(_MklConv2DBackpropFilterWithBias);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const)|A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// BiasAddGrad fusion in the presence of BackpropFilter. But nodes do not match
// criteria for rewrite. So rewrite should not happen. 3rd input of
// Conv2DBackpropFilter is different than input to BiasAddGrad.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Conv2DBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['A'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);"
"D(_MklConv2DBackpropFilter);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(BiasAddGrad)|A->D;A->E;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// BiasAddGrad fusion, but nodes do not match criteria for fusion.
// Different input formats.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Conv2DBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" input: ['A'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);"
"D(_MklConv2DBackpropFilter);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(BiasAddGrad)|A->D;A->E;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// BiasAddGrad fusion in the presence of BackpropFilter only. Fusion is done
// before node rewrite. Check this ordering.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative3) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'O' op: '_MklInput'}"
"node { name: 'D' op: '_MklConv2DWithBias'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
"node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'Int32Input'}"
"node { name: 'G' op: '_MklConv2DBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['E', 'F', 'A', 'M', 'N', 'O'] }"
"node { name: 'H' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
"E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
"M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
"C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
"O->G:5");
}
// C=Conv2D(A,B); E=BiasAdd(C,D); Y=Zeta(E,X);
// G=Conv2DBackpropInput(F,B,E)
// This is a case of node rewrite followed by node merge followed by connecting
// filter output of Conv2DWithBias to filter input of Conv2DBackpropInput.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_ConvBpropInput_FilterFwd) {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'X' op: 'Input'}"
"node { name: 'Y' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'X']}"
"node { name: 'F' op: 'Int32Input'}"
"node { name: 'G' op: 'Conv2DBackpropInput'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['F', 'B', 'E']}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['G', 'X']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2DWithBias);F(Int32Input);"
"G(_MklConv2DBackpropInput);X(Input);Y(Zeta);Z(Zeta)|"
"A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;"
"DMT/_1->E:4;DMT/_2->E:5;DMT/_3->G:3;E->G:2;E->Y;E:1->G:1;E:2->G:5;"
"E:3->G:4;F->G;F:control->DMT/_3:control;G->Z;X->Y:1;X->Z:1");
}
// Test set 3: Pad + Conv2D fusion
// padding is VALID type
// A = input(image), B = input(paddings), C= Pad = input of conv2D,
// D=input(filter), E = Conv2D, Z = Zeta
// C=Pad(A,B); E=Conv2D(C,D); Z=Zeta(E,Y)
// After layout pass
// _MklPadWithConv2D(A, D, B, DMT/_0, DMT/_1, DMT/_2)
TEST_F(MklLayoutPassTest, NodeMerge_PadWithConv2D_Positive) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklPadWithConv2D);Y(Input);Z(Zeta)|A->E;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->E:2;D->E:1;DMT/_0->E:3;DMT/_1->E:4;"
"DMT/_2->E:5;E->Z;Y->Z:1");
}
// Test if input control edges do not duplicate after merge.
// If both the merging ops have input control edge from a common op
// then, the merged op will have only one control edge from that
// common op.
// padding is VALID type
// A = input(image), A1 = input, B = input(paddings),
// C= Pad = input of conv2D,
// D=input(filter), E = Conv2D, Z = Zeta
// C=Pad(A,B); E=Conv2D(C,D); Z=Zeta(E,Y)
// A1:control->C:control
// A1:control->E:control
// After layout pass:
// _MklPadWithConv2D(A, D, B, DMT/_0, DMT/_1, DMT/_2)
// A1:control->E:control (only one control edge)
TEST_F(MklLayoutPassTest, Input_ControlEdge_PadWithConv2D_Positive) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A1' op: 'Input'}"
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
Node* a1 = FindNode("A1");
Node* c = FindNode("C");
Node* e = FindNode("E");
const Edge* edge = graph_.AddControlEdge(a1, c);
const Edge* edge_1 = graph_.AddControlEdge(a1, e);
ASSERT_NE(edge, nullptr);
ASSERT_NE(edge_1, nullptr);
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);A1(Input);B(Int32Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklPadWithConv2D);Y(Input);Z(Zeta)|A->E;"
"A1:control->E:control;A:control->DMT/_0:control;A:control->DMT/"
"_1:control;"
"A:control->DMT/_2:control;B->E:2;D->E:1;DMT/_0->E:3;DMT/_1->E:4;"
"DMT/_2->E:5;E->Z;Y->Z:1");
}
// Test if output control edges does not duplicate after merge.
// If both the merging ops have output control edge to a common op,
// then after merge, the merged op will have only one control edge
// to that common op.
// padding is VALID type
// A = input(image), B = input(paddings), C= Pad = input of conv2D,
// D=input(filter), E = Conv2D, Z = Zeta
// C=Pad(A,B); E=Conv2D(C,D); Z=Zeta(E,Y)
// C:control->A1:control
// E:control->A1:control
// After layout pass:
// _MklPadWithConv2D(A, D, B, DMT/_0, DMT/_1, DMT/_2)
// E:control->A1:control (only one control edge)
TEST_F(MklLayoutPassTest, Output_ControlEdge_PadWithConv2D_Positive) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A1' op: 'Input'}"
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
Node* a1 = FindNode("A1");
Node* c = FindNode("C");
Node* e = FindNode("E");
const Edge* edge = graph_.AddControlEdge(c, a1);
const Edge* edge_1 = graph_.AddControlEdge(e, a1);
ASSERT_NE(edge, nullptr);
ASSERT_NE(edge_1, nullptr);
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);A1(Input);B(Int32Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklPadWithConv2D);Y(Input);Z(Zeta)|A->E;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->E:2;D->E:1;DMT/_0->E:3;DMT/_1->E:4;"
"DMT/_2->E:5;E->Z;E:control->A1:control;Y->Z:1");
}
// Pad + Conv2D fusion with padding is VALID,
// Input node pointing to both Pad and Conv2D
// A = input(image), B = input(paddings), C= Pad
// E = Conv2D, Z = Zeta
// C=Pad(A,B); E=Conv2D(C,A); Z=Zeta(E,Y)
// After layout pass
// _MklPadWithConv2D(A, A, B, DMT/_0, DMT/_1, DMT/_2)
TEST_F(MklLayoutPassTest, NodeMerge_PadWithConv2D_Common_Input) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'A'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklPadWithConv2D);Y(Input);Z(Zeta)|A->E;A->E:1;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->E:2;DMT/_0->E:3;DMT/_1->E:4;"
"DMT/_2->E:5;E->Z;Y->Z:1");
}
// Pad + Conv2D with padding is VALID,
// Input node pointing to both Pad and Conv2D
// Output of both Pad and Conv2D feeds one node (Z as Output2)
// A = input(as image), B = input(as paddings), C= Pad
// E = Conv2D, Z = Output2
// C=Pad(A,B); E=Conv2D(C,A); Z=Output(C,E)
// After layout pass - No merging, since Pad and Conv2D both
// feed to the same node (Z)
TEST_F(MklLayoutPassTest, NodeMerge_PadWithConv2D_Common_InOutput) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'A'] }"
"node { name: 'Z' op: 'Output2'"
" input: ['C', 'E']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Pad);DMT/_0(Const);DMT/_1(Const);"
"E(_MklConv2D);Z(Output2)|A->C;A->E:1;B->C:1;C->E;C->Z;"
"C:control->DMT/_0:control;C:control->DMT/_1:control;"
"DMT/_0->E:2;DMT/_1->E:3;E->Z:1");
}
// Pad + Conv2D; padding is SAME
// A = input(image), B = input(paddings), C= Pad = input of conv2D,
// D=input(filter), E = Conv2D, Z = Zeta
// C=Pad(A,B); E=Conv2D(C,D); Z=Zeta(E,Y)
// After layout pass - No merging
TEST_F(MklLayoutPassTest, NodeMerge_PadWithConv2D_Negative) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Pad);D(Input);DMT/_0(Const);DMT/_1(Const);"
"E(_MklConv2D);Y(Input);Z(Zeta)|A->C;B->C:1;C->E;"
"C:control->DMT/_0:control;C:control->DMT/_1:control;"
"D->E:1;DMT/_0->E:2;DMT/_1->E:3;E->Z;Y->Z:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv2DTranspose_Positive) {
InitGraph(
"node { name: 'Input0' op: 'Input'}"
"node { name: 'Input1' op: 'Input'}"
"node { name: 'Const0' op: 'Const'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_INT32"
" }"
" }"
" attr {"
" key: 'value'"
" value {"
" tensor {"
" dtype: DT_INT32"
" tensor_shape {"
" dim {"
" size: 4"
" }"
" }"
" tensor_content: "
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\000"
"\\000\\000'"
" }"
" }"
" }"
"}"
"node { name: 'Const1' op: 'Const'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_INT32"
" }"
" }"
" attr {"
" key: 'value'"
" value {"
" tensor {"
" dtype: DT_INT32"
" tensor_shape {"
" dim {"
" size: 4"
" }"
" }"
" tensor_content: "
"'\\000\\000\\000\\000\\003\\000\\000\\000\\001\\000\\000\\000\\002\\000"
"\\000\\000'"
" }"
" }"
" }"
"}"
"node { \
name: 'Transpose0' \
op: 'Transpose' \
input: 'Input0' \
input: 'Const0' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { \
name: 'Conv2D' \
op: 'Conv2D' \
input: 'Transpose0' \
input: 'Input1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'data_format' \
value { \
s: 'NHWC' \
} \
} \
attr { \
key: 'dilations' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'padding' \
value { \
s: 'SAME' \
} \
} \
attr { \
key: 'strides' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'use_cudnn_on_gpu' \
value { \
b: true \
} \
} \
}"
"node { \
name: 'Transpose1' \
op: 'Transpose' \
input: 'Conv2D' \
input: 'Const1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { name: 'Relu' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['Transpose1'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"Const0(Const);Const1(Const);"
"Conv2D(_MklConv2D);DMT/_0(Const);DMT/_1(Const);Input0(Input);"
"Input1(Input);Relu(_MklRelu)|Conv2D->Relu;Conv2D:2->Relu:1;DMT/"
"_0->Conv2D:2;DMT/_1->Conv2D:3;Input0->Conv2D;"
"Input0:control->DMT/_0:control;Input0:control->DMT/"
"_1:control;Input1->Conv2D:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv2DTranspose_Negative) {
InitGraph(
"node { name: 'Input0' op: 'Input'}"
"node { name: 'Input1' op: 'Input'}"
"node { name: 'Const0' op: 'Const'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_INT32"
" }"
" }"
" attr {"
" key: 'value'"
" value {"
" tensor {"
" dtype: DT_INT32"
" tensor_shape {"
" dim {"
" size: 4"
" }"
" }"
" tensor_content: "
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\000"
"\\000\\000'"
" }"
" }"
" }"
"}"
"node { name: 'Const1' op: 'Const'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_INT32"
" }"
" }"
" attr {"
" key: 'value'"
" value {"
" tensor {"
" dtype: DT_INT32"
" tensor_shape {"
" dim {"
" size: 4"
" }"
" }"
" tensor_content: "
"'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\000"
"\\000\\000'"
" }"
" }"
" }"
"}"
"node { \
name: 'Transpose0' \
op: 'Transpose' \
input: 'Input0' \
input: 'Const0' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { \
name: 'Conv2D' \
op: 'Conv2D' \
input: 'Transpose0' \
input: 'Input1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'data_format' \
value { \
s: 'NHWC' \
} \
} \
attr { \
key: 'dilations' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'padding' \
value { \
s: 'SAME' \
} \
} \
attr { \
key: 'strides' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'use_cudnn_on_gpu' \
value { \
b: true \
} \
} \
}"
"node { \
name: 'Transpose1' \
op: 'Transpose' \
input: 'Conv2D' \
input: 'Const1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { name: 'Relu' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['Transpose1'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"Const0(Const);Const1(Const);Conv2D(_MklConv2D);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);Input0(Input);"
"Input1(Input);Relu(_MklRelu);Transpose0(_MklTranspose);"
"Transpose1(_MklTranspose)|Const0->Transpose0:1;"
"Const1->Transpose1:1;Conv2D->Transpose1;DMT/_0->Conv2D:2;"
"DMT/_1->Conv2D:3;DMT/_2->Relu:1;Input0->Transpose0;"
"Input1->Conv2D:1;Transpose0->Conv2D;Transpose0:control->DMT/_0:control;"
"Transpose0:control->DMT/_1:control;Transpose1->Relu;"
"Transpose1:control->DMT/_2:control");
}
TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv3DTranspose_Positive) {
InitGraph(
"node { name: 'Input0' op: 'Input'} \
node { name: 'Input1' op: 'Input'} \
node { name: 'Const0' op: 'Const' \
attr { key: 'dtype' value { type: DT_INT32 } } \
attr { \
key: 'value' \
value { \
tensor { \
dtype: DT_INT32 \
tensor_shape { \
dim { \
size: 5 \
} \
} \
tensor_content: \
'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004' \
'\\000\\000\\000\\001\\000\\000\\000' \
} \
} \
} \
} \
node { name: 'Const1' op: 'Const' \
attr { key: 'dtype' value { type: DT_INT32 } } \
attr { \
key: 'value' \
value { \
tensor { \
dtype: DT_INT32 \
tensor_shape { \
dim { \
size: 5 \
} \
} \
tensor_content: \
'\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\002' \
'\\000\\000\\000\\003\\000\\000\\000' \
} \
} \
} \
}"
"node { \
name: 'Transpose0' \
op: 'Transpose' \
input: 'Input0' \
input: 'Const0' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { \
name: 'Conv3D' \
op: 'Conv3D' \
input: 'Transpose0' \
input: 'Input1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'data_format' \
value { \
s: 'NDHWC' \
} \
} \
attr { \
key: 'dilations' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'padding' \
value { \
s: 'SAME' \
} \
} \
attr { \
key: 'strides' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'use_cudnn_on_gpu' \
value { \
b: true \
} \
} \
}"
"node { \
name: 'Transpose1' \
op: 'Transpose' \
input: 'Conv3D' \
input: 'Const1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { name: 'Relu' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['Transpose1'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);"
"DMT/_1(Const);Input0(Input);Input1(Input);"
"Relu(_MklRelu)|Conv3D->Relu;Conv3D:2->Relu:1;"
"DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;Input0->Conv3D;"
"Input0:control->DMT/_0:control;"
"Input0:control->DMT/_1:control;Input1->Conv3D:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv3DTranspose_Negative) {
InitGraph(
"node { name: 'Input0' op: 'Input'} \
node { name: 'Input1' op: 'Input'} \
node { name: 'Const0' op: 'Const' \
attr { \
key: 'dtype' \
value { \
type: DT_INT32 \
} \
} \
attr { \
key: 'value' \
value { \
tensor { \
dtype: DT_INT32 \
tensor_shape { \
dim { \
size: 5 \
} \
} \
tensor_content: \
'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004' \
'\\000\\000\\000\\001\\000\\000\\000' \
} \
} \
} \
} \
node { name: 'Const1' op: 'Const' \
attr { \
key: 'dtype' \
value { \
type: DT_INT32 \
} \
} \
attr { \
key: 'value' \
value { \
tensor { \
dtype: DT_INT32 \
tensor_shape { \
dim { \
size: 5 \
} \
} \
tensor_content: \
'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004' \
'\\000\\000\\000\\001\\000\\000\\000' \
} \
} \
} \
}"
"node { \
name: 'Transpose0' \
op: 'Transpose' \
input: 'Input0' \
input: 'Const0' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { \
name: 'Conv3D' \
op: 'Conv3D' \
input: 'Transpose0' \
input: 'Input1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'data_format' \
value { \
s: 'NDHWC' \
} \
} \
attr { \
key: 'dilations' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'padding' \
value { \
s: 'SAME' \
} \
} \
attr { \
key: 'strides' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'use_cudnn_on_gpu' \
value { \
b: true \
} \
} \
}"
"node { \
name: 'Transpose1' \
op: 'Transpose' \
input: 'Conv3D' \
input: 'Const1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { name: 'Relu' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['Transpose1'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"Const0(Const);Const1(Const);Conv3D(_MklConv3D);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);"
"Input0(Input);Input1(Input);Relu(_MklRelu);"
"Transpose0(_MklTranspose);Transpose1(_MklTranspose)"
"|Const0->Transpose0:1;Const1->Transpose1:1;"
"Conv3D->Transpose1;DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;"
"DMT/_2->Relu:1;Input0->Transpose0;Input1->Conv3D:1;"
"Transpose0->Conv3D;Transpose0:control->DMT/_0:control;"
"Transpose0:control->DMT/_1:control;Transpose1->Relu;"
"Transpose1:control->DMT/_2:control");
}
/////////////////////////////////////////////////////////////////////
// Unit tests related to rewriting node to Mkl node
/////////////////////////////////////////////////////////////////////
// Single Conv2D Op; No Mkl layer on the input and on the output.
// We will generate dummy Mkl tensor as 2nd input of Conv2D.
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
"DMT/_1->C:3");
}
// Test case for the Depthwise FWD pass
TEST_F(MklLayoutPassTest, NodeRewrite_DepthwiseConv2dNative_Basic) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'DepthwiseConv2dNative'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklDepthwiseConv2dNative);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
"DMT/_1->C:3");
}
// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
// have 2 outputs, both of which will be inputs to next Conv2D.
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->C;A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->C:1;C->D:1;C->E;"
"C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
}
// Conv2D with INT32 which is not supported by Mkl
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
InitGraph(
"node { name: 'A' op: 'HalfInput'}"
"node { name: 'B' op: 'HalfInput'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_HALF } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(HalfInput);B(HalfInput);C(Conv2D);D(Zeta)|"
"A->C;B->C:1;B->D;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_ConstInp) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'D' op: 'QuantizeV2'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'SCALED' } }"
" attr { key: 'round_mode' value { s: 'HALF_TO_EVEN' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(Const);C(Const);D(QuantizeV2);E(Zeta)|"
"A->D;B->D:1;C->D:2;D->E");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_MinFirst) {
InitGraph(
"node { name: 'A' op: 'Input' } "
"node { name: 'B' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'D' op: 'QuantizeV2'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'MIN_FIRST' } }"
" attr { key: 'round_mode' value { s: 'HALF_TO_EVEN' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Const);C(Const);D(QuantizeV2);E(Zeta)|"
"A->D;B->D:1;C->D:2;D->E");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Negative_HalfFromZero) {
InitGraph(
"node { name: 'A' op: 'Input' } "
"node { name: 'B' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'D' op: 'QuantizeV2'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'SCALED' } }"
" attr { key: 'round_mode' value { s: 'HALF_FROM_ZERO' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Const);C(Const);D(QuantizeV2);E(Zeta)|"
"A->D;B->D:1;C->D:2;D->E");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizeV2Op_Positive) {
InitGraph(
"node { name: 'A' op: 'Input' } "
"node { name: 'B' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Const' "
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'D' op: 'QuantizeV2'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'SCALED' } }"
" attr { key: 'round_mode' value { s: 'HALF_TO_EVEN' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Const);C(Const);D(_MklQuantizeV2);DMT/_0(Const);DMT/"
"_1(Const);DMT/_2(Const);E(Zeta)|"
"A->D;A:control->DMT/_0:control;A:control->DMT/"
"_1:control;A:control->DMT/_2:control;B->D:1;C->D:2;D->E;DMT/"
"_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Dequantize_Negative_Const_Input) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_QUINT8 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_QUINT8 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Dequantize'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'SCALED' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(Input);C(Input);D(Dequantize);"
"E(Zeta)|A->D;B->D:1;C->D:2;D->E");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Dequantize_Negative_Non_SCALED_Mode) {
InitGraph(
"node { name: 'A' op: 'QuantizedInput'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Dequantize'"
" attr { key: 'T' value { type: DT_QUINT8 } }"
" attr { key: 'mode' value { s: 'MIN_FIRST' } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QuantizedInput);B(Input);C(Input);D(Dequantize);"
"E(Zeta)|A->D;B->D:1;C->D:2;D->E");
}
// Rewrite test for _FusedConv2D Op with BiasAdd fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Positive1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Rewrite test for _FusedConv2D Op with Relu fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Positive2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'Relu'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Rewrite test for _FusedConv2D Op with BiasAdd+Relu fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Positive3) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops'"
" value { list: {s: 'BiasAdd', s: 'Relu'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Rewrite test for _FusedConv2D Op with BiasAdd+Relu6 fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Positive4) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops'"
" value { list: {s: 'BiasAdd', s: 'Relu6'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Rewrite test for _FusedConv2D Op with BiasAdd+Elu fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Positive5) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops'"
" value { list: {s: 'BiasAdd', s: 'Elu'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Rewrite test for _FusedConv2D Op with BiasAdd+Add fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Positive6) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 2 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops'"
" value { list: {s: 'BiasAdd', s: 'Add'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C', 'D']}"
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(_MklFusedConv2D);F(Zeta)|A->E;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"B->E:1;C->E:2;D->E:3;D->F:1;DMT/_0->E:4;DMT/_1->E:5;"
"DMT/_2->E:6;DMT/_3->E:7;E->F");
}
// Rewrite test for _FusedConv2D Op with BiasAdd+Add+Relu fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Positive7) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 2 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops'"
" value { list: {s: 'BiasAdd', s: 'Add', s: 'Relu'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C', 'D']}"
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(_MklFusedConv2D);F(Zeta)|A->E;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;"
"C->E:2;D->E:3;D->F:1;DMT/_0->E:4;DMT/_1->E:5;DMT/_2->E:6;"
"DMT/_3->E:7;E->F");
}
// Rewrite test for _FusedConv2D Op with unsupported fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Negative1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_FusedConv2D);E(Zeta)|A->D;"
"B->D:1;C->D:2;C->E:1;D->E");
}
// Rewrite test for _FusedConv2D Op with unsupported type
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Negative2) {
InitGraph(
"node { name: 'A' op: 'DoubleInput'}"
"node { name: 'B' op: 'DoubleInput'}"
"node { name: 'C' op: 'DoubleInput'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_DOUBLE } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_DOUBLE } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(DoubleInput);B(DoubleInput);C(DoubleInput);"
"D(_FusedConv2D);E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E");
}
// Merge test for PadWithFusedConv2D Op with BiasAdd fusion
// padding is VALID type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C, D, E)
// G = Zeta(F, E)
// After layout pass
// _MklPadWithFusedConv2D(A, D, E, B, DMT/_0, DMT/_1, DMT/_2, DMT/_3)
TEST_F(MklLayoutPassTest, NodeMerge_PadWithFusedConv2D_Positive1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['C', 'D', 'E']}"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['F', 'E'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);D(Input);DMT/_0(Const);DMT/_1(Const);DMT/"
"_2(Const);DMT/_3(Const);E(Input);F(_MklPadWithFusedConv2D);"
"G(Zeta)|A->F;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->F:3;D->F:1;DMT/"
"_0->F:4;DMT/_1->F:5;DMT/_2->F:6;DMT/_3->F:7;E->F:2;E->G:1;F->G");
}
// Merge test for PadWithFusedConv2D Op with BiasAdd+Relu fusion
// padding is VALID type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C, D, E) (With relu)
// G = Zeta(F, E)
// After layout pass
// _MklPadWithFusedConv2D(A, D, E, B, DMT/_0, DMT/_1, DMT/_2, DMT/_3)
TEST_F(MklLayoutPassTest, NodeMerge_PadWithFusedConv2D_Positive2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops'"
" value { list: {s: 'BiasAdd', s: 'Relu'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['C', 'D', 'E']}"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['F', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);D(Input);DMT/_0(Const);DMT/_1(Const);DMT/"
"_2(Const);DMT/_3(Const);E(Input);F(_MklPadWithFusedConv2D);"
"G(Zeta)|A->F;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->F:3;"
"D->F:1;DMT/_0->F:4;DMT/_1->F:5;DMT/_2->F:6;DMT/"
"_3->F:7;E->F:2;E->G:1;F->G");
}
// Merge test for PadWithFusedConv2D Op with unsupported fusion
// padding is VALID type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias),
// F = _FusedConv2D(C, D, E) (With Unsupported), G = Zeta(F, E)
// After layout pass - No merging
TEST_F(MklLayoutPassTest, NodeMerge_PadWithFusedConv2D_Negative1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['C', 'D', 'E']}"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['F', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Pad);D(Input);E(Input);F(_FusedConv2D);G("
"Zeta)|A->C;B->C:1;C->F;D->F:1;E->F:2;E->G:1;F->G");
}
// Merge test for PadWithFusedConv2D Op with BiasAdd fusion
// padding is SAME type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C,D,E)
// G = Zeta(F,E)
// After layout pass - No merging
TEST_F(MklLayoutPassTest, NodeMerge_PadWithFusedConv2D_Negative2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['C', 'D', 'E']}"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['F', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Pad);D(Input);DMT/_0(Const);DMT/"
"_1(Const);DMT/_2(Const);E(Input);F(_MklFusedConv2D);G(Zeta)|A->C;"
"B->C:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
"C:control->DMT/_2:control;D->F:1;DMT/_0->F:3;DMT/_1->F:4;DMT/"
"_2->F:5;E->F:2;E->G:1;F->G");
}
// Merge test for PadWithFusedConv2D Op with BiasAdd+Relu fusion
// padding is SAME type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C,D,E)(With relu)
// G = Zeta(F,E)
// After layout pass - No merging
TEST_F(MklLayoutPassTest, NodeMerge_PadWithFusedConv2D_Negative3) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops'"
" value { list: {s: 'BiasAdd', s: 'Relu'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['C', 'D', 'E']}"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['F', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Pad);D(Input);DMT/_0(Const);DMT/"
"_1(Const);DMT/_2(Const);E(Input);F(_MklFusedConv2D);G(Zeta)|A->C;"
"B->C:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
"C:control->DMT/_2:control;D->F:1;DMT/_0->F:3;DMT/_1->F:4;DMT/"
"_2->F:5;E->F:2;E->G:1;F->G");
}
// Tests that there are no duplicate input control edges after merge.
// If both the merging ops have input control edges from a common op
// then, the merged op will have only one control edge from that
// common op. This test only add additional input control edge check
// based on the previous test NodeMerge_PadWithFusedConv2D_Positive1
// padding is VALID type
// A = input(image), X = input, B = input(paddings),
// C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C, D, E)
// G = Zeta(F, E)
// X:control->C:control
// X:control->F:control
// After layout pass:
// _MklPadWithFusedConv2D(A, D, B, F, DMT/_0, DMT/_1, DMT/_2, DMT/_3)
// X:control->E:control (only one control edge)
TEST_F(MklLayoutPassTest, Input_ControlEdge_PadWithFusedConv2D_Positive) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'X' op: 'Input'}"
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['C', 'D', 'E']}"
"node { name: 'G' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['F', 'E']}");
Node* x = FindNode("X");
Node* c = FindNode("C");
Node* f = FindNode("F");
const Edge* edge = graph_.AddControlEdge(x, c);
const Edge* edge_1 = graph_.AddControlEdge(x, f);
ASSERT_NE(edge, nullptr);
ASSERT_NE(edge_1, nullptr);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(Input);F(_MklPadWithFusedConv2D);"
"G(Zeta);X(Input)|A->F;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;A:control->DMT/_2:control;"
"A:control->DMT/_3:control;B->F:3;D->F:1;DMT/_0->F:4;"
"DMT/_1->F:5;DMT/_2->F:6;DMT/_3->F:7;E->F:2;E->G:1;F->G;"
"X:control->F:control");
}
// ts that there are no duplicate output control edges after merge.
// If both the merging ops have output control edge to a common op,
// then after merge, the merged op will have only one control edge
// to that common op. This test only add additional output control edge check
// based on the previous test NodeMerge_PadWithFusedConv2D_Positive1
// padding is VALID type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
// D = input(filter), E = input(bias), F = _FusedConv2D(C, D, E)
// G = Zeta(F, E), X = input
// C:control->X:control
// F:control->X:control
// After layout pass:
// _MklPadWithFusedConv2D(A, D, B, F, DMT/_0, DMT/_1, DMT/_2, DMT/_2)
// F:control->X:control (only one control edge)
TEST_F(MklLayoutPassTest, Output_ControlEdge_PadWithFusedConv2D_Positive) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'X' op: 'Input'}"
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['C', 'D', 'E']}"
"node { name: 'G' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['F', 'E']}");
Node* x = FindNode("X");
Node* c = FindNode("C");
Node* f = FindNode("F");
const Edge* edge = graph_.AddControlEdge(c, x);
const Edge* edge_1 = graph_.AddControlEdge(f, x);
ASSERT_NE(edge, nullptr);
ASSERT_NE(edge_1, nullptr);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);D(Input);DMT/_0(Const);DMT/_1(Const);DMT/"
"_2(Const);DMT/_3(Const);E(Input);F(_MklPadWithFusedConv2D);"
"G(Zeta);X(Input)|A->F;A:control->DMT/_0:control;A:control->DMT/"
"_1:control;A:control->DMT/_2:control;A:control->DMT/"
"_3:control;B->F:3;D->F:1;DMT/_0->F:4;DMT/_1->F:5;DMT/_2->F:6;DMT/"
"_3->F:7;E->F:2;E->G:1;F->G;F:control->X:control");
}
// Pad + _FusedConv2D with padding is VALID,
// Input node pointing to both Pad and _FusedConv2D
// Output of both Pad and _FusedConv2D feeds one node (G as Output2)
// A = input(as image), B = input(as paddings), C = Pad(A, B)
// E = input(as bias), F = _FusedConv2D(C, A, E), G = Output(C, F)
// After layout pass - No merging, since Pad and _FusedConv2D both
// feed to the same node (Z)
TEST_F(MklLayoutPassTest, NodeMerge_PadWithFusedConv2D_Common_InOutput) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['C', 'A', 'E']}"
"node { name: 'G' op: 'Output2'"
" input: ['C', 'F']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Pad);DMT/_0(Const);DMT/_1(Const);DMT/"
"_2(Const);E(Input);F(_MklFusedConv2D);G(Output2)|A->C;A->F:1;B->C:"
"1;C->F;C->G;C:control->DMT/_0:control;C:control->DMT/"
"_1:control;C:control->DMT/_2:control;DMT/_0->F:3;DMT/_1->F:4;DMT/"
"_2->F:5;E->F:2;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Conv2DBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:4;DMT/_2->D:5");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Conv2DBackpropInput'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['B', 'A', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
"A->D:1;A->E;B->D;B:control->DMT/_0:control;"
"B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;"
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
TEST_F(MklLayoutPassTest,
NodeRewrite_DepthwiseConv2dNativeGradFilter_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'DepthwiseConv2dNativeBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(_"
"MklDepthwiseConv2dNativeBackpropFilter);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:4;DMT/_2->D:5");
}
TEST_F(MklLayoutPassTest, NodeRewrite_DepthwiseConv2dNativeGradInput_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'DepthwiseConv2dNativeBackpropInput'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['B', 'A', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(_"
"MklDepthwiseConv2dNativeBackpropInput);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
"A->D:1;A->E;B->D;B:control->DMT/_0:control;"
"B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;"
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Check that we never rewrite BiasAddGrad.
TEST_F(MklLayoutPassTest, NodeRewrite_BiasAddGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Polygamma'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A']}"
"node { name: 'E' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|"
"A->C;A->D:1;B->C:1;C->D;D->E");
}
// Check that we never rewrite BiasAddGrad.
TEST_F(MklLayoutPassTest, NodeRewrite_BiasAddGrad_Positive1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'MatMul'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'transpose_a' value { b: false } }"
" attr { key: 'transpose_b' value { b: false } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A']}"
"node { name: 'E' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklMatMul);D(Zeta);E(BiasAddGrad)"
"|A->C;A->D:1;B->C:1;C->D;D->E");
}
// Check that we never rewrite BiasAddGrad.
TEST_F(MklLayoutPassTest, NodeRewrite_BiasAddGrad_Positive2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A']}"
"node { name: 'E' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Zeta);E(BiasAddGrad);"
"M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
"M->C:2;N->C:3");
}
// Concat Op test: Concat with no Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'InputList'"
" attr { key: 'N' value { i: 2 } }}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Concat'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['A', 'B:0', 'B:1']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
"B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Concat with 2 Mkl layers feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'F' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D']}"
"node { name: 'G' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'H' op: 'Concat'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['G', 'E', 'F']}"
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
"F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"B->E:1;C->F;C:control->DMT/_2:control;C:control->DMT/_3:control;"
"D->F:1;DMT/_0->E:2;DMT/_1->E:3;DMT/_2->F:2;DMT/_3->F:3;"
"DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;"
"G:control->DMT/_4:control;H->I:1");
}
// Concat with 1 Mkl and 1 non-Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}"
"node { name: 'G' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'H' op: 'Concat'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['G', 'E', 'F']}"
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);"
"H(_MklConcat);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;"
"G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1");
}
// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'InputList'"
" attr { key: 'N' value { i: 2 } }}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'ConcatV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tidx' value { type: DT_INT32 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['B:0', 'B:1', 'A']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D:2;B->D;B:1->D:1;"
"B:control->DMT/_0:control;B:control->DMT/_1:control;"
"B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:4;DMT/_2->D:5");
}
// ConcatV2 with 2 Mkl layers feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'F' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D']}"
"node { name: 'G' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'H' op: 'ConcatV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tidx' value { type: DT_INT32 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['E', 'F', 'G']}"
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;B->E:1;C->F;"
"C:control->DMT/_2:control;C:control->DMT/_3:control;"
"D->F:1;DMT/_0->E:2;DMT/_1->E:3;DMT/_2->F:2;DMT/_3->F:3;"
"DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;"
"F:2->H:4;G->H:2;H->I:1");
}
// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}"
"node { name: 'G' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'H' op: 'ConcatV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tidx' value { type: DT_INT32 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['E', 'F', 'G']}"
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);"
"H(_MklConcatV2);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;"
"E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;"
"G->H:2;H->I:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;"
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'ReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklReluGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'ReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklRelu);C(_MklReluGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
"DMT/_1->C:2");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu6_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu6'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklRelu6);C(Zeta);DMT/_0(Const)|A->B;A->C;"
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu6Grad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Relu6Grad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklRelu6Grad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu6Relu6Grad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu6'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'Relu6Grad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklRelu6);C(_MklRelu6Grad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
"DMT/_1->C:2");
}
TEST_F(MklLayoutPassTest, NodeRewrite_LeakyRelu_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'LeakyRelu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.1 } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLeakyRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;"
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_LeakyRelu_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'LeakyRelu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 2.0 } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(LeakyRelu);C(Zeta)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_LeakyReluGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'LeakyReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.1 } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklLeakyReluGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
}
TEST_F(MklLayoutPassTest, NodeRewrite_LeakyReluGrad_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'LeakyReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 2.0 } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(LeakyReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_LeakyReluLeakyReluGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'LeakyRelu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.1 } }"
" input: ['A'] }"
"node { name: 'C' op: 'LeakyReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.1 } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLeakyRelu);C(_MklLeakyReluGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
"DMT/_1->C:2");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'AvgPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklAvgPool);C(Zeta);DMT/_0(Const)|A->B;A->C;"
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Int32Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'AvgPoolGrad' "
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
"DMT/_1->C:3");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'I' op: 'Int32Input'}"
"node { name: 'B' op: 'AvgPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A'] }"
"node { name: 'C' op: 'AvgPoolGrad' "
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['I', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);"
"DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;"
"B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;"
"I:control->DMT/_1:control");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNormGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
"F(_MklFusedBatchNormGrad);G(Zeta)|A->F;A->G;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGradV2_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNormGradV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
"F(_MklFusedBatchNormGradV2);G(Zeta)|A->F;A->G;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
"E->F:4;F->G:1");
}
// T, U combination is not supported by MKL. Node will not be rewritten
// into MKL node.
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGradV2_Negative) {
InitGraph(
"node { name: 'A' op: 'HalfInput'}"
"node { name: 'B' op: 'HalfInput'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNormGradV2'"
" attr { key: 'T' value { type: DT_HALF } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(HalfInput);B(HalfInput);C(Input);D(Input);E(Input);"
"F(FusedBatchNormGradV2);G(Zeta)|A->F;A->G;"
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNorm'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
"F(_MklFusedBatchNorm);G(Zeta)|A->F;A->G;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV2_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNormV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
"F(_MklFusedBatchNormV2);G(Zeta)|A->F;A->G;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
"E->F:4;F->G:1");
}
// T, U combination is not supported by MKL. Node will not be rewritten
// into MKL node.
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV2_Negative) {
InitGraph(
"node { name: 'A' op: 'HalfInput'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNormV2'"
" attr { key: 'T' value { type: DT_HALF } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(HalfInput);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNormV2);G(Zeta)|A->F;A->G;"
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNormV3'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
"F(_MklFusedBatchNormV3);G(Zeta)|A->F;A->G;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
"DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) {
InitGraph(
"node { name: 'A' op: 'HalfInput'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNormV3'"
" attr { key: 'T' value { type: DT_HALF } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
" input: ['A', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(HalfInput);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNormV3);G(Zeta)|A->F;A->G;"
"B->F:1;C->F:2;D->F:3;E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) {
InitGraph(
"node { name: 'A' op: 'QuantizedUnsignedInt8Input'}"
"node { name: 'B' op: 'QuantizedSignedInt8Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'QuantizedSignedInt32Input'}"
"node { name: 'H' op: 'QuantizedDepthwiseConv2D'"
" attr { key: 'Tinput' value { type: DT_QUINT8 } }"
" attr { key: 'Tfilter' value { type: DT_QINT8 } }"
" attr { key: 'out_type' value { type: DT_QINT32 } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F'] }"
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_QINT32 } }"
" input: ['G', 'H'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(QuantizedUnsignedInt8Input);B(QuantizedSignedInt8Input);C(Input);"
"D(Input);DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);"
"DMT/_4(Const);DMT/_5(Const);E(Input);F(Input);"
"G(QuantizedSignedInt32Input);H(_MklQuantizedDepthwiseConv2D);I(Zeta)"
"|A->H;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;A:control->DMT/_5:control;B->H:1;C->H:2;"
"D->H:3;DMT/_0->H:6;DMT/_1->H:7;DMT/_2->H:8;DMT/_3->H:9;DMT/_4->H:10;"
"DMT/_5->H:11;E->H:4;F->H:5;G->I;H->I:1");
}
/////////////////////////////////////////////////////////////////////
// Unit tests related to context-based node rewrite
/////////////////////////////////////////////////////////////////////
// If any of the inputs is an MKL op, then rewrite Slice to Mkl op.
TEST_F(MklLayoutPassTest, NodeRewrite_Ctxbased_Slice_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Int32Input'}"
"node { name: 'E' op: 'Int32Input'}"
"node { name: 'F' op: 'Slice'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Index' value { type: DT_INT32 } }"
" input: ['C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Int32Input);"
"DMT/_0(Const);DMT/_1(Const);"
"E(Int32Input);F(_MklSlice);G(Zeta);M(_MklInput);N(_MklInput)|"
"A->C;A->G;B->C:1;C->F;C->G:1;C:2->F:3;"
"C:control->DMT/_0:control;C:control->DMT/"
"_1:control;"
"D->F:1;DMT/_0->F:4;DMT/_1->F:5;"
"E->F:2;M->C:2;N->C:3");
}
// If none of the inputs is an MKL op, then Slice should not be rewritten.
TEST_F(MklLayoutPassTest, NodeRewrite_Ctxbased_Slice_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Int32Input'}"
"node { name: 'D' op: 'Slice'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Index' value { type: DT_INT32 } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Int32Input);"
"D(Slice);E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1");
}
/////////////////////////////////////////////////////////////////////
// Unit tests related to rewriting node for workspace edges
/////////////////////////////////////////////////////////////////////
/* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */
TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'LRN'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.001 } }"
" attr { key: 'beta' value { f: 0.75 } }"
" attr { key: 'bias' value { f: 1.0 } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['A'] }"
"node { name: 'C' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['B'] }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'MaxPoolGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['B', 'C', 'D'] }"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'LRNGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.001 } }"
" attr { key: 'beta' value { f: 0.75 } }"
" attr { key: 'bias' value { f: 1.0 } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['E', 'F', 'B'] }"
"node { name: 'H' op: 'Input'}"
"node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['H', 'G'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
"I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
"B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;"
"C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;"
"E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I");
}
/* Test LRN->LRNGrad replacement by workspace nodes. */
TEST_F(MklLayoutPassTest, LRN_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'LRN'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.001 } }"
" attr { key: 'beta' value { f: 0.75 } }"
" attr { key: 'bias' value { f: 1.0 } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['A'] }"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'LRNGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.001 } }"
" attr { key: 'beta' value { f: 0.75 } }"
" attr { key: 'bias' value { f: 1.0 } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['C', 'D', 'B'] }"
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklLRNGrad);F(Zeta)|"
"A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;"
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
"D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
}
/* Test LRN->LRNGrad replacement when only one of them is present. */
TEST_F(MklLayoutPassTest, LRN_Negative1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'LRN'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.001 } }"
" attr { key: 'beta' value { f: 0.75 } }"
" attr { key: 'bias' value { f: 1.0 } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Zeta);DMT/_0(Const)|"
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
/* Test LRN->LRNGrad replacement when only one of them is present. */
TEST_F(MklLayoutPassTest, LRN_Negative2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'LRNGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.001 } }"
" attr { key: 'beta' value { f: 0.75 } }"
" attr { key: 'bias' value { f: 1.0 } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(LRNGrad);"
"E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1");
}
/* Test LRN->LRNGrad negative case, where single LRN feeds
2 LRNGrad nodes at different slots. */
TEST_F(MklLayoutPassTest, LRN_Negative3) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'LRN'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.001 } }"
" attr { key: 'beta' value { f: 0.75 } }"
" attr { key: 'bias' value { f: 1.0 } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['A'] }"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'LRNGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.001 } }"
" attr { key: 'beta' value { f: 0.75 } }"
" attr { key: 'bias' value { f: 1.0 } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['C', 'D', 'B'] }"
"node { name: 'F' op: 'LRNGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'alpha' value { f: 0.001 } }"
" attr { key: 'beta' value { f: 0.75 } }"
" attr { key: 'bias' value { f: 1.0 } }"
" attr { key: 'depth_radius' value { i: 2 } }"
" input: ['C', 'B', 'D'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklLRNGrad);F(LRNGrad);G(Zeta)|A->B;"
"A:control->DMT/_0:control;B->E:2;B->F:1;B:1->E:3;B:2->E:6;"
"B:3->E:7;C->E;C->F;C:control->DMT/_1:control;"
"C:control->DMT/_2:control;D->E:1;D->F:2;DMT/_0->B:1;"
"DMT/_1->E:4;DMT/_2->E:5;E->G;F->G:1");
}
/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'MaxPoolGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['C', 'B', 'D'] }"
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Zeta)|"
"A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;"
"C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
"D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
}
// Test MaxPool>MaxPoolGrad replacement when only one of them is present.
// In this case, we will rewrite MaxPool node but workspace edges will not
// be present.
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklMaxPool);C(Zeta);DMT/_0(Const)|"
"A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
// Test MaxPoolGrad replacement when only one of them is present.
// In this case, we will rewrite MaxPoolGrad and for workspace tensor and
// its Mkl part, we will generate dummy tensor.
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'MaxPoolGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(MaxPoolGrad);"
"E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1");
}
// Test MaxPool handling for batch-wise pooling (NCHW)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for batch-wise pooling (NCHW)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NHWC)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NCHW)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for batch-wise pooling (NHWC)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for batch-wise pooling (NHWC)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NHWC)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Test MaxPool handling for depth-wise pooling (NHWC)
// No rewrite should take place in such case
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
/////////////////////////////////////////////////////////////////////
// Single Conv2D Op on GPU device
// No rewrite should happen
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'O' op: '_MklInput'}"
"node { name: 'D' op: '_MklConv2DWithBias'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
"node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'A']}"
"node { name: 'F' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['E'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
"E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;"
"M->D:3;N->D:4;O->D:5");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Conv2DBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1");
}
TEST_F(MklLayoutPassTest,
NodeRewrite_DepthwiseConv2dNativeGradFilter_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'DepthwiseConv2dNativeBackpropFilter'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Input);D("
"DepthwiseConv2dNativeBackpropFilter);E(Zeta)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'ReluGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu6_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Relu6'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Relu6);C(Zeta)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Relu6Grad_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Relu6Grad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'C'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Relu6Grad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'MaxPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'AvgPool'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A'] }"
"node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
}
// Concat Op test: Concat with no Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'InputList'"
" attr { key: 'N' value { i: 2 } }}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Concat'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['A', 'B:0', 'B:1']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
"B->D:1;B:1->D:2;C->E;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'InputList'"
" attr { key: 'N' value { i: 2 } }}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'ConcatV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tidx' value { type: DT_INT32 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['B:0', 'B:1', 'A']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNorm'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV2_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNorm'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'FusedBatchNorm'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'U' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'epsilon' value { f: 0.0001 } }"
" attr { key: 'is_training' value { b: true } }"
" input: ['A', 'B', 'C', 'D', 'E'] }"
"node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'F'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);E(Input);"
"F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
"E->F:4;F->G:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
"M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Int32Input'}"
"node { name: 'D' op: 'Slice'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Index' value { type: DT_INT32 } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }",
kGPUDevice);
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Int32Input);C(Int32Input);D(Slice);E(Zeta)|A->D;A->E;"
"B->D:1;C->D:2;D->E:1");
}
/////////////////////////////////////////////////////////////////////
// Post-rewrite fixup pass test
/////////////////////////////////////////////////////////////////////
TEST_F(MklLayoutPassTest, PostRewriteFixUpPass) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Const' "
" attr { key: 'dtype' value { type: DT_UINT8 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_UINT8 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'E' op: '_MklAdd'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A', 'D', 'D']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Const);E(_MklAdd);"
"M(_MklInput);N(_MklInput)|A->C;A->E:1;B->C:1;C->E;C:2->E:2;"
"D->E:3;M->C:2;N->C:3");
}
/////////////////////////////////////////////////////////////////////
// Unit tests related to filter caching.
//
// These tests check if the attribute `is_filter_const` is set to true
// when filter is a constant and false otherwise for various operators
// such as Conv2D, Conv2DWithBias, Conv3D etc.
/////////////////////////////////////////////////////////////////////
// Conv2D op where filter is a constant.
TEST_F(MklLayoutPassTest, Conv2D_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const' " // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklConv2D"));
}
// Conv2D op where filter is NOT a constant.
TEST_F(MklLayoutPassTest, Conv2D_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklConv2D"));
}
// Conv2D + BiasAdd fusion where filter is a constant.
TEST_F(MklLayoutPassTest, Conv2DWithBias_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const'" // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklConv2DWithBias"));
}
// Conv2D + BiasAdd fusion where filter is NOT a constant.
TEST_F(MklLayoutPassTest, Conv2DWithBias_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklConv2DWithBias"));
}
// Conv3D op where filter is a constant.
TEST_F(MklLayoutPassTest, Conv3D_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const' " // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Conv3D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCDHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1, "
"i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1, "
"i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklConv3D"));
}
// Conv3D op where filter is NOT a constant.
TEST_F(MklLayoutPassTest, Conv3D_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'Conv3D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCDHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1, "
"i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1, "
"i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklConv3D"));
}
// Pad + Conv2D fusion where filter is a constant.
TEST_F(MklLayoutPassTest, PadWithConv2D_FilterCaching_Positive) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Const'" // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklPadWithConv2D"));
}
// Pad + Conv2D fusion where filter is NOT a constant.
TEST_F(MklLayoutPassTest, PadWithConv2D_FilterCaching_Negative) {
DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Int32Input'}"
"node { name: 'C' op: 'Pad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tpaddings' value { type: DT_INT32 } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}" // Filter
"node { name: 'E' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NHWC' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'VALID' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D'] }"
"node { name: 'Y' op: 'Input'}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklPadWithConv2D"));
}
// _FusedConv2D + BiasAdd fusion where filter is a constant.
TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const'" // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklFusedConv2D"));
}
// _FusedConv2D + BiasAdd fusion where filter is NOT a constant.
TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>("is_filter_const",
"_MklFusedConv2D"));
}
// Depthwise Conv2D op where filter is a constant.
TEST_F(MklLayoutPassTest, DepthwiseConv2dNative_FilterCaching_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Const'" // Filter
" attr { key: 'dtype' value { type: DT_FLOAT } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'C' op: 'DepthwiseConv2dNative'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklDepthwiseConv2dNative"));
}
// Depthwise Conv2D op where filter is NOT a constant.
TEST_F(MklLayoutPassTest, DepthwiseConv2dNative_FilterCaching_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}" // Filter
"node { name: 'C' op: 'DepthwiseConv2dNative'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal<bool>(
"is_filter_const", "_MklDepthwiseConv2dNative"));
}
// Fused QuantizedMatMulWithBias Op Rewrite test
// Rewrite the QuantizedMatMulWithBias with _MklQuantizedMatMulWithBias
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedMatMulWithBias_Positive) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QInt8Input' }"
"node { name: 'C' op: 'QInt32Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'QInt32Input'}"
"node { name: 'I' op: 'QuantizedMatMulWithBias'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QINT8 } }"
" attr { key: 'Tbias' value { type: DT_QINT32 } }"
" attr { key: 'Toutput' value { type: DT_QINT32 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G']}"
"node { name: 'J' op: 'Zeta' attr { key: 'T' value { type: DT_QINT32 } }"
" input: ['I', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QInt8Input);C(QInt32Input);D(Input);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);"
"DMT/_4(Const);DMT/_5(Const);DMT/_6(Const);E(Input);F(Input);"
"G(Input);H(QInt32Input);I(_MklQuantizedMatMulWithBias);"
"J(Zeta)|A->I;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;A:control->DMT/_5:control;"
"A:control->DMT/_6:control;B->I:1;C->I:2;D->I:3;DMT/_0->I:7;"
"DMT/_1->I:8;DMT/_2->I:9;DMT/_3->I:10;DMT/_4->I:11;DMT/_5->I:12;"
"DMT/_6->I:13;E->I:4;F->I:5;G->I:6;H->J:1;I->J");
}
// Rewrite test for QuantizedMatMulWithBias Op with unsupported input
// Rewrite should not happen
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedMatMulWithBias_Negative) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QUInt8Input' }"
"node { name: 'C' op: 'QInt32Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'QInt32Input'}"
"node { name: 'I' op: 'QuantizedMatMulWithBias'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QUINT8 } }"
" attr { key: 'Tbias' value { type: DT_QINT32 } }"
" attr { key: 'Toutput' value { type: DT_QINT32 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G']}"
"node { name: 'J' op: 'Zeta' attr { key: 'T' value { type: DT_QINT32 } }"
" input: ['I', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QUInt8Input);C(QInt32Input);D(Input);E(Input);"
"F(Input);G(Input);H(QInt32Input);I(QuantizedMatMulWithBias);"
"J(Zeta)|A->I;B->I:1;C->I:2;D->I:3;"
"E->I:4;F->I:5;G->I:6;H->J:1;I->J");
}
// Fused QuantizedMatMulWithBiasAndRelu Op Rewrite test
// Rewrite the QuantizedMatMulWithBiasAndRelu with
// _MklQuantizedMatMulWithBiasAndRelu
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedMatMulWithBiasAndRelu_Positive) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QInt8Input' }"
"node { name: 'C' op: 'Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'QInt32Input'}"
"node { name: 'I' op: 'QuantizedMatMulWithBiasAndRelu'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QINT8 } }"
" attr { key: 'Toutput' value { type: DT_QINT32 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G']}"
"node { name: 'J' op: 'Zeta' attr { key: 'T' value { type: DT_QINT32 } }"
" input: ['I', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QInt8Input);C(Input);D(Input);"
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);"
"DMT/_4(Const);DMT/_5(Const);DMT/_6(Const);E(Input);F(Input);"
"G(Input);H(QInt32Input);I(_MklQuantizedMatMulWithBiasAndRelu);"
"J(Zeta)|A->I;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;A:control->DMT/_5:control;"
"A:control->DMT/_6:control;B->I:1;C->I:2;D->I:3;DMT/_0->I:7;"
"DMT/_1->I:8;DMT/_2->I:9;DMT/_3->I:10;DMT/_4->I:11;DMT/_5->I:12;"
"DMT/_6->I:13;E->I:4;F->I:5;G->I:6;H->J:1;I->J");
}
// Rewrite test for QuantizedMatMulWithBiasAndRelu Op with unsupported input
// Rewrite should not happen
TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedMatMulWithBiasAndRelu_Negative) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QUInt8Input' }"
"node { name: 'C' op: 'Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'QInt32Input'}"
"node { name: 'I' op: 'QuantizedMatMulWithBiasAndRelu'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QUINT8 } }"
" attr { key: 'Toutput' value { type: DT_QINT32 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G']}"
"node { name: 'J' op: 'Zeta' attr { key: 'T' value { type: DT_QINT32 } }"
" input: ['I', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QUInt8Input);C(Input);D(Input);"
"E(Input);F(Input);G(Input);H(QInt32Input);"
"I(QuantizedMatMulWithBiasAndRelu);J(Zeta)|A->I;"
"B->I:1;C->I:2;D->I:3;E->I:4;F->I:5;"
"G->I:6;H->J:1;I->J");
}
// Fused QuantizedMatMulWithBiasAndReluAndRequantize Op Rewrite test
// Rewrite the QuantizedMatMulWithBiasAndReluAndRequantize with
// _MklQuantizedMatMulWithBiasAndReluAndRequantize
TEST_F(MklLayoutPassTest,
NodeRewrite_QuantizedMatMulWithBiasAndReluAndRequantize_Positive) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QInt8Input' }"
"node { name: 'C' op: 'QInt32Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'Input'}"
"node { name: 'I' op: 'Input'}"
"node { name: 'J' op: 'QUInt8Input'}"
"node { name: 'K' op: 'QuantizedMatMulWithBiasAndReluAndRequantize'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QINT8 } }"
" attr { key: 'Tbias' value { type: DT_QINT32 } }"
" attr { key: 'Toutput' value { type: DT_QUINT8 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']}"
"node { name: 'L' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['K', 'J'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QInt8Input);C(QInt32Input);"
"D(Input);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);"
"DMT/_5(Const);DMT/_6(Const);DMT/_7(Const);DMT/_8(Const);E(Input);"
"F(Input);G(Input);H(Input);I(Input);J(QUInt8Input);"
"K(_MklQuantizedMatMulWithBiasAndReluAndRequantize);L(Zeta)|A->K;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;A:control->DMT/_5:control;"
"A:control->DMT/_6:control;A:control->DMT/_7:control;"
"A:control->DMT/_8:control;B->K:1;C->K:2;D->K:3;DMT/_0->K:9;"
"DMT/_1->K:10;DMT/_2->K:11;DMT/_3->K:12;DMT/_4->K:13;DMT/_5->K:14;"
"DMT/_6->K:15;DMT/_7->K:16;DMT/_8->K:17;E->K:4;F->K:5;G->K:6;"
"H->K:7;I->K:8;J->L:1;K->L");
}
// Rewrite test for QuantizedMatMulWithBiasAndRelu Op with unsupported input
// Rewrite should not happen
TEST_F(MklLayoutPassTest,
NodeRewrite_QuantizedMatMulWithBiasAndReluAndRequantize_Negative) {
InitGraph(
"node { name: 'A' op: 'QUInt8Input' }"
"node { name: 'B' op: 'QUInt8Input' }"
"node { name: 'C' op: 'QInt32Input' }"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Input'}"
"node { name: 'G' op: 'Input'}"
"node { name: 'H' op: 'Input'}"
"node { name: 'I' op: 'Input'}"
"node { name: 'J' op: 'QUInt8Input'}"
"node { name: 'K' op: 'QuantizedMatMulWithBiasAndReluAndRequantize'"
" attr { key: 'T1' value { type: DT_QUINT8 } }"
" attr { key: 'T2' value { type: DT_QUINT8 } }"
" attr { key: 'Tbias' value { type: DT_QINT32 } }"
" attr { key: 'Toutput' value { type: DT_QUINT8 } }"
" input: ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']}"
"node { name: 'L' op: 'Zeta' attr { key: 'T' value { type: DT_QUINT8 } }"
" input: ['K', 'J'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(QUInt8Input);B(QUInt8Input);C(QInt32Input);"
"D(Input);E(Input);F(Input);G(Input);H(Input);I(Input);"
"J(QUInt8Input);"
"K(QuantizedMatMulWithBiasAndReluAndRequantize);L(Zeta)|A->K;"
"B->K:1;C->K:2;D->K:3;E->K:4;F->K:5;G->K:6;"
"H->K:7;I->K:8;J->L:1;K->L");
}
static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
testing::StopTiming();
string s;
for (int in = 0; in < 10; in++) {
s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
}
random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox);
for (int op = 0; op < op_nodes; op++) {
s += strings::Printf(
"node { name: 'op%04d' op: 'Zeta' attr { key: 'T' value { "
"type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
op, rnd.Uniform(10), rnd.Uniform(10));
}
bool first = true;
while (iters > 0) {
Graph* graph = new Graph(OpRegistry::Global());
InitGraph(s, graph);
int N = graph->num_node_ids();
if (first) {
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
first = false;
}
{
testing::StartTiming();
std::unique_ptr<Graph> ug(graph);
RunMklLayoutRewritePass(&ug);
testing::StopTiming();
}
iters -= N; // Our benchmark units are individual graph nodes,
// not whole graphs
// delete graph;
}
}
BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace
} // namespace tensorflow
#endif // INTEL_MKL && ENABLE_MKL