Graph pattern matcher for grappler.
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index e87ec9d..3195867 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -393,6 +393,29 @@
)
cc_library(
+ name = "pattern_utils",
+ srcs = ["pattern_utils.cc"],
+ hdrs = ["pattern_utils.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_view",
+ ],
+)
+
+tf_cc_test(
+ name = "pattern_utils_test",
+ srcs = ["pattern_utils_test.cc"],
+ deps = [
+ ":pattern_utils",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+cc_library(
name = "transitive_fanin",
srcs = ["transitive_fanin.cc"],
hdrs = ["transitive_fanin.h"],
diff --git a/tensorflow/core/grappler/utils/pattern_utils.cc b/tensorflow/core/grappler/utils/pattern_utils.cc
new file mode 100644
index 0000000..4b3f845
--- /dev/null
+++ b/tensorflow/core/grappler/utils/pattern_utils.cc
@@ -0,0 +1,129 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/utils/pattern_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace utils {
+
+// A subgraph pattern syntax implicitly defines a DAG having a single root. We
+// traverse the syntax DAG in DFS manner. This function finds a match for
+// current root of the pattern with the current node and recursively matches
+// children subpatterns with the children of current node.
+template <>
+bool SubGraphMatcher<MatchingDirection::kFollowInputs>::DoesOpTypePatternMatch(
+ const OpTypePattern& pattern, MutableNodeView* node_view,
+ NodeViewMatch* match) {
+ // Currently no control inputs and outputs are allowed.
+ if (node_view->NumControllingFanins() > 0 ||
+ node_view->NumControlledFanouts() > 0)
+ return false;
+
+ bool op_type_matched = false;
+ if (pattern.op == "*") {
+ op_type_matched = true;
+ } else {
+ // The op field string of current pattern might express an op among multiple
+ // op types (mutually exclusive) separated by '|'.
+ std::vector<string> op_list = str_util::Split(pattern.op, '|');
+ for (const string& op : op_list) {
+ if (node_view->node()->op() == op) {
+ op_type_matched = true;
+ break;
+ }
+ }
+ }
+ if (op_type_matched) {
+ // If op type matches and current node is visited first time, insert current
+ // node to node_label_to_index_ map with the current label as the key.
+ // Multiple occurances of same label in the pattern syntax indicates that
+ // the same node needs to be visited for each of such occurances. Hence
+ // subsequent visits should find the corresponding label in the map as a key
+ // and the current node should be the value for that key.
+ if (node_label_to_index_.find(pattern.label) ==
+ node_label_to_index_.end()) {
+ node_label_to_index_[pattern.label] = node_view->node_index();
+ // Bookkeeping
+ matched_node_indices_.insert(node_view->node_index());
+ if (pattern.node_status == NodeStatus::kRemove) {
+ remove_node_indices_.insert(node_view->node_index());
+ }
+ } else if (node_label_to_index_[pattern.label] != node_view->node_index()) {
+ return false; // label constraint could not be satisfied.
+ } else {
+ DCHECK(node_label_to_index_[pattern.label] == node_view->node_index());
+ }
+ } else {
+ return false;
+ }
+ // Current root of the pattern syntax is matched with the current node.
+ match->node_view = node_view;
+
+ // Go for matching child subpattern.
+ if (!pattern.children.empty()) {
+ // Currently only direction toward inputs is implemented.
+ auto node_view_children = node_view->GetRegularFanins();
+ if (node_view_children.size() != pattern.children.size()) {
+ return false;
+ } else {
+ for (int i = 0; i < pattern.children.size(); ++i) {
+ auto child_node_index = node_view_children[i].node_index();
+ // TODO (mdfaijul): Is it guaranted that GetNode will reuturn non null
+ // pointer.
+ MutableNodeView* child_node_view =
+ graph_view_->GetNode(child_node_index);
+ const OpTypePattern& child_pattern = pattern.children[i];
+ match->children.push_back(NodeViewMatch());
+ NodeViewMatch* child_match = &(match->children.back());
+ if (!DoesOpTypePatternMatch(child_pattern, child_node_view,
+ child_match)) {
+ return false;
+ }
+ }
+ }
+ }
+ return true;
+}
+
+// Current implementation supports pattern maching toward node's inputs only.
+template <>
+bool SubGraphMatcher<MatchingDirection::kFollowInputs>::GetMatchedNodes(
+ const OpTypePattern& pattern, MutableNodeView* node_view,
+ std::map<string, int>* matched_nodes_map,
+ std::set<int>* remove_node_indices) {
+ bool found_match = false;
+ match_.reset(new NodeViewMatch());
+ if (DoesOpTypePatternMatch(pattern, node_view, match_.get())) {
+ if (!HasRemoveNodeExternalDependents()) {
+ found_match = true;
+ matched_nodes_map->swap(this->node_label_to_index_);
+ remove_node_indices->swap(this->remove_node_indices_);
+ }
+ } else {
+ found_match = false;
+ // Clear all bookkeeping data
+ match_->Clear();
+ match_.reset(nullptr);
+ node_label_to_index_.clear();
+ matched_node_indices_.clear();
+ remove_node_indices_.clear();
+ }
+ return found_match;
+}
+
+} // namespace utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/pattern_utils.h b/tensorflow/core/grappler/utils/pattern_utils.h
new file mode 100644
index 0000000..9d83ec7
--- /dev/null
+++ b/tensorflow/core/grappler/utils/pattern_utils.h
@@ -0,0 +1,227 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
+
+#include "tensorflow/core/grappler/utils/graph_view.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace utils {
+
+//------------------------------------------------------------------------------
+// A pattern can be defined by the following grammar. Here, op_type is any valid
+// op name in the TensorFlow.
+//
+// leaf_pattern ::= `{` op_type `}`
+// pattern ::= leaf_pattern |
+// `{` op_type `,` `{` pattern `,` ... `,` pattern `}` `}`
+//
+// (1) For example, the following pattern syntax describes a pattern for
+// _FusedConv2D (Conv2D + BiasAdd + Relu). Note that "*" means any type of op.
+//
+// {"Relu",
+// {
+// "BiasAdd",
+// {
+// {"Conv2D"},
+// {"*"}
+// }
+// }
+// }
+//
+// The syntax above has a root ("Relu") and children (inputs), where each child
+// is a sub-pattern. Graph pattern matcher finds a match for the given pattern
+// syntax in a graph and returns a set of matched nodes.
+//
+// (2) In order to match a DAG with a given root, we extend pattern syntax with
+// labels. For example, a frequently found pattern in Deep Learning models is a
+// residual block like below.
+//
+// Placeholder Const
+// | |
+// +-----+-----+ |
+// | | |
+// | v v
+// | Conv2D Const
+// | | |
+// | v v-----+
+// | BiasAdd
+// | |
+// v v----------+
+// AddV2
+//
+// As shown above, it is the same input node (Placeholder) consumed by both
+// AddV2 and and Conv2D. This constrained can be put as labels in the following
+// augmented pattern syntax.
+//
+// {"AddV2", "my_add",
+// {
+// {"*", "my_residual_input"},
+// {"BiasAdd", "my_bias_add",
+// {
+// {"Conv2D", "my_conv",
+// {
+// {"*", "my_residual_input"},
+// {"*", "my_filter"}
+// }
+// },
+// {"*", my_bias"}
+// }
+// }
+// }
+// }
+//
+// Note that the same label "my_residual_input" is used to tell that it is a
+// child of both "AddV2" and "Conv2D". Labels are arbitrary strings to associate
+// with the nodes to be matched as well as to uniquely identify those nodes.
+//
+// (3) The motivatation for a grammar based pattern matching in grappler is to
+// make easy for finding fusion pattern in the remapper. A subgraph that
+// matches a given pattern, however, is not fusable if any of the matched node,
+// that will be removed as a part of fusion, has a consumer outside the matched
+// subgraph. In order to check for such type of external dependencies, we
+// further extend pattern syntax by prospective action (NodeStatus) on the
+// matched nodes as shown below. This helps cross checking the nodes to be
+// removed with the nodes matched intially.
+//
+// {"AddV2", "my_add", NodeStatus::kReplace,
+// {
+// {"*", "my_residual_input", NodeStatus::kRemain},
+// {"BiasAdd", "my_bias_add", NodeStatus::kRemove,
+// {
+// {"Conv2D", "my_conv", NodeStatus::kRemove,
+// {
+// {"*", "my_residual_input", NodeStatus::kRemain},
+// {"*", "my_filter", NodeStatus::Remain}
+// }
+// },
+// {"*", my_bias", NodeStatus::kRemain}
+// }
+// }
+// }
+// }
+//------------------------------------------------------------------------------
+
+// Pattern matcher recursively matches child subpatterns. The direction
+// for children could be toward node's input (fanins) or outputs (fanouts).
+enum class MatchingDirection { kFollowInputs, kFollowOutputs };
+
+// Action for each node in the set of matched nodes for a given pattern.
+enum class NodeStatus { kRemain, kRemove, kReplace };
+
+// TODO (intel-tf): Support multiple roots by making them children of a single
+// virtual root.
+struct OpTypePattern {
+ string op;
+ string label;
+ NodeStatus node_status;
+ std::vector<OpTypePattern> children;
+
+ string DebugString() const {
+ string result = "{(op: " + op + ", " + "label: " + label + "), {";
+ for (const OpTypePattern& child : children) {
+ result += child.DebugString() + ",";
+ }
+ result += "}}";
+ return result;
+ }
+};
+
+// This is a helpful recursive structure that keeps one-to-one mapping of
+// pattern syntax to the matched nodes. User can call DebugString to see what
+// has been matched so far and where is the failing point.
+struct NodeViewMatch {
+ MutableNodeView* node_view = nullptr;
+ std::vector<NodeViewMatch> children;
+
+ string DebugString() const {
+ string result = "{";
+ if (node_view == nullptr) {
+ result += "Non-Matched-Node}";
+ return result;
+ } else {
+ result += node_view->node()->DebugString();
+ result += ", {";
+ for (const NodeViewMatch& child : children) {
+ result += child.DebugString() + ",";
+ }
+ result += "}}";
+ return result;
+ }
+ }
+
+ void Clear() {
+ for (NodeViewMatch& child : children) {
+ child.Clear(); // child is an object.
+ }
+ children.clear(); // children is a vector.
+ if (node_view != nullptr) {
+ node_view = nullptr;
+ }
+ }
+};
+
+template <MatchingDirection DIRECTION = MatchingDirection::kFollowInputs>
+class SubGraphMatcher {
+ public:
+ SubGraphMatcher(MutableGraphView* graph_view) : graph_view_(graph_view){};
+
+ // If a given pattern is matched, this function returns true as well as the
+ // matched node and remove node info is populated.
+ bool GetMatchedNodes(const OpTypePattern& pattern, MutableNodeView* node_view,
+ std::map<string, int>* matched_nodes_map,
+ std::set<int>* remove_node_indices);
+
+ private:
+ MutableGraphView* graph_view_;
+ std::map<string, int> node_label_to_index_;
+ std::set<int> matched_node_indices_;
+ std::set<int> remove_node_indices_;
+ std::unique_ptr<NodeViewMatch> match_ = nullptr;
+
+ bool DoesOpTypePatternMatch(const OpTypePattern& pattern,
+ MutableNodeView* node_view, NodeViewMatch* match);
+
+ // This function should be called after the pattern matcher has found
+ // potential matched nodes (i.e. when DoesOpTypePatternMatch returns "true").
+ // It performs a sanity check if the candidate nodes for removal in subgraph
+ // fusion is indeed safe to remove.
+ bool HasRemoveNodeExternalDependents() {
+ for (const auto& node_idx : remove_node_indices_) {
+ auto node_view = graph_view_->GetNode(node_idx);
+ // Traverse all the Regular Fanouts. Fanouts are stored as vector of
+ // vector, std::vector<std::vector<MutableFaninView>>. Note that
+ // a MutableNodeView's fanouts are stored in a nested vector of
+ // MutableFaninView type.
+ auto fanouts_by_ports = node_view->GetRegularFanouts();
+ for (const auto& fanouts : fanouts_by_ports) {
+ for (const auto& fanout : fanouts) {
+ if (!matched_node_indices_.count(fanout.node_index())) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+};
+
+} // namespace utils
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
diff --git a/tensorflow/core/grappler/utils/pattern_utils_test.cc b/tensorflow/core/grappler/utils/pattern_utils_test.cc
new file mode 100644
index 0000000..f2ea0b6
--- /dev/null
+++ b/tensorflow/core/grappler/utils/pattern_utils_test.cc
@@ -0,0 +1,475 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/utils/pattern_utils.h"
+#include "tensorflow/core/util/dump_graph.h"
+
+#include "tensorflow/cc/ops/nn_ops_internal.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace utils {
+namespace {
+
+using ::tensorflow::ops::Placeholder;
+
+void GetMatMulBiasAddGeluGraph(GraphDef* graph,
+ bool add_external_dependent = false) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto input_shape = ops::Placeholder::Shape({8, 32});
+ auto weight_shape = ops::Placeholder::Shape({32, 64});
+ auto bias_shape = ops::Placeholder::Shape({64});
+
+ auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
+ auto weight = Placeholder(s.WithOpName("weight"), DT_FLOAT, weight_shape);
+ auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
+
+ auto matmul = ops::MatMul(s.WithOpName("matmul"), input, weight);
+ auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
+ if (add_external_dependent) {
+ auto external_dependent =
+ ops::Identity(s.WithOpName("external_dependent"), bias_add);
+ }
+ // Gelu with smaller ops
+ auto one_over_square_root_two =
+ ops::Const(s.WithOpName("one_over_square_root_two"), {0.707f}, {});
+ auto bias_add_times_const = ops::Mul(s.WithOpName("bias_add_times_const"),
+ bias_add, one_over_square_root_two);
+ auto erf = ops::Erf(s.WithOpName("erf"), bias_add_times_const);
+ auto one = ops::Const(s.WithOpName("one"), {1.0f}, {});
+ auto erf_plus_one = ops::AddV2(s.WithOpName("erf_plus_one"), erf, one);
+ auto one_half = ops::Const(s.WithOpName("one_half"), {0.5f}, {});
+ auto one_half_times_erf_plus_one = ops::Mul(
+ s.WithOpName("one_half_times_erf_plus_one"), one_half, erf_plus_one);
+ auto gelu =
+ ops::Mul(s.WithOpName("gelu"), one_half_times_erf_plus_one, bias_add);
+ auto fetch = ops::Identity(s.WithOpName("fetch"), gelu);
+
+ TF_ASSERT_OK(s.ToGraphDef(graph));
+}
+
+OpTypePattern GetMatMulBiasAddGeluPattern() {
+ // Although labels are arbitrary, for the convenience of check they are
+ // prefixed with "my_" to the orginal node names in the global graph.
+ // clang-format off
+ OpTypePattern pattern_syntax{"Mul", "my_gelu", NodeStatus::kReplace,
+ {
+ {"Mul", "my_one_half_times_erf_plus_one", NodeStatus::kRemove,
+ {
+ {"Const", "my_one_half", NodeStatus::kRemain},
+ {"AddV2", "my_erf_plus_one", NodeStatus::kRemove,
+ {
+ {"Erf", "my_erf", NodeStatus::kRemove,
+ {
+ {"Mul", "my_bias_add_times_const", NodeStatus::kRemove,
+ {
+ {"BiasAdd", "my_bias_add", NodeStatus::kRemove},
+ {"Const", "my_one_over_square_root_two", NodeStatus::kRemain}
+ }
+ }
+ }
+ },
+ {"Const", "my_one", NodeStatus::kRemain}
+ }
+ }
+ }
+ },
+ {"BiasAdd", "my_bias_add", NodeStatus::kRemove,
+ {
+ {"MatMul", "my_matmul", NodeStatus::kRemove},
+ {"*", "my_bias", NodeStatus::kRemain}
+ }
+ }
+ }
+ }; // clang-format on
+
+ return pattern_syntax;
+}
+
+class PatternMatcherTest : public ::testing::Test {
+ protected:
+ struct NodeConfig {
+ NodeConfig(string name, string op, std::vector<string> inputs)
+ : name(std::move(name)), op(std::move(op)), inputs(std::move(inputs)) {}
+
+ string name;
+ string op;
+ std::vector<string> inputs;
+ };
+
+ static GraphDef CreateGraph(const std::vector<NodeConfig>& nodes) {
+ GraphDef graph;
+
+ for (const NodeConfig& node : nodes) {
+ NodeDef node_def;
+ node_def.set_name(node.name);
+ node_def.set_op(node.op);
+ for (const string& input : node.inputs) {
+ node_def.add_input(input);
+ }
+ *graph.add_node() = std::move(node_def);
+ }
+
+ return graph;
+ }
+};
+
+TEST_F(PatternMatcherTest, Tree) {
+ // A Data flow graph. Data flows from top to bottom. Here A, B, C, D, and E
+ // are ops.
+ //
+ // Input graph Subgraph for pattern matcher
+ //
+ // A C D
+ // | \ /
+ // B E
+ // /
+ // C D
+ // \ /
+ // E
+ //
+ // E is the root of pattern syntax as shown below that the pattern matcher
+ // would match.
+ // {"E", "my_e", NodeStatus::kReplace,
+ // {
+ // {"C", "my_c", NodeStatus::kRemove}
+ // {"D", "my_d", NodeStatus::kRemove}
+ // }
+ // }
+
+ ::tensorflow::Status status;
+ GraphDef graph = CreateGraph({{"e", "E", {"c", "d"}},
+ {"c", "C", {"b"}},
+ {"d", "D", {}},
+ {"b", "B", {"a"}},
+ {"a", "A", {}}});
+ // clang-format off
+ OpTypePattern pattern{"E", "my_e", NodeStatus::kReplace,
+ {
+ {"C", "my_c", NodeStatus::kRemove},
+ {"D", "my_d", NodeStatus::kRemove}
+ }
+ }; // clang-format on
+
+ MutableGraphView graph_view(&graph, &status);
+ TF_ASSERT_OK(status);
+ graph_view.SortTopologically(/*ignore_cycles=*/false, {});
+ auto root_node_view = graph_view.GetNode("e");
+
+ SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
+ std::map<string, int> matched_nodes_map; // label to node index map
+ std::set<int> remove_node_indices;
+ bool found_match = graph_matcher.GetMatchedNodes(
+ pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
+
+ EXPECT_TRUE(found_match);
+ EXPECT_FALSE(matched_nodes_map.empty());
+ EXPECT_FALSE(remove_node_indices.empty());
+
+ bool all_indices_matched = true;
+ for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
+ it++) {
+ auto label = str_util::StripPrefix(it->first, "my_");
+ int matched_node_idx = it->second;
+ int expected_node_idx = graph_view.GetNode(label)->node_index();
+ if (matched_node_idx != expected_node_idx) {
+ all_indices_matched = false;
+ break;
+ }
+ }
+ EXPECT_TRUE(all_indices_matched);
+}
+
+TEST_F(PatternMatcherTest, DAG) {
+ // A Data flow graph. Data flows from top to bottom. Here A, B, C, D, and E
+ // are ops.
+ //
+ // Input graph Subgraph for pattern matcher
+ //
+ // A
+ // | B
+ // B / \
+ // / \ C D
+ // C D \ /
+ // \ / E
+ // E
+ //
+ // E is the root of pattern syntax as shown below that the pattern matcher
+ // would match.
+ // {"E", "my_e", NodeStatus::kReplace,
+ // {
+ // {"C", "my_c", NodeStatus::kRemove,
+ // {
+ // {"B", "my_b", NodeStatus::kRemove}
+ // }
+ // },
+ // {"D", "my_d", NodeStatus::kRemove,
+ // {
+ // {"B", "my_b", NodeStatus::kRemove}
+ // }
+ // }
+ // }
+ // }
+
+ ::tensorflow::Status status;
+ GraphDef graph = CreateGraph({{"e", "E", {"c", "d"}},
+ {"c", "C", {"b"}},
+ {"d", "D", {"b"}},
+ {"b", "B", {"a"}},
+ {"a", "A", {}}});
+ // clang-format off
+ OpTypePattern pattern{"E", "my_e", NodeStatus::kReplace,
+ {
+ {"C", "my_c", NodeStatus::kRemove,
+ {
+ {"B", "my_b", NodeStatus::kRemove}
+ }
+ },
+ {"D", "my_d", NodeStatus::kRemove,
+ {
+ {"B", "my_b", NodeStatus::kRemove}
+ }
+ }
+ }
+ }; // clang-format on
+
+ MutableGraphView graph_view(&graph, &status);
+ TF_ASSERT_OK(status);
+ graph_view.SortTopologically(/*ignore_cycles=*/false, {});
+ auto root_node_view = graph_view.GetNode("e");
+
+ SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
+ std::map<string, int> matched_nodes_map; // label to node index map
+ std::set<int> remove_node_indices;
+ bool found_match = graph_matcher.GetMatchedNodes(
+ pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
+
+ EXPECT_TRUE(found_match);
+ EXPECT_FALSE(matched_nodes_map.empty());
+ EXPECT_FALSE(remove_node_indices.empty());
+
+ bool all_indices_matched = true;
+ for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
+ it++) {
+ auto label = str_util::StripPrefix(it->first, "my_");
+ int matched_node_idx = it->second;
+ int expected_node_idx = graph_view.GetNode(label)->node_index();
+ if (matched_node_idx != expected_node_idx) {
+ all_indices_matched = false;
+ break;
+ }
+ }
+ EXPECT_TRUE(all_indices_matched);
+}
+
+// Pattern should not be matched if any of candidate remove nodes has external
+// dependent.
+TEST_F(PatternMatcherTest, DAGExternalDependent) {
+ // A Data flow graph. Data flows from top to bottom. Here A, B, C, D, E, and F
+ // are ops.
+ //
+ // Input graph Subgraph for pattern matcher
+ //
+ // A
+ // | B
+ // B / \
+ // / \ C D
+ // C D \ /
+ // \ / \ E
+ // E F
+ //
+ // E is the root of pattern syntax as shown below that the pattern matcher
+ // would match. Note D is a candidate for remove node as mentioned in the
+ // syntax. So Pattern matcher should not find a match.
+ // {"E", "my_e", NodeStatus::Replace,
+ // {
+ // {"C", "my_c", NodeStatus::kRemove,
+ // {
+ // {"B", "my_b", NodeStatus::kRemove}
+ // }
+ // },
+ // {"D", "my_d", NodeStatus::kRemove,
+ // {
+ // {"B", "my_b", NodeStatus::kRemove}
+ // }
+ // }
+ // }
+ // }
+
+ ::tensorflow::Status status;
+ GraphDef graph = CreateGraph({{"f", "F", {"d"}},
+ {"e", "E", {"c", "d"}},
+ {"c", "C", {"b"}},
+ {"d", "D", {"b"}},
+ {"b", "B", {"a"}},
+ {"a", "A", {}}});
+ // clang-format off
+ OpTypePattern pattern{"E", "my_e", NodeStatus::kReplace,
+ {
+ {"C", "my_c", NodeStatus::kRemove,
+ {
+ {"B", "my_b", NodeStatus::kRemove}
+ }
+ },
+ {"D", "my_d", NodeStatus::kRemove,
+ {
+ {"B", "my_b", NodeStatus::kRemove}
+ }
+ }
+ }
+ }; // clang-format on
+
+ MutableGraphView graph_view(&graph, &status);
+ TF_ASSERT_OK(status);
+ graph_view.SortTopologically(/*ignore_cycles=*/false, {});
+ auto root_node_view = graph_view.GetNode("e");
+
+ SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
+ std::map<string, int> matched_nodes_map; // label to node index map
+ std::set<int> remove_node_indices;
+ bool found_match = graph_matcher.GetMatchedNodes(
+ pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
+
+ EXPECT_FALSE(found_match);
+ EXPECT_TRUE(matched_nodes_map.empty());
+ EXPECT_TRUE(remove_node_indices.empty());
+}
+
+TEST_F(PatternMatcherTest, MatMulBiasAddGelu) {
+ ::tensorflow::Status status;
+ GraphDef graph;
+ GetMatMulBiasAddGeluGraph(&graph);
+ OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
+ MutableGraphView graph_view(&graph, &status);
+ TF_ASSERT_OK(status);
+ graph_view.SortTopologically(/*ignore_cycles=*/false, {});
+ auto root_node_view = graph_view.GetNode("gelu");
+
+ SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
+ std::map<string, int> matched_nodes_map; // label to node index map
+ std::set<int> remove_node_indices;
+ bool found_match = graph_matcher.GetMatchedNodes(
+ pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
+
+ EXPECT_TRUE(found_match);
+ EXPECT_FALSE(matched_nodes_map.empty());
+ EXPECT_FALSE(remove_node_indices.empty());
+
+ bool all_indices_matched = true;
+ for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
+ it++) {
+ auto label = str_util::StripPrefix(it->first, "my_");
+ int matched_node_idx = it->second;
+ int expected_node_idx = graph_view.GetNode(label)->node_index();
+ if (matched_node_idx != expected_node_idx) {
+ all_indices_matched = false;
+ break;
+ }
+ }
+ EXPECT_TRUE(all_indices_matched);
+}
+
+// Pattern should not be matched if any of candidate remove nodes has external
+// dependent.
+TEST_F(PatternMatcherTest, MatMulBiasAddGeluExternalDependent) {
+ ::tensorflow::Status status;
+ GraphDef graph;
+ GetMatMulBiasAddGeluGraph(&graph, /*add_external_dependent=*/true);
+ OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
+ MutableGraphView graph_view(&graph, &status);
+ TF_ASSERT_OK(status);
+ graph_view.SortTopologically(/*ignore_cycles=*/false, {});
+ auto root_node_view = graph_view.GetNode("gelu");
+
+ SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
+ std::map<string, int> matched_nodes_map; // label to node index map
+ std::set<int> remove_node_indices;
+ bool found_match = graph_matcher.GetMatchedNodes(
+ pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
+
+ EXPECT_FALSE(found_match);
+ EXPECT_TRUE(matched_nodes_map.empty());
+ EXPECT_TRUE(remove_node_indices.empty());
+}
+
+TEST_F(PatternMatcherTest, MatMulBiasAddGeluMutation) {
+ ::tensorflow::Status status;
+ GraphDef graph;
+ GetMatMulBiasAddGeluGraph(&graph);
+ OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
+ MutableGraphView graph_view(&graph, &status);
+ TF_ASSERT_OK(status);
+ graph_view.SortTopologically(/*ignore_cycles=*/false, {});
+ auto root_node_view = graph_view.GetNode("gelu");
+
+ SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
+ std::map<string, int> matched_nodes_map; // label to node index map
+ std::set<int> remove_node_indices;
+ bool found_match = graph_matcher.GetMatchedNodes(
+ pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
+ EXPECT_TRUE(found_match);
+ EXPECT_FALSE(matched_nodes_map.empty());
+ EXPECT_FALSE(remove_node_indices.empty());
+
+ // Before mutation number of nodes.
+ int num_nodes_before = graph_view.NumNodes();
+ // Before mutation node_names of the remove candidate nodes.
+ std::vector<string> remove_node_names;
+ for (auto const& node_idx : remove_node_indices) {
+ remove_node_names.push_back(graph_view.GetNode(node_idx)->GetName());
+ }
+
+ Mutation* mutation = graph_view.GetMutationBuilder();
+ // Replace with fused op.
+ NodeDef fused_node;
+ fused_node.set_name("gelu");
+ fused_node.set_op("_FusedMatMul");
+ fused_node.add_input(graph_view.GetNode("matmul")->node()->input(0));
+ fused_node.add_input(graph_view.GetNode("matmul")->node()->input(1));
+ fused_node.add_input(graph_view.GetNode("bias_add")->node()->input(1));
+ mutation->AddNode(std::move(fused_node), &status);
+ TF_ASSERT_OK(status);
+ mutation->Apply();
+ // Remove nodes that are marked as NodeStatus::kRemove.
+ for (auto const& node_idx : remove_node_indices) {
+ mutation->RemoveNode(graph_view.GetNode(node_idx));
+ }
+ mutation->Apply();
+
+ // After mutation number of nodes.
+ int num_nodes_after = graph_view.NumNodes();
+ EXPECT_EQ(num_nodes_before - remove_node_indices.size(), num_nodes_after);
+
+ bool remove_nodes_deleted = true;
+ for (auto const& node_name : remove_node_names) {
+ if (graph_view.GetNode(node_name) != nullptr) {
+ remove_nodes_deleted = false;
+ break;
+ }
+ }
+ EXPECT_TRUE(remove_nodes_deleted);
+
+ bool replace_node_exist = graph_view.HasNode("gelu") ? true : false;
+ EXPECT_TRUE(replace_node_exist);
+}
+
+} // namespace
+} // namespace utils
+} // namespace grappler
+} // namespace tensorflow