| #include "graphmatcher.h" | 
 | #include "ast.h" | 
 | #include "nomnigraph/Graph/Algorithms.h" | 
 |  | 
 | #include <mutex> | 
 |  | 
 | static std::mutex mtx_; | 
 |  | 
 | namespace nom { | 
 | namespace nql { | 
 | using namespace nom::repr; | 
 |  | 
 | NNGraph::NodeRef MatchedSubgraph::operator[](const std::string& key) const { | 
 |   auto search = matchMap.find(key); | 
 |   CAFFE_ENFORCE( | 
 |       search != matchMap.end(), "Could not find key in map of matches:", key); | 
 |   return search->second; | 
 | } | 
 |  | 
 | TestMatchGraph::NodeRef GraphMatcher::genMatcherFromASTExpr( | 
 |     ASTExpr* expr, | 
 |     bool insertTemp = false) { | 
 |   if (!expr->isCall()) { | 
 |     if (expr->starInputs()) { | 
 |       return matchGraph_.createNode(std::move( | 
 |           testMatchPredicate(Criteria("*")).starCount().nonTerminal())); | 
 |     } | 
 |     if (!varMap_.count(expr->name)) { | 
 |       varMap_[expr->name] = matchGraph_.createNode( | 
 |           std::move(testMatchPredicate(Criteria("*")).nonTerminal())); | 
 |     } | 
 |     return varMap_[expr->name]; | 
 |   } | 
 |  | 
 |   std::vector<TestMatchGraph::NodeRef> children; | 
 |   for (auto child : expr->children) { | 
 |     children.push_back(genMatcherFromASTExpr(child, true)); | 
 |   } | 
 |  | 
 |   auto res = matchGraph_.createNode(testMatchPredicate(Criteria(expr->name))); | 
 |   callMap_[expr->name] = res; | 
 |   for (auto child : children) { | 
 |     matchGraph_.createEdge(child, res); | 
 |   } | 
 |  | 
 |   if (insertTemp) { | 
 |     auto temp = matchGraph_.createNode(testMatchPredicate(Criteria("*"))); | 
 |     matchGraph_.createEdge(res, temp); | 
 |     res = temp; | 
 |   } | 
 |  | 
 |   return res; | 
 | } | 
 |  | 
 | TestMatchGraph::NodeRef GraphMatcher::genMatcherFromASTStmt(ASTStmt* stmt) { | 
 |   auto right = genMatcherFromASTExpr(stmt->rhs); | 
 |   auto res = right; | 
 |   /* For cases like | 
 |    %x, %y = Foo(%z) | 
 |    for now we just say that both %x and %y are defined by node Foo, we don't | 
 |    distinguish them (i.e. we don't keep any information about their order. */ | 
 |   for (auto v : stmt->lhs) { | 
 |     res = matchGraph_.createNode(testMatchPredicate(Criteria("*"))); | 
 |     matchGraph_.createEdge(right, res); | 
 |     varMap_[v] = res; | 
 |   } | 
 |   return res; | 
 | } | 
 |  | 
 | void deallocTokenStrings() { | 
 |   for (auto p : tokens) { | 
 |     delete (std::string*)p; | 
 |   } | 
 |   tokens.clear(); | 
 |  | 
 |   for (auto p : tokenVectors) { | 
 |     delete (std::vector<void*>*)p; | 
 |   } | 
 |   tokenVectors.clear(); | 
 | } | 
 |  | 
 | TestMatchGraph::NodeRef GraphMatcher::genMatcherFromASTGraph(ASTGraph* ast) { | 
 |   matchGraph_ = TestMatchGraph(); | 
 |   // TODO: Cleanup this. | 
 |   TestMatchGraph::NodeRef last = nullptr; | 
 |   if (ast->stmts.empty()) { | 
 |     syntaxIsValid_ = false; // Temporary solution, which works because we don't | 
 |                             // allow empty graphs. | 
 |   } | 
 |  | 
 |   for (auto stmt : ast->stmts) { | 
 |     auto r = genMatcherFromASTStmt(stmt); | 
 |     if (r) { | 
 |       last = r; | 
 |     } | 
 |   } | 
 |  | 
 |   return last; | 
 | } | 
 |  | 
 | TestMatchGraph::NodeRef GraphMatcher::genMatcherFromIRFile(const char* fname) { | 
 |   std::lock_guard<std::mutex> lock(mtx_); | 
 |   ASTGraph g; | 
 |   parseFile(fname, &g); | 
 |   matchGraphRootNode_ = genMatcherFromASTGraph(&g); | 
 |   deallocTokenStrings(); | 
 |   return matchGraphRootNode_; | 
 | } | 
 |  | 
 | TestMatchGraph::NodeRef GraphMatcher::genMatcherFromIRStr(const char* str) { | 
 |   std::lock_guard<std::mutex> lock(mtx_); | 
 |   ASTGraph g; | 
 |   parseString(str, &g); | 
 |   matchGraphRootNode_ = genMatcherFromASTGraph(&g); | 
 |   deallocTokenStrings(); | 
 |   return matchGraphRootNode_; | 
 | } | 
 |  | 
 | TestMatchPredicate testMatchPredicate(const Criteria& criteria) { | 
 |   auto predicate = | 
 |       TestMatchPredicate([criteria](nom::repr::NNGraph::NodeRef nodeRef) { | 
 |         std::string nodeLabel = getNodeName(nodeRef); | 
 |         return (criteria == "*" || criteria == nodeLabel); | 
 |       }); | 
 |   predicate.setDebugString(criteria); | 
 |   return predicate; | 
 | } | 
 |  | 
 | // Helper function for convertToNQLString function. | 
 | // Given a node and a renameMap return the unique name for this node. | 
 | static std::string getNameForBlob( | 
 |     NNGraph::NodeRef node, | 
 |     const std::unordered_map<NNGraph::NodeRef, std::string>& renameMap) { | 
 |   if (renameMap.count(node)) { | 
 |     return renameMap.at(node); | 
 |   } | 
 |   return getNodeName(node); | 
 | } | 
 |  | 
 | // Helper function for convertToNQLString function. | 
 | // Given a node and a renameMap return a string representing the node, which | 
 | // looks something like: | 
 | //   %a = Op(%b, %c, %d) | 
 | static const std::string getNQLStringForBlob( | 
 |     NNGraph::NodeRef node, | 
 |     const std::unordered_map<NNGraph::NodeRef, std::string>& renameMap) { | 
 |   if (!nn::is<Data>(node) || !nn::hasProducer(node)) { | 
 |     return ""; | 
 |   } | 
 |   NNGraph::NodeRef defOp = nn::getProducer(node); | 
 |  | 
 |   std::string result = | 
 |       getNameForBlob(node, renameMap) + " = " + getNodeName(defOp) + "("; | 
 |   int i = 0; | 
 |   for (auto inputTensor : nn::getInputs(defOp)) { | 
 |     if (i) { | 
 |       result += ", "; | 
 |     } | 
 |     result += getNameForBlob(inputTensor, renameMap); | 
 |     i++; | 
 |   } | 
 |   result += ")"; | 
 |   return result; | 
 | } | 
 |  | 
 | // Helper function for convertToNQLString function. | 
 | // It takes a list of nodes and returns a map node->unique_name. The new names | 
 | // are based on the existing ones, but are also unique. | 
 | static std::unordered_map<NNGraph::NodeRef, std::string> computeDedupRenameMap( | 
 |     const std::vector<NNGraph::NodeRef>& nodes) { | 
 |   std::unordered_map<NNGraph::NodeRef, std::string> renameMap; | 
 |   std::unordered_set<std::string> takenNames; | 
 |   takenNames.clear(); | 
 |   for (auto node : nodes) { | 
 |     std::string name = getNodeName(node); | 
 |     if (!isa<Data>(node->data())) { | 
 |       continue; | 
 |     } | 
 |     std::string newName = name; | 
 |     int dedupCounter = 0; | 
 |     while (takenNames.count(newName)) { | 
 |       newName = name + "_" + caffe2::to_string(dedupCounter); | 
 |       dedupCounter++; | 
 |     } | 
 |     renameMap[node] = newName; | 
 |     takenNames.insert(newName); | 
 |   } | 
 |   return renameMap; | 
 | } | 
 |  | 
 | std::vector<MatchedSubgraph> GraphMatcher::getMatches( | 
 |     nom::repr::NNGraph& df) const { | 
 |   std::vector<MatchedSubgraph> matches; | 
 |   if (!syntaxIsValid_) { | 
 |     return matches; | 
 |   } | 
 |   // Attempt to match at each node | 
 |   for (const auto& node : df.getMutableNodes()) { | 
 |     auto match = matchGraph_.isSubgraphMatch(node, matchGraphRootNode_, true); | 
 |     if (match.isMatch()) { | 
 |       MatchedSubgraph ms; | 
 |       ms.subgraph = *match.getMatchedSubgraph(); | 
 |       // This is a map from the the internal TestMatchGraph to the nodes in the | 
 |       // NNGraph | 
 |       auto match_graph_map = match.getMatchNodeMap(); | 
 |       // We iterate through the "varMap_" map (string -> | 
 |       // TestMatchGraph::NodeRef) to generate string -> NNGraph::NodeRef | 
 |       for (auto p : varMap_) { | 
 |         auto iter = match_graph_map->find(p.second); | 
 |         if (iter != match_graph_map->end()) { | 
 |           ms.matchMap[p.first] = iter->second; | 
 |         } | 
 |       } | 
 |       for (auto p : callMap_) { | 
 |         auto iter = match_graph_map->find(p.second); | 
 |         if (iter != match_graph_map->end()) { | 
 |           ms.matchMap[p.first] = iter->second; | 
 |         } | 
 |       } | 
 |       matches.emplace_back(ms); | 
 |     } | 
 |   } | 
 |   return matches; | 
 | } | 
 |  | 
 | // \brief Return a short string name for the given \param node. | 
 | // The function works with both tensors and operators. | 
 | std::string getNodeName(const NNGraph::NodeRef node) { | 
 |   if (!node) { | 
 |     return ""; | 
 |   } | 
 |   if (nn::is<NeuralNetOperator>(node)) { | 
 |     if (auto* op = nn::get<NeuralNetOperator>(node)) { | 
 |       return op->getName(); | 
 |     } | 
 |   } | 
 |   if (nn::is<NeuralNetData>(node)) { | 
 |     if (auto tensor = nn::get<NeuralNetData>(node)) { | 
 |       return "%" + tensor->getName(); | 
 |     } | 
 |   } | 
 |   return ""; | 
 | } | 
 |  | 
 | // \brief Return a string representing the given graph \param g. | 
 | // The returned string is a valid NQL query. | 
 | std::string convertToNQLString(NNGraph& g) { | 
 |   // Order nodes in a topological order. | 
 |   // TODO: Currently tarjans mutates the graph, and that's the only reason we | 
 |   // are not using const reference for `g`. We need to fix tarjans so that it | 
 |   // doesn't mutate the graph and use const reference in this function too. | 
 |   auto topoMatch = nom::algorithm::tarjans(&g); | 
 |   std::vector<NNGraph::NodeRef> nodes; | 
 |   int sccNum = 0; | 
 |   for (auto scc : topoMatch) { | 
 |     sccNum++; | 
 |     for (auto node : scc.getNodes()) { | 
 |       nodes.emplace_back(node); | 
 |     } | 
 |   } | 
 |   std::reverse(nodes.begin(), nodes.end()); | 
 |  | 
 |   // Different nodes might have the same name. We want to change that so that | 
 |   // they are distinguishable by the name. NQL assumes that names are unique. | 
 |   std::unordered_map<NNGraph::NodeRef, std::string> renameMap = | 
 |       computeDedupRenameMap(nodes); | 
 |  | 
 |   // Going from top to bottom (nodes are in topological order), print all | 
 |   // nodes. | 
 |   std::string result = "def nn {\n"; | 
 |   for (auto node : nodes) { | 
 |     std::string r = getNQLStringForBlob(node, renameMap); | 
 |     if (!r.empty()) { | 
 |       result += "  " + r + "\n"; | 
 |     } | 
 |   } | 
 |   result += "}\n"; | 
 |   return result; | 
 | } | 
 | }; // namespace nql | 
 | }; // namespace nom |