[jit][subgraph_matcher] Enable regex matching for string attributes of node (#39454)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39454
Test Plan: Imported from OSS
Differential Revision: D21876224
fbshipit-source-id: c0fdff3a4532d2a73b222353e2cad6cf52444697
diff --git a/test/cpp/jit/test_subgraph_matcher.cpp b/test/cpp/jit/test_subgraph_matcher.cpp
index ca03af0..2e398db 100644
--- a/test/cpp/jit/test_subgraph_matcher.cpp
+++ b/test/cpp/jit/test_subgraph_matcher.cpp
@@ -467,6 +467,16 @@
// Lists are not supported yet, thus we shouldn't match for now.
AT_ASSERT(findPatternMatches(pattern, graph).empty());
}
+ {
+ Graph pattern;
+ parseIR(
+ R"IR(
+graph(%a, %b):
+ %c = a::c[myattr="q.*"](%a, %b)
+ return (%c))IR",
+ &pattern);
+ AT_ASSERT(!findPatternMatches(pattern, graph).empty());
+ }
}
void testBadPattern() {
diff --git a/torch/csrc/jit/ir/subgraph_matcher.cpp b/torch/csrc/jit/ir/subgraph_matcher.cpp
index 3a74f51..8319543 100644
--- a/torch/csrc/jit/ir/subgraph_matcher.cpp
+++ b/torch/csrc/jit/ir/subgraph_matcher.cpp
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/jit_log.h>
+#include <regex>
#include <stack>
namespace torch {
@@ -133,7 +134,7 @@
}
switch (n1->kindOf(attr_name)) {
case AttributeKind::s:
- if (n1->s(attr_name) != n2->s(attr_name)) {
+ if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) {
GRAPH_DEBUG(
"Nodes did not match because attribute '",
attr_name.toQualString(),