blob: 2051d45df41751d41cf5bf2c2a9c204b08962f80 [file] [log] [blame]
#include "ast.h"
#include "caffe2/opt/converter.h"
#include "nomnigraph/Transformations/SubgraphMatcher.h"
namespace nom {
namespace nql {
using Criteria = std::string;
using TestMatchGraph = nom::matcher::MatchGraph<nom::repr::NNGraph>;
using TestMatchPredicate = nom::matcher::MatchPredicate<nom::repr::NNGraph>;
// Each match is a struct of
// subgraph and map from the string used in the query to a NodeRef in the
// subgraph note: the maps are injective but not necessarily bijective -- if
// you use the same name in the query twice only one will be mapped.
//
// See `getMatches` to generate these structs.
struct MatchedSubgraph {
// A subgraph that contains at least all the nodes in matchMap
// This is the canonical match -- the matchMap is only a useful utility
nom::repr::NNGraph::SubgraphType subgraph;
// Provides safer access to matchMap with nicer semantics
nom::repr::NNGraph::NodeRef operator[](const std::string& key) const;
// Maps a variable name to a Node in a dataflow graph
std::map<std::string, nom::repr::NNGraph::NodeRef> matchMap;
};
/// \brief Main graph matcher interface.
///
/// This class solves a problem of finding a matching subgraph, which is
/// specified in a text form.
class GraphMatcher {
public:
/// \brief Initialize subgraph pattern from \p STR.
void initFromString(const char* str) {
genMatcherFromIRStr(str);
}
/// \brief Initialize subgraph patter from IR stored in file \p fname.
void initFromFile(const char* fname) {
genMatcherFromIRFile(fname);
}
/// \brief Try to find the pattern in the given graph \p DF and return true
/// if it was found.
bool findSubgraph(nom::repr::NNGraph& df) {
return doesMatch(df);
}
/// \brief Replace the found subgraph with another one.
void replaceSubgraphWith() {
CAFFE_THROW("Subgraph replacement is not implemented yet.");
}
/// \brief Return the matcher graph.
TestMatchGraph* getMatcherGraph() {
return &matchGraph_;
}
// TODO: Do we need this, or can we get it from getMatcherGraph?
TestMatchGraph::NodeRef getMatcher() {
return matchGraphRootNode_;
}
// \brief Return a mapping from IR variable name (std::string) to Node in the
// matched graph.
std::unordered_map<std::string, nom::repr::NNGraph::NodeRef> getMatchMap()
const {
return matchMap_;
}
// \brief Returns a vector of matches.
std::vector<MatchedSubgraph> getMatches(nom::repr::NNGraph& df) const;
private:
std::unordered_map<std::string, nom::repr::NNGraph::NodeRef> matchMap_;
std::unordered_map<std::string, TestMatchGraph::NodeRef> varMap_;
std::unordered_map<std::string, TestMatchGraph::NodeRef> callMap_;
TestMatchGraph matchGraph_;
TestMatchGraph::NodeRef matchGraphRootNode_;
bool syntaxIsValid_ = true;
bool doesMatch(nom::repr::NNGraph& df) {
if (!syntaxIsValid_) {
return false;
}
matchMap_.clear();
std::vector<nom::repr::NNGraph::NodeRef> Nodes = df.getMutableNodes();
for (auto& Node : Nodes) {
auto match =
matchGraph_.isSubgraphMatch(Node, matchGraphRootNode_, true, true);
if (match.isMatch()) {
// Fill the match map
auto subgraphMatcherMap = match.getMatchNodeMap();
for (auto p : varMap_) {
auto iter = subgraphMatcherMap->find(p.second);
if (iter != subgraphMatcherMap->end()) {
matchMap_[p.first] = iter->second;
}
}
for (auto p : callMap_) {
auto iter = subgraphMatcherMap->find(p.second);
if (iter != subgraphMatcherMap->end()) {
matchMap_[p.first] = iter->second;
}
}
return true;
}
}
return false;
}
TestMatchGraph::NodeRef genMatcherFromIRFile(const char* fname);
TestMatchGraph::NodeRef genMatcherFromIRStr(const char* str);
TestMatchGraph::NodeRef genMatcherFromASTGraph(ASTGraph* ast);
TestMatchGraph::NodeRef genMatcherFromASTStmt(ASTStmt* stmt);
TestMatchGraph::NodeRef genMatcherFromASTExpr(ASTExpr* expr, bool insertTemp);
};
// Node matches a criteria (string) if the data string is the same as the
// criteria. Special case: "*" will match any thing.
TestMatchPredicate testMatchPredicate(const Criteria& criteria);
// \brief Return a short string name for the given \param node.
// The function works with both tensors and operators.
std::string getNodeName(const nom::repr::NNGraph::NodeRef);
// \brief Return a string representing the given graph \param g.
// The returned string is a valid NQL query.
std::string convertToNQLString(nom::repr::NNGraph&);
void deallocTokenStrings();
} // namespace nql
} // namespace nom