blob: 9d83ec79ee722adc3af3a7315ce52dbf6e942c66 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
#include "tensorflow/core/grappler/utils/graph_view.h"
namespace tensorflow {
namespace grappler {
namespace utils {
//------------------------------------------------------------------------------
// A pattern can be defined by the following grammar. Here, op_type is any valid
// op name in the TensorFlow.
//
// leaf_pattern ::= `{` op_type `}`
// pattern ::= leaf_pattern |
// `{` op_type `,` `{` pattern `,` ... `,` pattern `}` `}`
//
// (1) For example, the following pattern syntax describes a pattern for
// _FusedConv2D (Conv2D + BiasAdd + Relu). Note that "*" means any type of op.
//
// {"Relu",
// {
// "BiasAdd",
// {
// {"Conv2D"},
// {"*"}
// }
// }
// }
//
// The syntax above has a root ("Relu") and children (inputs), where each child
// is a sub-pattern. Graph pattern matcher finds a match for the given pattern
// syntax in a graph and returns a set of matched nodes.
//
// (2) In order to match a DAG with a given root, we extend pattern syntax with
// labels. For example, a frequently found pattern in Deep Learning models is a
// residual block like below.
//
// Placeholder Const
// | |
// +-----+-----+ |
// | | |
// | v v
// | Conv2D Const
// | | |
// | v v-----+
// | BiasAdd
// | |
// v v----------+
// AddV2
//
// As shown above, it is the same input node (Placeholder) consumed by both
// AddV2 and and Conv2D. This constrained can be put as labels in the following
// augmented pattern syntax.
//
// {"AddV2", "my_add",
// {
// {"*", "my_residual_input"},
// {"BiasAdd", "my_bias_add",
// {
// {"Conv2D", "my_conv",
// {
// {"*", "my_residual_input"},
// {"*", "my_filter"}
// }
// },
// {"*", my_bias"}
// }
// }
// }
// }
//
// Note that the same label "my_residual_input" is used to tell that it is a
// child of both "AddV2" and "Conv2D". Labels are arbitrary strings to associate
// with the nodes to be matched as well as to uniquely identify those nodes.
//
// (3) The motivatation for a grammar based pattern matching in grappler is to
// make easy for finding fusion pattern in the remapper. A subgraph that
// matches a given pattern, however, is not fusable if any of the matched node,
// that will be removed as a part of fusion, has a consumer outside the matched
// subgraph. In order to check for such type of external dependencies, we
// further extend pattern syntax by prospective action (NodeStatus) on the
// matched nodes as shown below. This helps cross checking the nodes to be
// removed with the nodes matched intially.
//
// {"AddV2", "my_add", NodeStatus::kReplace,
// {
// {"*", "my_residual_input", NodeStatus::kRemain},
// {"BiasAdd", "my_bias_add", NodeStatus::kRemove,
// {
// {"Conv2D", "my_conv", NodeStatus::kRemove,
// {
// {"*", "my_residual_input", NodeStatus::kRemain},
// {"*", "my_filter", NodeStatus::Remain}
// }
// },
// {"*", my_bias", NodeStatus::kRemain}
// }
// }
// }
// }
//------------------------------------------------------------------------------
// Pattern matcher recursively matches child subpatterns. The direction
// for children could be toward node's input (fanins) or outputs (fanouts).
enum class MatchingDirection { kFollowInputs, kFollowOutputs };
// Action for each node in the set of matched nodes for a given pattern.
enum class NodeStatus { kRemain, kRemove, kReplace };
// TODO (intel-tf): Support multiple roots by making them children of a single
// virtual root.
struct OpTypePattern {
string op;
string label;
NodeStatus node_status;
std::vector<OpTypePattern> children;
string DebugString() const {
string result = "{(op: " + op + ", " + "label: " + label + "), {";
for (const OpTypePattern& child : children) {
result += child.DebugString() + ",";
}
result += "}}";
return result;
}
};
// This is a helpful recursive structure that keeps one-to-one mapping of
// pattern syntax to the matched nodes. User can call DebugString to see what
// has been matched so far and where is the failing point.
struct NodeViewMatch {
MutableNodeView* node_view = nullptr;
std::vector<NodeViewMatch> children;
string DebugString() const {
string result = "{";
if (node_view == nullptr) {
result += "Non-Matched-Node}";
return result;
} else {
result += node_view->node()->DebugString();
result += ", {";
for (const NodeViewMatch& child : children) {
result += child.DebugString() + ",";
}
result += "}}";
return result;
}
}
void Clear() {
for (NodeViewMatch& child : children) {
child.Clear(); // child is an object.
}
children.clear(); // children is a vector.
if (node_view != nullptr) {
node_view = nullptr;
}
}
};
template <MatchingDirection DIRECTION = MatchingDirection::kFollowInputs>
class SubGraphMatcher {
public:
SubGraphMatcher(MutableGraphView* graph_view) : graph_view_(graph_view){};
// If a given pattern is matched, this function returns true as well as the
// matched node and remove node info is populated.
bool GetMatchedNodes(const OpTypePattern& pattern, MutableNodeView* node_view,
std::map<string, int>* matched_nodes_map,
std::set<int>* remove_node_indices);
private:
MutableGraphView* graph_view_;
std::map<string, int> node_label_to_index_;
std::set<int> matched_node_indices_;
std::set<int> remove_node_indices_;
std::unique_ptr<NodeViewMatch> match_ = nullptr;
bool DoesOpTypePatternMatch(const OpTypePattern& pattern,
MutableNodeView* node_view, NodeViewMatch* match);
// This function should be called after the pattern matcher has found
// potential matched nodes (i.e. when DoesOpTypePatternMatch returns "true").
// It performs a sanity check if the candidate nodes for removal in subgraph
// fusion is indeed safe to remove.
bool HasRemoveNodeExternalDependents() {
for (const auto& node_idx : remove_node_indices_) {
auto node_view = graph_view_->GetNode(node_idx);
// Traverse all the Regular Fanouts. Fanouts are stored as vector of
// vector, std::vector<std::vector<MutableFaninView>>. Note that
// a MutableNodeView's fanouts are stored in a nested vector of
// MutableFaninView type.
auto fanouts_by_ports = node_view->GetRegularFanouts();
for (const auto& fanouts : fanouts_by_ports) {
for (const auto& fanout : fanouts) {
if (!matched_node_indices_.count(fanout.node_index())) {
return true;
}
}
}
}
return false;
}
};
} // namespace utils
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_