SubgraphMatcher: add attributes support. (#20602)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20602
ghimport-source-id: fa3225bb5d70729d6a1bcf88295d031707d986a1
Differential Revision: D15377635
Pulled By: ZolotukhinM
fbshipit-source-id: ebd385e7b9436429d0ad76ed3d932925a29f6456
diff --git a/test/cpp/jit/test_subgraph_matcher.h b/test/cpp/jit/test_subgraph_matcher.h
index ee157de..f33e19e 100644
--- a/test/cpp/jit/test_subgraph_matcher.h
+++ b/test/cpp/jit/test_subgraph_matcher.h
@@ -361,6 +361,80 @@
AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
}
+void testAttributes() {
+ Graph graph;
+ script::parseIR(
+ R"IR(
+graph(%0):
+ %a = a::a[isattr=[1,2]](%0)
+ %b = a::b[intattr=10, floatattr=3.14](%0)
+ %c = a::c[myattr="qqq"](%a, %b)
+ return (%c))IR",
+ &graph);
+
+ {
+ Graph pattern;
+ script::parseIR(
+ R"IR(
+graph(%a, %b):
+ %c = a::c[myattr="qqq"](%a, %b)
+ return (%c))IR",
+ &pattern);
+ AT_ASSERT(!findPatternMatches(pattern, graph).empty());
+ }
+ {
+ Graph pattern;
+ script::parseIR(
+ R"IR(
+graph(%a, %b):
+ %c = a::c[myattr="zzz"](%a, %b)
+ return (%c))IR",
+ &pattern);
+ AT_ASSERT(findPatternMatches(pattern, graph).empty());
+ }
+ {
+ Graph pattern;
+ script::parseIR(
+ R"IR(
+graph(%0):
+ %b = a::b[extraattr=10](%0)
+ return (%b))IR",
+ &pattern);
+ AT_ASSERT(findPatternMatches(pattern, graph).empty());
+ }
+ {
+ Graph pattern;
+ script::parseIR(
+ R"IR(
+graph(%0):
+ %b = a::b[intattr=10, floatattr=3.14](%0)
+ return (%b))IR",
+ &pattern);
+ AT_ASSERT(!findPatternMatches(pattern, graph).empty());
+ }
+ {
+ Graph pattern;
+ script::parseIR(
+ R"IR(
+graph(%0):
+ %b = a::b[intattr=10, floatattr=3.14, strattr="rrr"](%0)
+ return (%b))IR",
+ &pattern);
+ AT_ASSERT(findPatternMatches(pattern, graph).empty());
+ }
+ {
+ Graph pattern;
+ script::parseIR(
+ R"IR(
+graph(%0):
+ %a = a::a[isattr=[1,2]](%0)
+ return (%a))IR",
+ &pattern);
+ // Lists are not supported yet, thus we shouldn't match for now.
+ AT_ASSERT(findPatternMatches(pattern, graph).empty());
+ }
+}
+
void testBadPattern() {
Graph graph, pattern1, pattern2;
script::parseIR(
@@ -405,6 +479,7 @@
testOverlappingMatches();
testMatchInBasicBlocks1();
testMatchInBasicBlocks2();
+ testAttributes();
testBadPattern();
}
diff --git a/torch/csrc/jit/subgraph_matcher.cpp b/torch/csrc/jit/subgraph_matcher.cpp
index 9b5e15a..58bd6b0 100644
--- a/torch/csrc/jit/subgraph_matcher.cpp
+++ b/torch/csrc/jit/subgraph_matcher.cpp
@@ -123,7 +123,8 @@
if (n1->kind() != n2->kind() ||
n1->outputs().size() != n2->outputs().size() ||
- n1->inputs().size() != n2->inputs().size()) {
+ n1->inputs().size() != n2->inputs().size() ||
+ n1->numAttributes() != n2->numAttributes()) {
return false;
}
@@ -140,6 +141,31 @@
return false;
}
}
+ for (const Symbol& attr_name : n1->attributeNames()) {
+ if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) {
+ return false;
+ }
+ switch (n1->kindOf(attr_name)) {
+ case AttributeKind::s:
+ if (n1->s(attr_name) != n2->s(attr_name)) {
+ return false;
+ }
+ break;
+ case AttributeKind::f:
+ if (n1->f(attr_name) != n2->f(attr_name)) {
+ return false;
+ }
+ break;
+ case AttributeKind::i:
+ if (n1->i(attr_name) != n2->i(attr_name)) {
+ return false;
+ }
+ break;
+ default:
+ // Other attributes types not supported yet
+ return false;
+ }
+ }
return true;
}
@@ -167,9 +193,7 @@
} // unnamed namespace
// Main entry point for the subgraph matching.
-std::vector<Match> findPatternMatches(
- const Graph& pattern,
- Graph& graph) {
+std::vector<Match> findPatternMatches(const Graph& pattern, Graph& graph) {
AT_ASSERT(patternGraphIsValid(pattern));
SubgraphMatcher m(pattern);