| #include <torch/csrc/jit/subgraph_matcher.h> |
| #include <stack> |
| |
| namespace torch { |
| namespace jit { |
| namespace { |
| |
| /** |
| * \brief A class implementing an API for comparing subgraphs. |
| */ |
| class SubgraphMatcher { |
| public: |
| explicit SubgraphMatcher(const Graph& pattern) : pattern_(pattern) {} |
| |
| /** |
| * \brief Compare matchGraph with the part of the graph denoted by a node \p |
| * ANCHOR. |
| * |
| * The anchor node would be compared against the deepest node in the |
| * match-graph. A node is considered matching if its number of inputs/outputs |
| * is the same as in the corresponding matchGraph node, its type is the same, |
| * and all nodes producing input-values also match. |
| */ |
| bool matchesSubgraphFromAnchorNode(Node* anchor); |
| |
| /** \brief Return match map for nodes. */ |
| std::unordered_map<const Node*, Node*> nodes_map() const { |
| return nodes_map_; |
| } |
| |
| /** \brief Return match map for values. */ |
| std::unordered_map<const Value*, Value*> values_map() const { |
| return values_map_; |
| } |
| |
| private: |
| bool matchValues(const Value* v1, Value* v2); |
| bool matchNodes(const Node* n1, Node* n2); |
| |
| std::unordered_map<const Node*, Node*> nodes_map_; |
| std::unordered_map<const Value*, Value*> values_map_; |
| |
| const Graph& pattern_; |
| const Node* anchor_ = nullptr; |
| }; |
| |
| /** |
| * \brief A function to verify that \p PATTERN is valid. Concrete requirements |
| * for validity can be found in subgraph_matcher.h. |
| */ |
| bool patternGraphIsValid(const Graph& pattern) { |
| // Verify that pattern graph has a single block. |
| for (const Node* n : pattern.nodes()) { |
| if (!n->blocks().empty()) { |
| return false; |
| } |
| } |
| |
| // Verify that pattern graph returns only one value. |
| const Node* bottom_node = *(pattern.nodes().end()); |
| if (bottom_node->inputs().size() != 1) { |
| return false; |
| } |
| |
| // TODO: Verify that nodes in the pattern don't alias. |
| return true; |
| } |
| |
| /** |
| * Compare two Values. V1 is from pattern, V2 is from the actual graph. |
| * |
| * The values are considered matching if: |
| * 1) the nodes defining them match |
| * 2) they have the same number of uses, except they are entry or exit nodes. |
| */ |
| bool SubgraphMatcher::matchValues(const Value* v1, Value* v2) { |
| // Check if we've already visited these values. |
| if (values_map_.count(v1)) { |
| return values_map_.at(v1) == v2; |
| } |
| |
| // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is |
| // PARAM, we're comparing entering values - in these two cases the number of |
| // uses don't need to be the same. |
| if (v1->uses().size() != v2->uses().size() && v2->node() != anchor_ && |
| v1->node()->kind() != prim::Param) { |
| return false; |
| } |
| |
| // Add the values to the map before calling matchNodes to avoid infinite |
| // recursion. |
| values_map_[v1] = v2; |
| return matchNodes(v1->node(), v2->node()); |
| } |
| |
| /** |
| * Compare two Nodes. N1 is from pattern, N2 is from the actual graph. |
| * |
| * The nodes are considered matching if: |
| * 1) N1 and N2 are of the same kind. |
| * 2) Number of inputs and outputs is the same. |
| * 3) All input and output values match. |
| * |
| * A special case is when N1 is PARAM - this is considered outside the pattern, |
| * so it matches everything. |
| */ |
| bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) { |
| // Check if we've already visited these nodes. |
| if (nodes_map_.count(n1)) { |
| return nodes_map_.at(n1) == n2; |
| } |
| |
| // Param node in pattern graph matches everything. |
| if (n1->kind() == prim::Param) { |
| return true; |
| } |
| |
| // We don't allow matches to span across blocks, so check if N2 is in the same |
| // block as the first (anchor) node. |
| if (n2->owningBlock() != anchor_->owningBlock()) { |
| return false; |
| } |
| |
| if (n1->kind() != n2->kind() || |
| n1->outputs().size() != n2->outputs().size() || |
| n1->inputs().size() != n2->inputs().size() || |
| n1->numAttributes() != n2->numAttributes()) { |
| return false; |
| } |
| |
| // Add nodes to the map before calling matchValues to avoid infinite |
| // recursion. |
| nodes_map_[n1] = n2; |
| for (size_t i = 0; i < n1->outputs().size(); i++) { |
| if (!matchValues(n1->outputs()[i], n2->outputs()[i])) { |
| return false; |
| } |
| } |
| for (size_t i = 0; i < n1->inputs().size(); i++) { |
| if (!matchValues(n1->inputs()[i], n2->inputs()[i])) { |
| return false; |
| } |
| } |
| for (const Symbol& attr_name : n1->attributeNames()) { |
| if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) { |
| return false; |
| } |
| switch (n1->kindOf(attr_name)) { |
| case AttributeKind::s: |
| if (n1->s(attr_name) != n2->s(attr_name)) { |
| return false; |
| } |
| break; |
| case AttributeKind::f: |
| if (n1->f(attr_name) != n2->f(attr_name)) { |
| return false; |
| } |
| break; |
| case AttributeKind::i: |
| if (n1->i(attr_name) != n2->i(attr_name)) { |
| return false; |
| } |
| break; |
| default: |
| // Other attributes types not supported yet |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| /** |
| * Recursively try to match pattern with the actual graph starting from the |
| * exiting node in the pattern and anchor node in the actual graph. |
| */ |
| bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) { |
| nodes_map_.clear(); |
| values_map_.clear(); |
| anchor_ = anchor; |
| |
| const Node* bottom_node = *(pattern_.nodes().end()); |
| AT_ASSERT(bottom_node->inputs().size() == 1); |
| bottom_node = bottom_node->input()->node(); |
| |
| if (!matchNodes(bottom_node, anchor)) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| } // unnamed namespace |
| |
| // Main entry point for the subgraph matching. |
| std::vector<Match> findPatternMatches(const Graph& pattern, Graph& graph) { |
| AT_ASSERT(patternGraphIsValid(pattern)); |
| |
| SubgraphMatcher m(pattern); |
| std::vector<Match> matches; |
| std::stack<Block*> blocks_to_visit; |
| |
| // Iterate over all nodes in the graph (including nodes in subblocks) trying |
| // to match the pattern each node. |
| blocks_to_visit.push(graph.block()); |
| while (!blocks_to_visit.empty()) { |
| Block* block = blocks_to_visit.top(); |
| blocks_to_visit.pop(); |
| for (Node* n : block->nodes()) { |
| if (m.matchesSubgraphFromAnchorNode(n)) { |
| matches.push_back({n, m.nodes_map(), m.values_map()}); |
| } |
| for (Block* subblock : n->blocks()) { |
| blocks_to_visit.push(subblock); |
| } |
| } |
| } |
| return matches; |
| } |
| |
| } // namespace jit |
| } // namespace torch |