Implementation for Pattern Net Transforms, which is a transform initialized by a Pattern NetDef and a Replace NetDef.

Summary: Split this into its own file for ease of reviewing. This is a simple interface for someone to create a Transform - by simply providing their own Pattern and Replace NetDefs.

Reviewed By: akyrola

Differential Revision: D5440426

fbshipit-source-id: dc643226f40ffe4ec5c86d56cfea374bd6a4e0e5
diff --git a/caffe2/contrib/transform/CMakeLists.txt b/caffe2/contrib/transform/CMakeLists.txt
index 91eb627..a9fc090 100644
--- a/caffe2/contrib/transform/CMakeLists.txt
+++ b/caffe2/contrib/transform/CMakeLists.txt
@@ -3,6 +3,7 @@
   message(STATUS "Include Graph Transformations")
   set(Caffe2_CONTRIB_TRANSFORMS_CPU_SRC
     "${CMAKE_CURRENT_SOURCE_DIR}/transform.cc"
+    "${CMAKE_CURRENT_SOURCE_DIR}/pattern_net_transform.cc"
     "${CMAKE_CURRENT_SOURCE_DIR}/graph.cc"
   )
 
diff --git a/caffe2/contrib/transform/graph.cc b/caffe2/contrib/transform/graph.cc
index 7ee7753..b7be937 100644
--- a/caffe2/contrib/transform/graph.cc
+++ b/caffe2/contrib/transform/graph.cc
@@ -27,8 +27,8 @@
       auto it = edge_parent.find(blob);
       if (it != edge_parent.end()) {
         int j = it->second;
-        node(i).parents[j] = blob;
-        node(j).children[i] = blob;
+        node(i).parents[j].push_back(blob);
+        node(j).children[i].push_back(blob);
       } else {
         external_input_.insert(blob);
       }
@@ -84,9 +84,11 @@
       const auto& list = from_children ? node(x).children : node(x).parents;
       for (const auto& edge : list) {
         int parent = edge.first;
-        const string& blob = edge.second;
+        const auto& blobs = edge.second;
         if (match_set.count(parent)) { // but has a parent that is in subgraph
-          edge_list.push_back({blob, x});
+          for (const string& blob : blobs) {
+            edge_list.push_back({blob, x});
+          }
         }
       }
     }
@@ -190,4 +192,23 @@
 
 } // namespace transform
 
+OperatorDef* AddOp(
+    NetDef* netdef_ptr,
+    string op_type,
+    std::vector<string> inputs,
+    std::vector<string> outputs) {
+  CHECK(netdef_ptr);
+  auto& netdef = *netdef_ptr;
+  auto op_ptr = netdef.add_op();
+  auto& op = *op_ptr;
+  op.set_type(op_type);
+  for (const string& inp : inputs) {
+    op.add_input(inp);
+  }
+  for (const string& outp : outputs) {
+    op.add_output(outp);
+  }
+  return op_ptr;
+}
+
 } // namespace caffe2
diff --git a/caffe2/contrib/transform/graph.h b/caffe2/contrib/transform/graph.h
index 75c0ed7..e33bc5a 100644
--- a/caffe2/contrib/transform/graph.h
+++ b/caffe2/contrib/transform/graph.h
@@ -24,8 +24,8 @@
   Node(
       const OperatorDef& op,
       bool active,
-      std::map<int, string> parents,
-      std::map<int, string> children)
+      std::map<int, std::vector<string>> parents,
+      std::map<int, std::vector<string>> children)
       : op(op), active(active), parents(parents), children(children) {}
 
   // The OperatorDef which this node represents.
@@ -34,9 +34,11 @@
   // Keeps track of if an operator has been deleted through a transformation.
   bool active = true;
 
