|  | #include "test/cpp/jit/test_base.h" | 
|  | #include "test/cpp/jit/test_utils.h" | 
|  | #include "torch/csrc/jit/subgraph_matcher.h" | 
|  |  | 
|  | namespace torch { | 
|  | namespace jit { | 
|  |  | 
|  | void testTrivial1() { | 
|  | Graph graph, pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %a = a::aaa(%0) | 
|  | return (%a))IR", | 
|  | &graph); | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %x = a::aaa(%0) | 
|  | return (%x))IR", | 
|  | &pattern); | 
|  | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); | 
|  | } | 
|  |  | 
|  | void testTrivial2() { | 
|  | Graph graph; | 
|  | auto* g_in = graph.addInput(); | 
|  | auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/ 1)); | 
|  | g_tanh->addInput(g_in); | 
|  | graph.registerOutput(g_tanh->output()); | 
|  |  | 
|  | Graph pattern; | 
|  | auto* p_in = pattern.addInput(); | 
|  | auto* p_tanh = pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/ 1)); | 
|  | p_tanh->addInput(p_in); | 
|  | pattern.registerOutput(p_tanh->output()); | 
|  |  | 
|  | auto matches = findPatternMatches(pattern, graph); | 
|  | AT_ASSERT(matches.size() == 1); | 
|  | for (const Match& m : matches) { | 
|  | AT_ASSERT(m.values_map.at(p_in) == g_in); | 
|  | AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output()); | 
|  | AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh); | 
|  | } | 
|  | } | 
|  |  | 
|  | void testTrivial3() { | 
|  | Graph graph, pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %a = a::a(%0) | 
|  | %b = a::b(%0) | 
|  | %c = a::c(%a, %b) | 
|  | return (%c))IR", | 
|  | &graph); | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%a, %b): | 
|  | %c = a::c(%a, %b) | 
|  | return (%c))IR", | 
|  | &pattern); | 
|  | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); | 
|  | } | 
|  |  | 
|  | void testTrivial4() { | 
|  | Graph graph; | 
|  | auto* g_in0 = graph.addInput(); | 
|  | auto* g_in1 = graph.addInput(); | 
|  | auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/ 1)); | 
|  | g_mul->addInput(g_in0); | 
|  | g_mul->addInput(g_in1); | 
|  | graph.registerOutput(g_mul->output()); | 
|  |  | 
|  | Graph pattern; | 
|  | auto* p_in0 = pattern.addInput(); | 
|  | auto* p_in1 = pattern.addInput(); | 
|  | auto* p_mul = pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/ 1)); | 
|  | p_mul->addInput(p_in0); | 
|  | p_mul->addInput(p_in1); | 
|  | pattern.registerOutput(p_mul->output()); | 
|  |  | 
|  | auto matches = findPatternMatches(pattern, graph); | 
|  | AT_ASSERT(matches.size() == 1); | 
|  | for (const Match& m : matches) { | 
|  | AT_ASSERT(m.values_map.at(p_in0) == g_in0); | 
|  | AT_ASSERT(m.values_map.at(p_in1) == g_in1); | 
|  | AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output()); | 
|  | AT_ASSERT(m.nodes_map.at(p_mul) == g_mul); | 
|  | } | 
|  | } | 
|  |  | 
|  | void testLinear1() { | 
|  | Graph graph, pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %a = a::aaa(%0) | 
|  | %b = b::bbb(%a) | 
|  | %c = c::ccc(%b) | 
|  | %d = d::ddd(%c) | 
|  | %a = a::aaa(%0) | 
|  | return (%d))IR", | 
|  | &graph); | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %x = b::bbb(%0) | 
|  | %y = c::ccc(%x) | 
|  | return (%y))IR", | 
|  | &pattern); | 
|  | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); | 
|  | } | 
|  |  | 
|  | void testLinear2() { | 
|  | Graph graph; | 
|  | auto* g_in = graph.addInput(); | 
|  |  | 
|  | auto* g_tanh = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/ 1)); | 
|  | g_tanh->addInput(g_in); | 
|  |  | 
|  | auto* g_tanh2 = graph.insertNode(graph.create(aten::tanh, /*num_outputs =*/ 1)); | 
|  | g_tanh2->addInput(g_tanh->output()); | 
|  |  | 
|  | graph.registerOutput(g_tanh2->output()); | 
|  |  | 
|  | Graph pattern; | 
|  | auto* p_in = pattern.addInput(); | 
|  |  | 
|  | auto* p_tanh = pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/ 1)); | 
|  | p_tanh->addInput(p_in); | 
|  |  | 
|  | auto* p_tanh2 = pattern.insertNode(pattern.create(aten::tanh, /*num_outputs =*/ 1)); | 
|  | p_tanh2->addInput(p_tanh->output()); | 
|  |  | 
|  | pattern.registerOutput(p_tanh2->output()); | 
|  |  | 
|  | auto matches = findPatternMatches(pattern, graph); | 
|  | AT_ASSERT(matches.size() == 1); | 
|  | for (const Match& m : matches) { | 
|  | AT_ASSERT(m.values_map.at(p_in) == g_in); | 
|  | AT_ASSERT(m.values_map.at(p_tanh->output()) == g_tanh->output()); | 
|  | AT_ASSERT(m.values_map.at(p_tanh2->output()) == g_tanh2->output()); | 
|  | AT_ASSERT(m.nodes_map.at(p_tanh) == g_tanh); | 
|  | AT_ASSERT(m.nodes_map.at(p_tanh2) == g_tanh2); | 
|  | } | 
|  | } | 
|  |  | 
|  | /** | 
|  | * Test diamond pattern: | 
|  | * | 
|  | *     ooo | 
|  | *      | | 
|  | *     aaa | 
|  | *    /   \ | 
|  | *  bbb   ccc | 
|  | *     \ / | 
|  | *     ddd | 
|  | *      | | 
|  | *     eee | 
|  | */ | 
|  | void testDiamond1() { | 
|  | Graph graph, pattern1, pattern2; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %o = o::ooo(%0) | 
|  | %a = a::aaa(%o) | 
|  | %b = b::bbb(%a) | 
|  | %c = c::ccc(%a) | 
|  | %d = d::ddd(%b, %c) | 
|  | %e = e::eee(%d) | 
|  | return (%e))IR", | 
|  | &graph); | 
|  |  | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %a = a::aaa(%0) | 
|  | %b = b::bbb(%a) | 
|  | %c = c::ccc(%a) | 
|  | %d = d::ddd(%b, %c) | 
|  | return (%d))IR", | 
|  | &pattern1); | 
|  | AT_ASSERT(!findPatternMatches(pattern1, graph).empty()); | 
|  |  | 
|  | // Check that order of nodes inside the diamond does not affect the result | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %a = a::aaa(%0) | 
|  | %c = c::ccc(%a) | 
|  | %b = b::bbb(%a) | 
|  | %d = d::ddd(%b, %c) | 
|  | return (%d))IR", | 
|  | &pattern2); | 
|  | AT_ASSERT(!findPatternMatches(pattern2, graph).empty()); | 
|  | } | 
|  |  | 
|  | /** | 
|  | * Test diamond pattern: | 
|  | * | 
|  | *     i0 | 
|  | *      | | 
|  | *    chunk | 
|  | *    /   \ | 
|  | * os[0] os[1] | 
|  | *     \ / | 
|  | *      * | 
|  | *      | | 
|  | *      o1 | 
|  | */ | 
|  | void testDiamond2() { | 
|  | Graph graph; | 
|  | auto* g_in = graph.addInput(); | 
|  |  | 
|  | auto* g_chunk = graph.insertNode(graph.create(prim::ConstantChunk, /*num_outputs =*/ 2)); | 
|  | g_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0); | 
|  | g_chunk->addInput(g_in); | 
|  |  | 
|  | auto* g_mul = graph.insertNode(graph.create(aten::mul, /*num_outputs =*/ 1)); | 
|  | g_mul->addInput(g_chunk->outputs()[0]); | 
|  | g_mul->addInput(g_chunk->outputs()[1]); | 
|  | graph.registerOutput(g_mul->output()); | 
|  |  | 
|  | Graph pattern; | 
|  | auto* p_in = pattern.addInput(); | 
|  | auto* p_chunk = pattern.insertNode(pattern.create(prim::ConstantChunk, /*num_outputs =*/ 2)); | 
|  | p_chunk->i_(attr::chunks, 2)->i_(attr::dim, 0); | 
|  | p_chunk->addInput(p_in); | 
|  |  | 
|  | auto* p_mul = pattern.insertNode(pattern.create(aten::mul, /*num_outputs =*/ 1)); | 
|  | p_mul->addInput(p_chunk->outputs()[0]); | 
|  | p_mul->addInput(p_chunk->outputs()[1]); | 
|  | pattern.registerOutput(p_mul->output()); | 
|  |  | 
|  | auto matches = findPatternMatches(pattern, graph); | 
|  | AT_ASSERT(matches.size() == 1); | 
|  | for (const Match& m : matches) { | 
|  | AT_ASSERT(m.values_map.at(p_in) == g_in); | 
|  | AT_ASSERT(m.values_map.at(p_chunk->outputs()[0]) == g_chunk->outputs()[0]); | 
|  | AT_ASSERT(m.values_map.at(p_chunk->outputs()[1]) == g_chunk->outputs()[1]); | 
|  | AT_ASSERT(m.values_map.at(p_mul->output()) == g_mul->output()); | 
|  | AT_ASSERT(m.nodes_map.at(p_mul) == g_mul); | 
|  | } | 
|  | } | 
|  |  | 
|  | void testXPattern() { | 
|  | Graph graph, pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0, %1): | 
|  | %b = b::bbb(%0) | 
|  | %c = c::ccc(%1) | 
|  | %x = x::xxx(%b, %c) | 
|  | %e = e::eee(%x) | 
|  | %f = f::fff(%x) | 
|  | %g = g::ggg(%e, %f) | 
|  | return (%g))IR", | 
|  | &graph); | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0, %1): | 
|  | %b = b::bbb(%0) | 
|  | %c = c::ccc(%1) | 
|  | %x = x::xxx(%b, %c) | 
|  | %e = e::eee(%x) | 
|  | %f = f::fff(%x) | 
|  | %g = g::ggg(%e, %f) | 
|  | return (%g))IR", | 
|  | &pattern); | 
|  | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); | 
|  | } | 
|  |  | 
|  | void testMultipleMatches() { | 
|  | Graph graph, pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%t0): | 
|  | %t1 = a::aaa(%t0) | 
|  | %t2 = a::aaa(%t1) | 
|  | %t3 = a::aaa(%t2) | 
|  | %t4 = a::aaa(%t3) | 
|  | return (%t4))IR", | 
|  | &graph); | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%t0): | 
|  | %t1 = a::aaa(%t0) | 
|  | return (%t1))IR", | 
|  | &pattern); | 
|  | auto matches = findPatternMatches(pattern, graph); | 
|  | AT_ASSERT(matches.size() == 4); | 
|  | } | 
|  |  | 
|  | void testOverlappingMatches() { | 
|  | Graph graph, pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%t0): | 
|  | %t1 = a::aaa(%t0) | 
|  | %t2 = a::aaa(%t1) | 
|  | %t3 = a::aaa(%t2) | 
|  | %t4 = a::aaa(%t3) | 
|  | return (%t4))IR", | 
|  | &graph); | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%t0): | 
|  | %t1 = a::aaa(%t0) | 
|  | %t2 = a::aaa(%t1) | 
|  | return (%t2))IR", | 
|  | &pattern); | 
|  | auto matches = findPatternMatches(pattern, graph); | 
|  | AT_ASSERT(matches.size() == 3); | 
|  | } | 
|  |  | 
|  | void testMatchInBasicBlocks1() { | 
|  | Graph graph; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%a, %b, %c): | 
|  | %d = aten::mul(%a, %b) | 
|  | %x = prim::If(%c) | 
|  | block0(): | 
|  | %x1 = aten::mul(%a, %d) | 
|  | -> (%x1) | 
|  | block1(): | 
|  | %x2 = aten::mul(%b, %d) | 
|  | -> (%x2) | 
|  | return (%x))IR", | 
|  | &graph); | 
|  |  | 
|  | // Ensure the matches don't cross basic block boundaries | 
|  | Graph pattern0; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%x, %y): | 
|  | %z = aten::mul(%x, %y) | 
|  | return (%z))IR", | 
|  | &pattern0); | 
|  | AT_ASSERT(findPatternMatches(pattern0, graph).size() == 3); | 
|  |  | 
|  | Graph pattern1; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%x, %y): | 
|  | %z1 = aten::mul(%x, %y) | 
|  | %z2 = aten::mul(%y, %z1) | 
|  | return (%z2))IR", | 
|  | &pattern1); | 
|  | AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0); | 
|  | } | 
|  |  | 
|  | void testMatchInBasicBlocks2() { | 
|  | Graph graph; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%a, %b): | 
|  | %x = my::mul(%a, %b) | 
|  | %y = my::node_with_subblock() | 
|  | block0(): | 
|  | %z = my::mul(%b, %x) | 
|  | -> (%z) | 
|  | return (%y))IR", | 
|  | &graph); | 
|  |  | 
|  | // Check that we can match both mul ops | 
|  | Graph pattern0; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%x, %y): | 
|  | %z = my::mul(%x, %y) | 
|  | return (%z))IR", | 
|  | &pattern0); | 
|  | AT_ASSERT(findPatternMatches(pattern0, graph).size() == 2); | 
|  |  | 
|  | // Ensure the matches don't cross basic block boundaries | 
|  | Graph pattern1; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%x, %y): | 
|  | %u = my::mul(%x, %y) | 
|  | %v = my::mul(%y, %u) | 
|  | return (%v))IR", | 
|  | &pattern1); | 
|  | AT_ASSERT(findPatternMatches(pattern1, graph).size() == 0); | 
|  | } | 
|  |  | 
|  | void testMatchesAttributes() { | 
|  | Graph graph; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %a = a::a[isattr=[1,2]](%0) | 
|  | %b = a::b[intattr=10, floatattr=3.14](%0) | 
|  | %c = a::c[myattr="qqq"](%a, %b) | 
|  | return (%c))IR", | 
|  | &graph); | 
|  |  | 
|  | { | 
|  | Graph pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%a, %b): | 
|  | %c = a::c[myattr="qqq"](%a, %b) | 
|  | return (%c))IR", | 
|  | &pattern); | 
|  | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); | 
|  | } | 
|  | { | 
|  | Graph pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%a, %b): | 
|  | %c = a::c[myattr="zzz"](%a, %b) | 
|  | return (%c))IR", | 
|  | &pattern); | 
|  | AT_ASSERT(findPatternMatches(pattern, graph).empty()); | 
|  | } | 
|  | { | 
|  | Graph pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %b = a::b[extraattr=10](%0) | 
|  | return (%b))IR", | 
|  | &pattern); | 
|  | AT_ASSERT(findPatternMatches(pattern, graph).empty()); | 
|  | } | 
|  | { | 
|  | Graph pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %b = a::b[intattr=10, floatattr=3.14](%0) | 
|  | return (%b))IR", | 
|  | &pattern); | 
|  | AT_ASSERT(!findPatternMatches(pattern, graph).empty()); | 
|  | } | 
|  | { | 
|  | Graph pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %b = a::b[intattr=10, floatattr=3.14, strattr="rrr"](%0) | 
|  | return (%b))IR", | 
|  | &pattern); | 
|  | AT_ASSERT(findPatternMatches(pattern, graph).empty()); | 
|  | } | 
|  | { | 
|  | Graph pattern; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %a = a::a[isattr=[1,2]](%0) | 
|  | return (%a))IR", | 
|  | &pattern); | 
|  | // Lists are not supported yet, thus we shouldn't match for now. | 
|  | AT_ASSERT(findPatternMatches(pattern, graph).empty()); | 
|  | } | 
|  | } | 
|  |  | 
|  | void testBadPattern() { | 
|  | Graph graph, pattern1, pattern2; | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%0): | 
|  | %a = a::aaa(%0) | 
|  | return (%a))IR", | 
|  | &graph); | 
|  |  | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%x): | 
|  | %y = my::node_with_subblock() | 
|  | block0(): | 
|  | %z = my::op(%x) | 
|  | -> (%z) | 
|  | return (%y))IR", | 
|  | &pattern1); | 
|  | ASSERT_ANY_THROW(findPatternMatches(pattern1, graph)); | 
|  |  | 
|  | script::parseIR( | 
|  | R"IR( | 
|  | graph(%x): | 
|  | %y = my::op1(%x) | 
|  | %z = my::op2(%x) | 
|  | return (%y, %z))IR", | 
|  | &pattern2); | 
|  | ASSERT_ANY_THROW(findPatternMatches(pattern2, graph)); | 
|  | } | 
|  |  | 
|  | void testSubgraphMatching() { | 
|  | testTrivial1(); | 
|  | testTrivial2(); | 
|  | testTrivial3(); | 
|  | testTrivial4(); | 
|  | testLinear1(); | 
|  | testLinear2(); | 
|  | testDiamond1(); | 
|  | testDiamond2(); | 
|  | testXPattern(); | 
|  | testMultipleMatches(); | 
|  | testOverlappingMatches(); | 
|  | testMatchInBasicBlocks1(); | 
|  | testMatchInBasicBlocks2(); | 
|  | testMatchesAttributes(); | 
|  | testBadPattern(); | 
|  | } | 
|  |  | 
|  | } // namespace jit | 
|  | } // namespace torch |