SubgraphMatcher: add attributes support. (#20602)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20602
ghimport-source-id: fa3225bb5d70729d6a1bcf88295d031707d986a1

Differential Revision: D15377635

Pulled By: ZolotukhinM

fbshipit-source-id: ebd385e7b9436429d0ad76ed3d932925a29f6456
diff --git a/test/cpp/jit/test_subgraph_matcher.h b/test/cpp/jit/test_subgraph_matcher.h
index ee157de..f33e19e 100644
--- a/test/cpp/jit/test_subgraph_matcher.h
+++ b/test/cpp/jit/test_subgraph_matcher.h
@@ -361,6 +361,80 @@
   AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0);
 }
 
+void testAttributes() {
+  Graph graph;
+  script::parseIR(
+      R"IR(
+graph(%0):
+  %a = a::a[isattr=[1,2]](%0)
+  %b = a::b[intattr=10, floatattr=3.14](%0)
+  %c = a::c[myattr="qqq"](%a, %b)
+  return (%c))IR",
+      &graph);
+
+  {
+    Graph pattern;
+    script::parseIR(
+        R"IR(
+graph(%a, %b):
+  %c = a::c[myattr="qqq"](%a, %b)
+  return (%c))IR",
+        &pattern);
+    AT_ASSERT(!findPatternMatches(pattern, graph).empty());
+  }
+  {
+    Graph pattern;
+    script::parseIR(
+        R"IR(
+graph(%a, %b):
+  %c = a::c[myattr="zzz"](%a, %b)
+  return (%c))IR",
+        &pattern);
+    AT_ASSERT(findPatternMatches(pattern, graph).empty());
+  }
+  {
+    Graph pattern;
+    script::parseIR(
+        R"IR(
+graph(%0):
+  %b = a::b[extraattr=10](%0)
+  return (%b))IR",
+        &pattern);
+    AT_ASSERT(findPatternMatches(pattern, graph).empty());
+  }
+  {
+    Graph pattern;
+    script::parseIR(
+        R"IR(
+graph(%0):
+  %b = a::b[intattr=10, floatattr=3.14](%0)
+  return (%b))IR",
+        &pattern);
+    AT_ASSERT(!findPatternMatches(pattern, graph).empty());
+  }
+  {
+    Graph pattern;
+    script::parseIR(
+        R"IR(
+graph(%0):
+  %b = a::b[intattr=10, floatattr=3.14, strattr="rrr"](%0)
+  return (%b))IR",
+        &pattern);
+    AT_ASSERT(findPatternMatches(pattern, graph).empty());
+  }
+  {
+    Graph pattern;
+    script::parseIR(
+        R"IR(
+graph(%0):
+  %a = a::a[isattr=[1,2]](%0)
+  return (%a))IR",
+        &pattern);
+    // Lists are not supported yet, thus we shouldn't match for now.
+    AT_ASSERT(findPatternMatches(pattern, graph).empty());
+  }
+}
+
 void testBadPattern() {
   Graph graph, pattern1, pattern2;
   script::parseIR(
@@ -405,6 +479,7 @@
   testOverlappingMatches();
   testMatchInBasicBlocks1();
   testMatchInBasicBlocks2();
+  testAttributes();
   testBadPattern();
 }
 
diff --git a/torch/csrc/jit/subgraph_matcher.cpp b/torch/csrc/jit/subgraph_matcher.cpp
index 9b5e15a..58bd6b0 100644
--- a/torch/csrc/jit/subgraph_matcher.cpp
+++ b/torch/csrc/jit/subgraph_matcher.cpp
@@ -123,7 +123,8 @@
 
   if (n1->kind() != n2->kind() ||
       n1->outputs().size() != n2->outputs().size() ||
-      n1->inputs().size() != n2->inputs().size()) {
+      n1->inputs().size() != n2->inputs().size() ||
+      n1->numAttributes() != n2->numAttributes()) {
     return false;
   }
 
@@ -140,6 +141,31 @@
       return false;
     }
   }
+  for (const Symbol& attr_name : n1->attributeNames()) {
+    if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) {
+      return false;
+    }
+    switch (n1->kindOf(attr_name)) {
+      case AttributeKind::s:
+        if (n1->s(attr_name) != n2->s(attr_name)) {
+          return false;
+        }
+        break;
+      case AttributeKind::f:
+        if (n1->f(attr_name) != n2->f(attr_name)) {
+          return false;
+        }
+        break;
+      case AttributeKind::i:
+        if (n1->i(attr_name) != n2->i(attr_name)) {
+          return false;
+        }
+        break;
+      default:
+        // Other attributes types not supported yet
+        return false;
+    }
+  }
 
   return true;
 }
@@ -167,9 +193,7 @@
 } // unnamed namespace
 
 // Main entry point for the subgraph matching.
-std::vector<Match> findPatternMatches(
-    const Graph& pattern,
-    Graph& graph) {
+std::vector<Match> findPatternMatches(const Graph& pattern, Graph& graph) {
   AT_ASSERT(patternGraphIsValid(pattern));
 
   SubgraphMatcher m(pattern);