-  // Stores a pair (idx, blob), the index of the child, and the blob of edge.
-  std::map<int, string> parents;
-  std::map<int, string> children;
+  // Stores a pair (idx, blob_list),
+  //  idx = index of the child
+  //  blob_list = a list of strings, containing the blobs that connect the nodes
+  std::map<int, std::vector<string>> parents;
+  std::map<int, std::vector<string>> children;
 };
 
 /**
@@ -150,4 +152,12 @@
 
 } // namespace transform
 
+// Adds an operator def to a netdef.
+// Returns the ptr, if you want to add anything extra (such as device_option)
+OperatorDef* AddOp(
+    NetDef* netdef_ptr,
+    string op_type,
+    std::vector<string> inputs,
+    std::vector<string> outputs);
+
 } // namespace caffe2
diff --git a/caffe2/contrib/transform/graph_test.cc b/caffe2/contrib/transform/graph_test.cc
index 67d937b..e25cb86 100644
--- a/caffe2/contrib/transform/graph_test.cc
+++ b/caffe2/contrib/transform/graph_test.cc
@@ -42,27 +42,6 @@
     .NumOutputs(0, INT_MAX)
     .AllowInplace({{0, 0}, {1, 1}});
 
-// Adds an operator def to a netdef.
-// Returns the ptr, if you want to add anything extra (such as device_option)
-OperatorDef* AddOp(
-    NetDef* netdef_ptr,
-    string op_type,
-    std::vector<string> inputs,
-    std::vector<string> outputs) {
-  CHECK(netdef_ptr);
-  auto& netdef = *netdef_ptr;
-  auto op_ptr = netdef.add_op();
-  auto& op = *op_ptr;
-  op.set_type(op_type);
-  for (const string& inp : inputs) {
-    op.add_input(inp);
-  }
-  for (const string& outp : outputs) {
-    op.add_output(outp);
-  }
-  return op_ptr;
-}
-
 // Checks if two netdefs are  in terms of type, input, and output.
 void compare_netdefs(const NetDef& net_a, const NetDef& net_b) {
   EXPECT_EQ(net_a.op_size(), net_b.op_size());
diff --git a/caffe2/contrib/transform/transform.cc b/caffe2/contrib/transform/transform.cc
index 329f28a..b6d2271 100644
--- a/caffe2/contrib/transform/transform.cc
+++ b/caffe2/contrib/transform/transform.cc
@@ -39,7 +39,7 @@
 
 void Transform::TryNeighbors(
     const Graph& graph,
-    const std::map<int, string>& neighbors,
+    const std::map<int, std::vector<string>>& neighbors,
     std::vector<int>* subgraph_ptr,
     std::vector<int>* best_subgraph_ptr) {
   auto& subgraph = *subgraph_ptr;
@@ -83,9 +83,17 @@
 void Transform::ReplacePattern(
     const std::vector<vector<int>>& matches,
     Graph* graph) {
-  // Simply try to apply the replace rule upon every match.
   for (const auto& match : matches) {
-    if (!ReplaceRule(match, graph)) {
+    // Make sure each matched node is still active (not overwritten)
+    bool is_match_active = true;
+    for (int idx : match) {
+      if (!graph->is_node_active(idx)) {
+        is_match_active = false;
+      }
+    }
+
+    // Simply try to apply the replace rule upon every match.
+    if (is_match_active && !ReplaceRule(match, graph)) {
       CAFFE_THROW("Replace failed!");
     }
   }
diff --git a/caffe2/contrib/transform/transform.h b/caffe2/contrib/transform/transform.h
index cbc3f1c..72e65fb 100644
--- a/caffe2/contrib/transform/transform.h
+++ b/caffe2/contrib/transform/transform.h
@@ -103,7 +103,7 @@
    */
   void TryNeighbors(
       const transform::Graph& graph,
-      const std::map<int, string>& neighbors,
+      const std::map<int, std::vector<string>>& neighbors,
       std::vector<int>* subgraph_ptr,
       std::vector<int>* best_subgraph_ptr);
 };
diff --git a/caffe2/contrib/transform/transform_test.cc b/caffe2/contrib/transform/transform_test.cc
index 37fe88e..726914f 100644
--- a/caffe2/contrib/transform/transform_test.cc
+++ b/caffe2/contrib/transform/transform_test.cc
@@ -61,15 +61,16 @@
     new_op.set_type("DummyOp3");
     int new_idx = g.size();
 
-    std::map<int, string> new_op_children;
-    std::map<int, string> new_op_parents;
+    std::map<int, std::vector<string>> new_op_children;
+    std::map<int, std::vector<string>> new_op_parents;
 
     // for each node parent in the head of the match, connect it to our new node
     for (const auto& edge : g.node(match[0]).parents) {
       int parent = edge.first;
-      string blob = edge.second;
-      g.node(parent).children[new_idx] = blob;
-      new_op_parents[parent] = blob;
+      for (const auto& blob : edge.second) {
+        g.node(parent).children[new_idx].push_back(blob);
+        new_op_parents[parent].push_back(blob);
+      }
     }
     for (const string& blob : g.node(match[0]).op.input()) {
       new_op.add_input(blob);
@@ -78,9 +79,10 @@
     // for each child in the tail of the match, connect it to our new node
     for (const auto& edge : g.node(match[1]).children) {
       int child = edge.first;
-      string blob = edge.second;
-      g.node(child).parents[new_idx] = blob;
-      new_op_children[child] = blob;
+      for (const auto& blob : edge.second) {
+        g.node(child).parents[new_idx].push_back(blob);
+        new_op_children[child].push_back(blob);
+      }
     }
     for (const string& blob : g.node(match[1]).op.output()) {
       new_op.add_output(blob);
@@ -98,27 +100,6 @@
 
 REGISTER_TRANSFORM(DummySwap, DummyTransform)
 
-// Adds an operator def to a netdef.
-// Returns the ptr, if you want to add anything extra (such as device_option)
-OperatorDef* AddOp(
-    NetDef* netdef_ptr,
-    string op_type,
-    std::vector<string> inputs,
-    std::vector<string> outputs) {
-  CHECK(netdef_ptr);
-  auto& netdef = *netdef_ptr;
-  auto op_ptr = netdef.add_op();
-  auto& op = *op_ptr;
-  op.set_type(op_type);
-  for (const string& inp : inputs) {
-    op.add_input(inp);
-  }
-  for (const string& outp : outputs) {
-    op.add_output(outp);
-  }
-  return op_ptr;
-}
-
 TEST(TransformTest, TestPatternMatch) {
   Workspace ws;
   ws.CreateBlob("in");
diff --git a/caffe2/contrib/transform/transforms/conv_to_nnpack_transform_test.cc b/caffe2/contrib/transform/transforms/conv_to_nnpack_transform_test.cc
index 5e5f691..f17485e 100644
--- a/caffe2/contrib/transform/transforms/conv_to_nnpack_transform_test.cc
+++ b/caffe2/contrib/transform/transforms/conv_to_nnpack_transform_test.cc
@@ -11,27 +11,6 @@
 
 using transform::Graph;
 
-// Adds an operator def to a netdef.
-// Returns the ptr, if you want to add anything extra (such as device_option)
-OperatorDef* AddOp(
-    NetDef* netdef_ptr,
-    string op_type,
-    std::vector<string> inputs,
-    std::vector<string> outputs) {
-  CHECK(netdef_ptr);
-  auto& netdef = *netdef_ptr;
-  auto op_ptr = netdef.add_op();
-  auto& op = *op_ptr;
-  op.set_type(op_type);
-  for (const string& inp : inputs) {
-    op.add_input(inp);
-  }
-  for (const string& outp : outputs) {
-    op.add_output(outp);
-  }
-  return op_ptr;
-}
-
 TEST(ConvToNNPackTest, TestSimple) {
   NetDef netdef;
   OperatorDef* op;
diff --git a/caffe2/contrib/transform/transforms/pattern_net_transform.cc b/caffe2/contrib/transform/transforms/pattern_net_transform.cc
new file mode 100644
index 0000000..cf63af2
--- /dev/null
+++ b/caffe2/contrib/transform/transforms/pattern_net_transform.cc
@@ -0,0 +1,247 @@
+#include "caffe2/contrib/transform/transforms/pattern_net_transform.h"
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/net.h"
+#include "caffe2/proto/caffe2.pb.h"
+
+namespace caffe2 {
+
+// First, single source traverse through the netdef.
+// This ensures all newly ordered are reachable from their prefix subset
+// Outputs a permutation of the operators.
+std::vector<int> PatternNetTransform::GetPatternTraversalOrder(
+    const transform::Graph& graph) {
+  std::vector<bool> visited(graph.size(), false);
+  std::vector<int> ordered_ops;
+  std::queue<int> q;
+  if (graph.size() > 0) {
+    q.push(0);
+    ordered_ops.push_back(0);
+    visited[0] = true;
+  }
+  while (!q.empty()) {
+    int idx = q.front();
+    q.pop();
+    for (const auto& edge : graph.node(idx).children) {
+      int x = edge.first;
+      if (!visited[x]) {
+        q.push(x);
+        ordered_ops.push_back(x);
+        visited[x] = true;
+      }
+    }
+    for (const auto& edge : graph.node(idx).parents) {
+      int x = edge.first;
+      if (!visited[x]) {
+        q.push(x);
+        ordered_ops.push_back(x);
+        visited[x] = true;
+      }
+    }
+  }
+  CAFFE_ENFORCE(
+      ordered_ops.size() == graph.size(), "Pattern graph must be connected.");
+  return ordered_ops;
+}
+
+bool compare_ops(const OperatorDef& x, const OperatorDef& y, bool arg_match) {
+  // make sure types are the same
+  if (x.type() != y.type()) {
+    return false;
+  }
+  if (x.input().size() != y.input().size()) {
+    return false;
+  }
+  if (x.output().size() != y.output().size()) {
+    return false;
+  }
+
+  // TODO(benz): make a comparison for device_option.
+  // make sure engine is the same (if specified in pattern)
+  if (x.has_engine() && x.engine() != y.engine()) {
+    return false;
+  }
+  // If argument_match is specified, make sure those are the same.
+  if (arg_match) {
+    if (x.arg().size() != y.arg().size()) {
+      return false;
+    }
+    // TODO(benz): Create arg equality operator.
+  }
+  return true;
+}
+
+// g.node(subgraph[i]) should match p_.node(ordered_ops_[i])
+// g.node(g_idx) should match p_.node(p_idx)
+bool PatternNetTransform::PatternRule(
+    const transform::Graph& g,
+    const std::vector<int>& subgraph,
+    int g_idx) {
+  if (subgraph.size() >= ordered_ops_.size()) {
+    return false;
+  }
+  int p_idx = ordered_ops_[subgraph.size()];
+
+  if (!compare_ops(p_.node(p_idx).op, g.node(g_idx).op, argument_match_)) {
+    return false;
+  }
+
+  // Let's say ordered_ops_ is [0, 2, 1], with 0 -> 2 being an edge
+  // When we try to match onto the second element, let's say our
+  // subgraph so far is [4], with it trying to become [4, 5].
+  // Then, we need to show that since 0 -> 2 is an edge is ordered_ops_,
+  // 4 must be a direct parent of 5 in the subgraph
+  // (the indices must match).
+  // Similarly, assume there is an edge from 1 -> 2 in p_.
+  // When trying to match [4, 5] to [4, 5, 7], we must verify that
+  // there exists an edge from 7 -> 5 in G.
+  for (const auto& edge : p_.node(p_idx).parents) {
+    int parent = edge.first;
+    // g_idx doesn't have parent in subgraph that p_[p_idx] has
+    // inverse_ops_ gets the index of a p_idx inside of ordered_ops_.
+    if (inverse_ops_[parent] < subgraph.size() &&
+        g.node(g_idx).parents.count(subgraph[inverse_ops_[parent]]) == 0) {
+      return false;
+    }
+  }
+
+  for (const auto& edge : p_.node(p_idx).children) {
+    int child = edge.first;
+    if (inverse_ops_[child] < subgraph.size() &&
+        g.node(g_idx).children.count(subgraph[inverse_ops_[child]]) == 0) {
+      return false;
+    }
+  }
+  return true;
+}
+
+bool PatternNetTransform::ValidatorRule(
+    const transform::Graph& g,
+    const std::vector<int>& subgraph) {
+  // Due to strict PatternRule, it suffices to simply check for size
+  return subgraph.size() == p_.size();
+}
+
+bool PatternNetTransform::ReplaceRule(
+    const std::vector<int>& match,
+    transform::Graph* g_ptr) {
+  CHECK(g_ptr);
+  auto& g = *g_ptr;
+
+  ssa_id_++;
+
+  // Map of PatternNet blob name to Matched blob name.
+  // Figures out how to rename the pattern_net to make the replacement fit.
+  std::unordered_map<string, string> external_renaming;
+
+  // Figure out blob renamings
+  for (int i = 0; i < match.size(); i++) {
+    int g_idx = match[i];
+    int p_idx = ordered_ops_[i];
+    for (int j = 0; j < p_.node(p_idx).op.input().size(); j++) {
+      string p_blob = p_.node(p_idx).op.input(j);
+      string g_blob = g.node(g_idx).op.input(j);
+      if (p_.external_input().count(p_blob)) {
+        external_renaming[p_blob] = g_blob;
+      }
+    }
+    for (int j = 0; j < p_.node(p_idx).op.output().size(); j++) {
+      string p_blob = p_.node(p_idx).op.output(j);
+      string g_blob = g.node(g_idx).op.output(j);
+      if (p_.external_output().count(p_blob)) {
+        external_renaming[p_blob] = g_blob;
+      }
+    }
+  }
+
+  auto input_list = g.GetSubgraphInput(match);
+  auto output_list = g.GetSubgraphOutput(match);
+
+  g.DeactivateSubgraph(match);
+
+  int offset = g.size();
+
+  g.resize_nodes(offset + r_.size());
+
+  // Append all the new operators.
+  for (int i = 0; i < r_.size(); i++) {
+    int new_node_idx = offset + i;
+
+    OperatorDef new_op = r_.node(i).op;
+
+    new_op.clear_input();
+    new_op.clear_output();
+    // Stitch Input from external graph into replaced subgraph
+    for (const auto& blob : r_.node(i).op.input()) {
+      if (external_renaming.count(blob)) {
+        string new_blob = external_renaming[blob];
+        new_op.add_input(new_blob);
+
+        // binary searches for new_blob amongst input list.
+        auto it = std::lower_bound(
+            input_list.begin(), input_list.end(), std::make_pair(new_blob, -1));
+
+        // if the input came from the graph (instead of G's external input)
+        for (; it < input_list.end() && it->first == new_blob; it++) {
+          int parent = it->second;
+          g.node(parent).children[new_node_idx].push_back(new_blob);
+          g.node(new_node_idx).parents[parent].push_back(new_blob);
+        }
+      } else {
+        new_op.add_input(TransformBlobWrapper(blob));
+      }
+    }
+    // Stitch Output from replaced subgraph to external graph.
+    for (const auto& blob : r_.node(i).op.output()) {
+      if (external_renaming.count(blob)) {
+        string new_blob = external_renaming[blob];
+        new_op.add_output(new_blob);
+
+        // binary searches for new_blob amongst input list.
+        auto it = std::lower_bound(
+            output_list.begin(),
+            output_list.end(),
+            std::make_pair(new_blob, -1));
+
+        // if the output goes to the graph (instead of G's external output)
+        for (; it < output_list.end() && it->first == new_blob; it++) {
+          int child = it->second;
+          g.node(child).parents[new_node_idx].push_back(new_blob);
+          g.node(new_node_idx).children[child].push_back(new_blob);
+        }
+      } else {
+        new_op.add_output(TransformBlobWrapper(blob));
+      }
+    }
+
+    // Connect all internal edges within replace graph
+    for (const auto& edge : r_.node(i).parents) {
+      int parent = edge.first;
+      int new_node_parent = offset + parent;
+      const auto& blobs = edge.second;
+      for (const string& blob : blobs) {
+        g.node(new_node_idx)
+            .parents[new_node_parent]
+            .push_back(TransformBlobWrapper(blob));
+      }
+    }
+
+    for (const auto& edge : r_.node(i).children) {
+      int child = edge.first;
+      int new_node_child = offset + child;
+      const auto& blobs = edge.second;
+      for (const string& blob : blobs) {
+        g.node(offset + i)
+            .children[new_node_child]
+            .push_back(TransformBlobWrapper(blob));
+      }
+    }
+
+    g.node(new_node_idx).op = new_op;
+    g.node(new_node_idx).active = true;
+  }
+  return true;
+}
+
+} // namespace Caffe2
diff --git a/caffe2/contrib/transform/transforms/pattern_net_transform.h b/caffe2/contrib/transform/transforms/pattern_net_transform.h
new file mode 100644
index 0000000..e89a67e
--- /dev/null
+++ b/caffe2/contrib/transform/transforms/pattern_net_transform.h
@@ -0,0 +1,126 @@
+#pragma once
+
+#include "caffe2/contrib/transform/transform.h"
+#include "caffe2/core/common.h"
+#include "caffe2/proto/caffe2.pb.h"
+#include "caffe2/utils/proto_utils.h"
+
+namespace caffe2 {
+
+/**
+ * PatternNetTransform allows you to create transforms using a simple
+ * interface.
+ *
+ * Simply provide a Pattern NetDef and a Replace NetDef,
+ * and this Transform will find subgraphs which fit the pattern net,
+ * and replace it with the replace net.
+ */
+class PatternNetTransform : public Transform {
+ public:
+  PatternNetTransform(const NetDef& pattern_net, const NetDef& replace_net)
+      : p_(transform::Graph(pattern_net)), r_(transform::Graph(replace_net)) {
+    // external input and output must match!
+    CAFFE_ENFORCE(
+        p_.external_input() == r_.external_input(),
+        "External inputs do not match!");
+    CAFFE_ENFORCE(
+        p_.external_output() == r_.external_output(),
+        "External outputs do not match!");
+    ordered_ops_ = GetPatternTraversalOrder(p_);
+    inverse_ops_.resize(ordered_ops_.size());
+    for (int i = 0; i < ordered_ops_.size(); i++) {
+      inverse_ops_[ordered_ops_[i]] = i;
+    }
+  }
+
+ protected:
+  /**
+   * We want to the final result of subgraph to match the PatternNet in the
+   * order of ordered_ops, operator by operator.
+   *
+   * [[[ ie. g.node(subgraph[i]) should match p.node(ordered_ops[i]) ]]]
+   *
+   * PatternRule for PatternNetTransform does the following:
+   *
+   * When trying to insert node idx into subgraph[p_idx],
+   * we need to see if the edges between index and the
+   * subgraph match the edges between p[ordered_ops[idx]]
+   * and p[ordered_ops[0]...ordered_ops[p_idx-1]].
+   */
+  bool PatternRule(
+      const transform::Graph& g,
+      const std::vector<int>& subgraph,
+      int idx) override;
+  /**
+   * ValidatorRule for PatternNetTransform does the following:
+   *
+   * Checks if the size of subgraph and p.size() are the same. That's it!
+   */
+  bool ValidatorRule(
+      const transform::Graph& g,
+      const std::vector<int>& subgraph) override;
+  /**
+   * ReplaceRule for PatternNet Transform does the following:
+   *
+   * 1) Figure out edge renamings for edges going into/out of the subgraph.
+   * That is, for each blob in the pattern graph, what is it called in the
+   * matched subgraph?
+   *
+   * 2) Remove the matched subgraph.
+   *
+   * 3) Append the replace graph's operators to the graph's operators, and use
+   *    the renamings to rename the blob names.
+   *
+   * 4) Create all the children/parent relationships within the replaced graph,
+   *    and stitch together the inputs and outputs into the rest of the graph,
+   *    matching the removed subgraph.
+   */
+  bool ReplaceRule(const std::vector<int>& subgraph, transform::Graph* g_ptr)
+      override;
+
+ private:
+  /**
+   * This returns a permutation of the Pattern Net's operators.
+   * The permutation satisfies this property:
+   *    - For any index i, order(i) is a neighbor of some node from
+   *      {order(1), ..., order(i-1)}.
+   *
+   * Why is this important? Consider the following case:
+   * PatternNet: 0 ---> 2 <--- 1
+   *
+   * When we have matched onto [0], and trying to add [1] to our subgraph,
+   * we cannot, since PatternMatch only considers neighbors of the current
+   * subgraph as a candidate next node.
+   *
+   * Therefore, we must present the subgraph in an order such that each node is
+   * a neighbor of its prefix subgraph. One ordering for the above example is
+   * [0, 2, 1].
+   */
+  std::vector<int> GetPatternTraversalOrder(const transform::Graph& g);
+
+  // Graph of Pattern NetDef
+  transform::Graph p_;
+
+  // The Traversal Order of the Pattern Net's Operators
+  // This is a permutation of the numbers from {0, ..., p.size()-1}
+  std::vector<int> ordered_ops_;
+
+  // The Inverse of the Traversal Order of the Pattern Net's Operators
+  // That is, inverse_ops[ordered_ops[i]] == i is always true.
+  std::vector<int> inverse_ops_;
+
+  // Graph of Replace NetDef
+  transform::Graph r_;
+
+  // This flag determines if the transform will match operator arguments.
+  // TODO(benz): Write a good argument comparator
+  bool argument_match_ = false;
+
+  const string TransformBlobWrapper(const string& blob_name) {
+    return "transform/" + blob_name + "_" + caffe2::to_string(ssa_id_);
+  }
+
+  int ssa_id_ = 0;
+};
+
+} // namespace caffe2
diff --git a/caffe2/contrib/transform/transforms/pattern_net_transform_test.cc b/caffe2/contrib/transform/transforms/pattern_net_transform_test.cc
new file mode 100644
index 0000000..32d36e9
--- /dev/null
+++ b/caffe2/contrib/transform/transforms/pattern_net_transform_test.cc
@@ -0,0 +1,348 @@
+#include <google/protobuf/text_format.h>
+#include <gtest/gtest.h>
+#include "caffe2/contrib/transform/transforms/pattern_net_transform.h"
+#include "caffe2/core/net.h"
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+namespace {
+
+using transform::Graph;
+
+static std::atomic<int> counter;
+
+class DummyCounterOp final : public OperatorBase {
+ public:
+  using OperatorBase::OperatorBase;
+  bool Run(int /* unused */) override {
+    counter.fetch_add(1);
+    return true;
+  }
+};
+
+REGISTER_CPU_OPERATOR(DummyCounterOp1, DummyCounterOp);
+REGISTER_CUDA_OPERATOR(DummyCounterOp1, DummyCounterOp);
+
+OPERATOR_SCHEMA(DummyCounterOp1)
+    .NumInputs(0, INT_MAX)
+    .NumOutputs(0, INT_MAX)
+    .AllowInplace({{0, 0}, {1, 1}});
+
+REGISTER_CPU_OPERATOR(DummyCounterOp2, DummyCounterOp);
+REGISTER_CUDA_OPERATOR(DummyCounterOp2, DummyCounterOp);
+
+OPERATOR_SCHEMA(DummyCounterOp2)
+    .NumInputs(0, INT_MAX)
+    .NumOutputs(0, INT_MAX)
+    .AllowInplace({{0, 0}, {1, 1}});
+
+REGISTER_CPU_OPERATOR(DummyCounterOp3, DummyCounterOp);
+REGISTER_CUDA_OPERATOR(DummyCounterOp3, DummyCounterOp);
+
+OPERATOR_SCHEMA(DummyCounterOp3)
+    .NumInputs(0, INT_MAX)
+    .NumOutputs(0, INT_MAX)
+    .AllowInplace({{0, 0}, {1, 1}});
+
+/**
+ * P = ---> (Op1) ---> (Op2) --->
+ *
+ * R = ---> (Op3) ---> (Op3) --->
+ */
+TEST(PatternNetTransformTest, TestGenerateTransform) {
+  Workspace ws;
+  ws.CreateBlob("in");
+
+  NetDef netdef;
+  AddOp(&netdef, "DummyCounterOp1", {"in"}, {"mid1"});
+  AddOp(&netdef, "DummyCounterOp2", {"mid1"}, {"mid2"});
+  AddOp(&netdef, "DummyCounterOp1", {"mid2"}, {"mid3"});
+  AddOp(&netdef, "DummyCounterOp2", {"mid3"}, {"out"});
+
+  NetDef pdef;
+  AddOp(&pdef, "DummyCounterOp1", {"in"}, {"mid"});
+  AddOp(&pdef, "DummyCounterOp2", {"mid"}, {"out"});
+
+  NetDef rdef;
+  AddOp(&rdef, "DummyCounterOp3", {"in"}, {"new_mid"});
+  AddOp(&rdef, "DummyCounterOp3", {"new_mid"}, {"out"});
+
+  PatternNetTransform t(pdef, rdef);
+
+  // test pattern match
+  Graph g(netdef);
+
+  auto matches = t.PatternMatch(g);
+  EXPECT_EQ(matches.size(), 2);
+
+  t.ReplacePattern(matches, &g);
+
+  EXPECT_EQ(g.size(), 8);
+  for (int i = 0; i < 4; i++) {
+    EXPECT_FALSE(g.is_node_active(i));
+  }
+  for (int i = 4; i < 8; i++) {
+    EXPECT_TRUE(g.is_node_active(i));
+  }
+
+  EXPECT_TRUE(g.node(4).children.count(5));
+  EXPECT_TRUE(g.node(5).children.count(6));
+  EXPECT_TRUE(g.node(6).children.count(7));
+
+  for (int i = 4; i < 8; i++) {
+    EXPECT_EQ(g.node(i).op.input().size(), 1);
+    EXPECT_EQ(g.node(i).op.output().size(), 1);
+  }
+
+  NetDef replaced_netdef = g.GetNetDef();
+
+  EXPECT_EQ(replaced_netdef.op().size(), 4);
+  EXPECT_EQ(replaced_netdef.op(0).type(), "DummyCounterOp3");
+  EXPECT_EQ(replaced_netdef.op(1).type(), "DummyCounterOp3");
+  EXPECT_EQ(replaced_netdef.op(2).type(), "DummyCounterOp3");
+  EXPECT_EQ(replaced_netdef.op(3).type(), "DummyCounterOp3");
+}
+
+/**
+ * P = ---> (Op1) ---> (Op2) --->
+ *
+ * R = ---> (Op3) ---> (Op3) --->
+ */
+TEST(PatternNetTransformTest, TestRepeatedTransform) {
+  Workspace ws;
+  ws.CreateBlob("in");
+
+  NetDef netdef;
+  AddOp(&netdef, "DummyCounterOp1", {"in"}, {"out"});
+  AddOp(&netdef, "DummyCounterOp2", {"out"}, {"out"});
+  for (int i = 0; i < 99; i++) {
+    AddOp(&netdef, "DummyCounterOp1", {"out"}, {"out"});
+    AddOp(&netdef, "DummyCounterOp2", {"out"}, {"out"});
+  }
+
+  NetDef pdef;
+  AddOp(&pdef, "DummyCounterOp1", {"in"}, {"mid"});
+  AddOp(&pdef, "DummyCounterOp2", {"mid"}, {"out"});
+
+  NetDef rdef;
+  AddOp(&rdef, "DummyCounterOp3", {"in"}, {"new_mid"});
+  AddOp(&rdef, "DummyCounterOp3", {"new_mid"}, {"out"});
+
+  PatternNetTransform t(pdef, rdef);
+
+  // test pattern match
+  Graph g(netdef);
+
+  auto matches = t.PatternMatch(g);
+  EXPECT_EQ(matches.size(), 100);
+
+  t.ReplacePattern(matches, &g);
+  NetDef replaced_netdef = g.GetNetDef();
+
+  EXPECT_EQ(replaced_netdef.op_size(), 200);
+  for (int i = 0; i < 200; i++) {
+    EXPECT_EQ(replaced_netdef.op(i).type(), "DummyCounterOp3");
+  }
+
+  unique_ptr<NetBase> net = CreateNet(replaced_netdef, &ws);
+  counter.exchange(0);
+  net.get()->Run();
+  EXPECT_EQ(200, counter.load());
+}
+
+/**
+ * P = ---> (Op1) ---> (Op3) ---> (Op2) --->
+ *            |------> (Op3) -------|
+ *
+ * R = ---> (Op1) --------------> (Op3) --->
+ *          |_(Op3)-->(Op3)-->(Op2)_|
+ *
+ */
+TEST(PatternNetTransformTest, TestHardTransform) {
+  Workspace ws;
+  ws.CreateBlob("in");
+
+  NetDef netdef;
+  // Segment 1 (differs from P because of type)
+  AddOp(&netdef, "DummyCounterOp1", {"in"}, {"mid1a_1", "mid1b_1"});
+  AddOp(&netdef, "DummyCounterOp2", {"mid1a_1"}, {"mid2a_1"});
+  AddOp(&netdef, "DummyCounterOp3", {"mid1b_1"}, {"mid2b_1"});
+  AddOp(&netdef, "DummyCounterOp3", {"mid2a_1", "mid2b_1"}, {"out_1"});
+
+  // Segment 2 (differs from P because of structure)
+  AddOp(
+      &netdef, "DummyCounterOp1", {"out_1"}, {"mid1a_2", "mid1b_2", "mid1c_2"});
+  AddOp(&netdef, "DummyCounterOp3", {"mid1a_2"}, {"mid2a_2"});
+  AddOp(&netdef, "DummyCounterOp3", {"mid1b_2"}, {"mid2b_2"});
+  AddOp(&netdef, "DummyCounterOp3", {"mid1c_2"}, {"mid2c_2"});
+  AddOp(
+      &netdef, "DummyCounterOp2", {"mid2a_2", "mid2b_2", "mid2c_2"}, {"out_2"});
+
+  // Segment 3
+  AddOp(&netdef, "DummyCounterOp1", {"out_2"}, {"mid1a_3", "mid1b_3"});
+  AddOp(&netdef, "DummyCounterOp3", {"mid1a_3"}, {"mid2a_3"});
+  AddOp(&netdef, "DummyCounterOp3", {"mid1b_3"}, {"mid2b_3"});
+  AddOp(&netdef, "DummyCounterOp2", {"mid2a_3", "mid2b_3"}, {"out"});
+
+  NetDef pdef;
+  // Should only match Segment 3
+  AddOp(&pdef, "DummyCounterOp1", {"sub_in"}, {"mid1a", "mid1b"});
+  AddOp(&pdef, "DummyCounterOp3", {"mid1a"}, {"mid2a"});
+  AddOp(&pdef, "DummyCounterOp3", {"mid1b"}, {"mid2b"});
+  AddOp(&pdef, "DummyCounterOp2", {"mid2a", "mid2b"}, {"sub_out"});
+
+  NetDef rdef;
+  AddOp(&rdef, "DummyCounterOp1", {"sub_in"}, {"mid1a", "mid1b"});
+  AddOp(&rdef, "DummyCounterOp3", {"mid1b"}, {"mid2b"});
+  AddOp(&rdef, "DummyCounterOp3", {"mid2b"}, {"mid3b"});
+  AddOp(&rdef, "DummyCounterOp2", {"mid3b"}, {"mid4b"});
+  AddOp(&rdef, "DummyCounterOp3", {"mid1a", "mid4b"}, {"sub_out"});
+
+  PatternNetTransform t(pdef, rdef);
+  Graph g(netdef);
+  EXPECT_EQ(g.size(), 13);
+
+  auto matches = t.PatternMatch(g);
+  EXPECT_EQ(matches.size(), 1);
+
+  t.ReplacePattern(matches, &g);
+  EXPECT_EQ(g.size(), 18);
+
+  NetDef replaced_netdef = g.GetNetDef();
+  EXPECT_EQ(replaced_netdef.op_size(), 14);
+  unique_ptr<NetBase> net = CreateNet(replaced_netdef, &ws);
+  counter.exchange(0);
+  net.get()->Run();
+  EXPECT_EQ(14, counter.load());
+}
+
+/**
+ *           |--(Op2)--|
+ * P = --->(Op1)----->(Op3)--->
+ *           |--(Op2)--|
+ *
+ * R = ---> (Op2) --->
+ *
+ *                |--(Op2)--|
+ *           -->(Op1)----->(Op3)---
+ *           |    |--(Op2)--|     |
+ * G = ---> (Op1)                (Op3) --->
+ *           |    |--(Op2)--|     |
+ *           -->(Op1)----->(Op3)--
+ *                |--(Op2)--|
+ *
+ * In this test, the two "parallel" modules have intersecting execution orders.
+ * We wish to test that the pattern match can still detect the two modules,
+ * separately.
+ *
+ * Furthermore, we will apply the transform to G, TWICE.
+ * It should reduce G to a single operator.
+ */
+TEST(PatternNetTransformTest, TestNonStrictTopographicTransform) {
+  Workspace ws;
+  ws.CreateBlob("in");
+
+  NetDef netdef;
+  // Head
+  AddOp(&netdef, "DummyCounterOp1", {"in"}, {"in_1", "in_2"});
+
+  // 2 intertwined segments, each matching P. No strict ordering.
+  AddOp(&netdef, "DummyCounterOp1", {"in_1"}, {"m1_1", "m2_1"});
+  AddOp(&netdef, "DummyCounterOp1", {"in_2"}, {"m1_2", "m2_2"});
+  AddOp(&netdef, "DummyCounterOp2", {"m1_1"}, {"out1_1"});
+  AddOp(&netdef, "DummyCounterOp2", {"m1_2"}, {"out1_2"});
+  AddOp(&netdef, "DummyCounterOp2", {"m2_1"}, {"out2_1"});
+  AddOp(&netdef, "DummyCounterOp2", {"m2_2"}, {"out2_2"});
+  AddOp(&netdef, "DummyCounterOp3", {"out1_1", "out2_1"}, {"out1"});
+  AddOp(&netdef, "DummyCounterOp3", {"out1_2", "out2_2"}, {"out2"});
+
+  // Tail
+  AddOp(&netdef, "DummyCounterOp3", {"out1", "out2"}, {"out"});
+
+  NetDef pdef;
+  AddOp(&pdef, "DummyCounterOp1", {"myin"}, {"mid1a", "mid1b"});
+  AddOp(&pdef, "DummyCounterOp2", {"mid1a"}, {"mid2a"});
+  AddOp(&pdef, "DummyCounterOp2", {"mid1b"}, {"mid2b"});
+  AddOp(&pdef, "DummyCounterOp3", {"mid2a", "mid2b"}, {"myout"});
+
+  NetDef rdef;
+  AddOp(&rdef, "DummyCounterOp2", {"myin"}, {"myout"});
+
+  PatternNetTransform t(pdef, rdef);
+
+  NetDef replaced_netdef = t.ApplyTo(netdef);
+  EXPECT_EQ(replaced_netdef.op_size(), 4);
+  unique_ptr<NetBase> net = CreateNet(replaced_netdef, &ws);
+  counter.exchange(0);
+  net.get()->Run();
+  EXPECT_EQ(4, counter.load());
+
+  // apply the transform again
+  // the entire net should get transformed this time
+  NetDef double_transformed_net = t.ApplyTo(replaced_netdef);
+  EXPECT_EQ(double_transformed_net.op_size(), 1);
+}
+
+/**
+ *      --->(Op1)----->(Op2)--->
+ *            |          ^
+ * P =        |----------|
+ *            |          v
+ *      --->(Op1)----->(Op2)--->
+ *
+ * R =  ---> (Op3) --->
+ *
+ * G = P -> P
+ *
+ * In this test, we fuse a subgraph with two inputs and two outputs, into one
+ * operator.
+ *
+ * This will ensure that we can allow a single edge to represent
+ * multiple blob names (the input and output of R are both 2 blobs).
+ *
+ * This will also ensure that patternmatch can traverse "backwards", from a node
+ * to its parent.
+ *
+ * Furthermore, this tests for repeat matches, since matching on either of the
+ * first two Op1 nodes will produce a match, but they are identical.
+ * So, the pattern should match 4 times, but only be replaced twice.
+ */
+TEST(PatternNetTransformTest, TestMultiInputOutputTransform) {
+  Workspace ws;
+  ws.CreateBlob("in1");
+  ws.CreateBlob("in2");
+
+  NetDef netdef;
+  AddOp(&netdef, "DummyCounterOp1", {"in1"}, {"in1"}); // has 2 children
+  AddOp(&netdef, "DummyCounterOp1", {"in2"}, {"in2"}); // has 2 children
+  AddOp(&netdef, "DummyCounterOp2", {"in1", "in2"}, {"mid1"});
+  AddOp(&netdef, "DummyCounterOp2", {"in1", "in2"}, {"mid2"});
+  AddOp(&netdef, "DummyCounterOp1", {"mid1"}, {"mid1"}); // has 2 children
+  AddOp(&netdef, "DummyCounterOp1", {"mid2"}, {"mid2"}); // has 2 children
+  AddOp(&netdef, "DummyCounterOp2", {"mid1", "mid2"}, {"out1"});
+  AddOp(&netdef, "DummyCounterOp2", {"mid1", "mid2"}, {"out2"});
+
+  NetDef pdef;
+  AddOp(&pdef, "DummyCounterOp1", {"subin1"}, {"subin1"}); // has 2 children
+  AddOp(&pdef, "DummyCounterOp1", {"subin2"}, {"subin2"}); // has 2 children
+  AddOp(&pdef, "DummyCounterOp2", {"subin1", "subin2"}, {"subout1"});
+  AddOp(&pdef, "DummyCounterOp2", {"subin1", "subin2"}, {"subout2"});
+
+  NetDef rdef;
+  AddOp(&rdef, "DummyCounterOp3", {"subin1", "subin2"}, {"subout1", "subout2"});
+
+  PatternNetTransform t(pdef, rdef);
+  Graph g(netdef);
+
+  NetDef replaced_netdef = t.ApplyTo(netdef);
+  EXPECT_EQ(replaced_netdef.op_size(), 2);
+  unique_ptr<NetBase> net = CreateNet(replaced_netdef, &ws);
+  counter.exchange(0);
+  net.get()->Run();
+  EXPECT_EQ(2, counter.load());
+}
+
+} // namespace
+
+} // namespace Caffe2