Implementation for Pattern Net Transforms, which is a transform initialized by a Pattern NetDef and a Replace NetDef.
Summary: Split this into its own file for ease of reviewing. This is a simple interface for someone to create a Transform - by simply providing their own Pattern and Replace NetDefs.
Reviewed By: akyrola
Differential Revision: D5440426
fbshipit-source-id: dc643226f40ffe4ec5c86d56cfea374bd6a4e0e5
diff --git a/caffe2/contrib/transform/CMakeLists.txt b/caffe2/contrib/transform/CMakeLists.txt
index 91eb627..a9fc090 100644
--- a/caffe2/contrib/transform/CMakeLists.txt
+++ b/caffe2/contrib/transform/CMakeLists.txt
@@ -3,6 +3,7 @@
message(STATUS "Include Graph Transformations")
set(Caffe2_CONTRIB_TRANSFORMS_CPU_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/transform.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/pattern_net_transform.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/graph.cc"
)
diff --git a/caffe2/contrib/transform/graph.cc b/caffe2/contrib/transform/graph.cc
index 7ee7753..b7be937 100644
--- a/caffe2/contrib/transform/graph.cc
+++ b/caffe2/contrib/transform/graph.cc
@@ -27,8 +27,8 @@
auto it = edge_parent.find(blob);
if (it != edge_parent.end()) {
int j = it->second;
- node(i).parents[j] = blob;
- node(j).children[i] = blob;
+ node(i).parents[j].push_back(blob);
+ node(j).children[i].push_back(blob);
} else {
external_input_.insert(blob);
}
@@ -84,9 +84,11 @@
const auto& list = from_children ? node(x).children : node(x).parents;
for (const auto& edge : list) {
int parent = edge.first;
- const string& blob = edge.second;
+ const auto& blobs = edge.second;
if (match_set.count(parent)) { // but has a parent that is in subgraph
- edge_list.push_back({blob, x});
+ for (const string& blob : blobs) {
+ edge_list.push_back({blob, x});
+ }
}
}
}
@@ -190,4 +192,23 @@
} // namespace transform
+OperatorDef* AddOp(
+ NetDef* netdef_ptr,
+ string op_type,
+ std::vector<string> inputs,
+ std::vector<string> outputs) {
+ CHECK(netdef_ptr);
+ auto& netdef = *netdef_ptr;
+ auto op_ptr = netdef.add_op();
+ auto& op = *op_ptr;
+ op.set_type(op_type);
+ for (const string& inp : inputs) {
+ op.add_input(inp);
+ }
+ for (const string& outp : outputs) {
+ op.add_output(outp);
+ }
+ return op_ptr;
+}
+
} // namespace caffe2
diff --git a/caffe2/contrib/transform/graph.h b/caffe2/contrib/transform/graph.h
index 75c0ed7..e33bc5a 100644
--- a/caffe2/contrib/transform/graph.h
+++ b/caffe2/contrib/transform/graph.h
@@ -24,8 +24,8 @@
Node(
const OperatorDef& op,
bool active,
- std::map<int, string> parents,
- std::map<int, string> children)
+ std::map<int, std::vector<string>> parents,
+ std::map<int, std::vector<string>> children)
: op(op), active(active), parents(parents), children(children) {}
// The OperatorDef which this node represents.
@@ -34,9 +34,11 @@
// Keeps track of if an operator has been deleted through a transformation.
bool active = true;
- // Stores a pair (idx, blob), the index of the child, and the blob of edge.
- std::map<int, string> parents;
- std::map<int, string> children;
+ // Stores a pair (idx, blob_list),
+ // idx = index of the child
+ // blob_list = a list of strings, containing the blobs that connect the nodes
+ std::map<int, std::vector<string>> parents;
+ std::map<int, std::vector<string>> children;
};
/**
@@ -150,4 +152,12 @@
} // namespace transform
+// Adds an operator def to a netdef.
+// Returns the ptr, if you want to add anything extra (such as device_option)
+OperatorDef* AddOp(
+ NetDef* netdef_ptr,
+ string op_type,
+ std::vector<string> inputs,
+ std::vector<string> outputs);
+
} // namespace caffe2
diff --git a/caffe2/contrib/transform/graph_test.cc b/caffe2/contrib/transform/graph_test.cc
index 67d937b..e25cb86 100644
--- a/caffe2/contrib/transform/graph_test.cc
+++ b/caffe2/contrib/transform/graph_test.cc
@@ -42,27 +42,6 @@
.NumOutputs(0, INT_MAX)
.AllowInplace({{0, 0}, {1, 1}});
-// Adds an operator def to a netdef.
-// Returns the ptr, if you want to add anything extra (such as device_option)
-OperatorDef* AddOp(
- NetDef* netdef_ptr,
- string op_type,
- std::vector<string> inputs,
- std::vector<string> outputs) {
- CHECK(netdef_ptr);
- auto& netdef = *netdef_ptr;
- auto op_ptr = netdef.add_op();
- auto& op = *op_ptr;
- op.set_type(op_type);
- for (const string& inp : inputs) {
- op.add_input(inp);
- }
- for (const string& outp : outputs) {
- op.add_output(outp);
- }
- return op_ptr;
-}
-
// Checks if two netdefs are in terms of type, input, and output.
void compare_netdefs(const NetDef& net_a, const NetDef& net_b) {
EXPECT_EQ(net_a.op_size(), net_b.op_size());
diff --git a/caffe2/contrib/transform/transform.cc b/caffe2/contrib/transform/transform.cc
index 329f28a..b6d2271 100644
--- a/caffe2/contrib/transform/transform.cc
+++ b/caffe2/contrib/transform/transform.cc
@@ -39,7 +39,7 @@
void Transform::TryNeighbors(
const Graph& graph,
- const std::map<int, string>& neighbors,
+ const std::map<int, std::vector<string>>& neighbors,
std::vector<int>* subgraph_ptr,
std::vector<int>* best_subgraph_ptr) {
auto& subgraph = *subgraph_ptr;
@@ -83,9 +83,17 @@
void Transform::ReplacePattern(
const std::vector<vector<int>>& matches,
Graph* graph) {
- // Simply try to apply the replace rule upon every match.
for (const auto& match : matches) {
- if (!ReplaceRule(match, graph)) {
+ // Make sure each matched node is still active (not overwritten)
+ bool is_match_active = true;
+ for (int idx : match) {
+ if (!graph->is_node_active(idx)) {
+ is_match_active = false;
+ }
+ }
+
+ // Simply try to apply the replace rule upon every match.
+ if (is_match_active && !ReplaceRule(match, graph)) {
CAFFE_THROW("Replace failed!");
}
}
diff --git a/caffe2/contrib/transform/transform.h b/caffe2/contrib/transform/transform.h
index cbc3f1c..72e65fb 100644
--- a/caffe2/contrib/transform/transform.h
+++ b/caffe2/contrib/transform/transform.h
@@ -103,7 +103,7 @@
*/
void TryNeighbors(
const transform::Graph& graph,
- const std::map<int, string>& neighbors,
+ const std::map<int, std::vector<string>>& neighbors,
std::vector<int>* subgraph_ptr,
std::vector<int>* best_subgraph_ptr);
};
diff --git a/caffe2/contrib/transform/transform_test.cc b/caffe2/contrib/transform/transform_test.cc
index 37fe88e..726914f 100644
--- a/caffe2/contrib/transform/transform_test.cc
+++ b/caffe2/contrib/transform/transform_test.cc
@@ -61,15 +61,16 @@
new_op.set_type("DummyOp3");
int new_idx = g.size();
- std::map<int, string> new_op_children;
- std::map<int, string> new_op_parents;
+ std::map<int, std::vector<string>> new_op_children;
+ std::map<int, std::vector<string>> new_op_parents;
// for each node parent in the head of the match, connect it to our new node
for (const auto& edge : g.node(match[0]).parents) {
int parent = edge.first;
- string blob = edge.second;
- g.node(parent).children[new_idx] = blob;
- new_op_parents[parent] = blob;
+ for (const auto& blob : edge.second) {
+ g.node(parent).children[new_idx].push_back(blob);
+ new_op_parents[parent].push_back(blob);
+ }
}
for (const string& blob : g.node(match[0]).op.input()) {
new_op.add_input(blob);
@@ -78,9 +79,10 @@
// for each child in the tail of the match, connect it to our new node
for (const auto& edge : g.node(match[1]).children) {
int child = edge.first;
- string blob = edge.second;
- g.node(child).parents[new_idx] = blob;
- new_op_children[child] = blob;
+ for (const auto& blob : edge.second) {
+ g.node(child).parents[new_idx].push_back(blob);
+ new_op_children[child].push_back(blob);
+ }
}
for (const string& blob : g.node(match[1]).op.output()) {
new_op.add_output(blob);
@@ -98,27 +100,6 @@
REGISTER_TRANSFORM(DummySwap, DummyTransform)
-// Adds an operator def to a netdef.
-// Returns the ptr, if you want to add anything extra (such as device_option)
-OperatorDef* AddOp(
- NetDef* netdef_ptr,
- string op_type,
- std::vector<string> inputs,
- std::vector<string> outputs) {
- CHECK(netdef_ptr);
- auto& netdef = *netdef_ptr;
- auto op_ptr = netdef.add_op();
- auto& op = *op_ptr;
- op.set_type(op_type);
- for (const string& inp : inputs) {
- op.add_input(inp);
- }
- for (const string& outp : outputs) {
- op.add_output(outp);
- }
- return op_ptr;
-}
-
TEST(TransformTest, TestPatternMatch) {
Workspace ws;
ws.CreateBlob("in");
diff --git a/caffe2/contrib/transform/transforms/conv_to_nnpack_transform_test.cc b/caffe2/contrib/transform/transforms/conv_to_nnpack_transform_test.cc
index 5e5f691..f17485e 100644
--- a/caffe2/contrib/transform/transforms/conv_to_nnpack_transform_test.cc
+++ b/caffe2/contrib/transform/transforms/conv_to_nnpack_transform_test.cc
@@ -11,27 +11,6 @@
using transform::Graph;
-// Adds an operator def to a netdef.
-// Returns the ptr, if you want to add anything extra (such as device_option)
-OperatorDef* AddOp(
- NetDef* netdef_ptr,
- string op_type,
- std::vector<string> inputs,
- std::vector<string> outputs) {
- CHECK(netdef_ptr);
- auto& netdef = *netdef_ptr;
- auto op_ptr = netdef.add_op();
- auto& op = *op_ptr;
- op.set_type(op_type);
- for (const string& inp : inputs) {
- op.add_input(inp);
- }
- for (const string& outp : outputs) {
- op.add_output(outp);
- }
- return op_ptr;
-}
-
TEST(ConvToNNPackTest, TestSimple) {
NetDef netdef;
OperatorDef* op;
diff --git a/caffe2/contrib/transform/transforms/pattern_net_transform.cc b/caffe2/contrib/transform/transforms/pattern_net_transform.cc
new file mode 100644
index 0000000..cf63af2
--- /dev/null
+++ b/caffe2/contrib/transform/transforms/pattern_net_transform.cc
@@ -0,0 +1,247 @@
+#include "caffe2/contrib/transform/transforms/pattern_net_transform.h"
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/net.h"
+#include "caffe2/proto/caffe2.pb.h"
+
+namespace caffe2 {
+
+// First, single source traverse through the netdef.
+// This ensures all newly ordered are reachable from their prefix subset
+// Outputs a permutation of the operators.
+std::vector<int> PatternNetTransform::GetPatternTraversalOrder(
+ const transform::Graph& graph) {
+ std::vector<bool> visited(graph.size(), false);
+ std::vector<int> ordered_ops;
+ std::queue<int> q;
+ if (graph.size() > 0) {
+ q.push(0);
+ ordered_ops.push_back(0);
+ visited[0] = true;
+ }
+ while (!q.empty()) {
+ int idx = q.front();
+ q.pop();
+ for (const auto& edge : graph.node(idx).children) {
+ int x = edge.first;
+ if (!visited[x]) {
+ q.push(x);
+ ordered_ops.push_back(x);
+ visited[x] = true;
+ }
+ }
+ for (const auto& edge : graph.node(idx).parents) {
+ int x = edge.first;
+ if (!visited[x]) {
+ q.push(x);
+ ordered_ops.push_back(x);
+ visited[x] = true;
+ }
+ }
+ }
+ CAFFE_ENFORCE(
+ ordered_ops.size() == graph.size(), "Pattern graph must be connected.");
+ return ordered_ops;
+}
+
+bool compare_ops(const OperatorDef& x, const OperatorDef& y, bool arg_match) {
+ // make sure types are the same
+ if (x.type() != y.type()) {
+ return false;
+ }
+ if (x.input().size() != y.input().size()) {
+ return false;
+ }
+ if (x.output().size() != y.output().size()) {
+ return false;
+ }
+
+ // TODO(benz): make a comparison for device_option.
+ // make sure engine is the same (if specified in pattern)
+ if (x.has_engine() && x.engine() != y.engine()) {
+ return false;
+ }
+ // If argument_match is specified, make sure those are the same.
+ if (arg_match) {
+ if (x.arg().size() != y.arg().size()) {
+ return false;
+ }
+ // TODO(benz): Create arg equality operator.
+ }
+ return true;
+}
+
+// g.node(subgraph[i]) should match p_.node(ordered_ops_[i])
+// g.node(g_idx) should match p_.node(p_idx)
+bool PatternNetTransform::PatternRule(
+ const transform::Graph& g,
+ const std::vector<int>& subgraph,
+ int g_idx) {
+ if (subgraph.size() >= ordered_ops_.size()) {
+ return false;
+ }
+ int p_idx = ordered_ops_[subgraph.size()];
+
+ if (!compare_ops(p_.node(p_idx).op, g.node(g_idx).op, argument_match_)) {
+ return false;
+ }
+
+ // Let's say ordered_ops_ is [0, 2, 1], with 0 -> 2 being an edge
+ // When we try to match onto the second element, let's say our
+ // subgraph so far is [4], with it trying to become [4, 5].
+ // Then, we need to show that since 0 -> 2 is an edge is ordered_ops_,
+ // 4 must be a direct parent of 5 in the subgraph
+ // (the indices must match).
+ // Similarly, assume there is an edge from 1 -> 2 in p_.
+ // When trying to match [4, 5] to [4, 5, 7], we must verify that
+ // there exists an edge from 7 -> 5 in G.
+ for (const auto& edge : p_.node(p_idx).parents) {
+ int parent = edge.first;
+ // g_idx doesn't have parent in subgraph that p_[p_idx] has
+ // inverse_ops_ gets the index of a p_idx inside of ordered_ops_.
+ if (inverse_ops_[parent] < subgraph.size() &&
+ g.node(g_idx).parents.count(subgraph[inverse_ops_[parent]]) == 0) {
+ return false;
+ }
+ }
+
+ for (const auto& edge : p_.node(p_idx).children) {
+ int child = edge.first;
+ if (inverse_ops_[child] < subgraph.size() &&
+ g.node(g_idx).children.count(subgraph[inverse_ops_[child]]) == 0) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool PatternNetTransform::ValidatorRule(
+ const transform::Graph& g,
+ const std::vector<int>& subgraph) {
+ // Due to strict PatternRule, it suffices to simply check for size
+ return subgraph.size() == p_.size();
+}
+
+bool PatternNetTransform::ReplaceRule(
+ const std::vector<int>& match,
+ transform::Graph* g_ptr) {
+ CHECK(g_ptr);
+ auto& g = *g_ptr;
+
+ ssa_id_++;
+
+ // Map of PatternNet blob name to Matched blob name.
+ // Figures out how to rename the pattern_net to make the replacement fit.
+ std::unordered_map<string, string> external_renaming;
+
+ // Figure out blob renamings
+ for (int i = 0; i < match.size(); i++) {
+ int g_idx = match[i];
+ int p_idx = ordered_ops_[i];
+ for (int j = 0; j < p_.node(p_idx).op.input().size(); j++) {
+ string p_blob = p_.node(p_idx).op.input(j);
+ string g_blob = g.node(g_idx).op.input(j);
+ if (p_.external_input().count(p_blob)) {
+ external_renaming[p_blob] = g_blob;
+ }
+ }
+ for (int j = 0; j < p_.node(p_idx).op.output().size(); j++) {
+ string p_blob = p_.node(p_idx).op.output(j);
+ string g_blob = g.node(g_idx).op.output(j);
+ if (p_.external_output().count(p_blob)) {
+ external_renaming[p_blob] = g_blob;
+ }
+ }
+ }
+
+ auto input_list = g.GetSubgraphInput(match);
+ auto output_list = g.GetSubgraphOutput(match);
+
+ g.DeactivateSubgraph(match);
+
+ int offset = g.size();
+
+ g.resize_nodes(offset + r_.size());
+
+ // Append all the new operators.
+ for (int i = 0; i < r_.size(); i++) {
+ int new_node_idx = offset + i;
+
+ OperatorDef new_op = r_.node(i).op;
+
+ new_op.clear_input();
+ new_op.clear_output();
+ // Stitch Input from external graph into replaced subgraph
+ for (const auto& blob : r_.node(i).op.input()) {
+ if (external_renaming.count(blob)) {
+ string new_blob = external_renaming[blob];
+ new_op.add_input(new_blob);
+
+ // binary searches for new_blob amongst input list.
+ auto it = std::lower_bound(
+ input_list.begin(), input_list.end(), std::make_pair(new_blob, -1));
+
+ // if the input came from the graph (instead of G's external input)
+ for (; it < input_list.end() && it->first == new_blob; it++) {
+ int parent = it->second;
+ g.node(parent).children[new_node_idx].push_back(new_blob);
+ g.node(new_node_idx).parents[parent].push_back(new_blob);
+ }
+ } else {
+ new_op.add_input(TransformBlobWrapper(blob));
+ }
+ }
+ // Stitch Output from replaced subgraph to external graph.
+ for (const auto& blob : r_.node(i).op.output()) {
+ if (external_renaming.count(blob)) {
+ string new_blob = external_renaming[blob];
+ new_op.add_output(new_blob);
+
+ // binary searches for new_blob amongst input list.
+ auto it = std::lower_bound(
+ output_list.begin(),
+ output_list.end(),
+ std::make_pair(new_blob, -1));
+
+ // if the output goes to the graph (instead of G's external output)
+ for (; it < output_list.end() && it->first == new_blob; it++) {
+ int child = it->second;
+ g.node(child).parents[new_node_idx].push_back(new_blob);
+ g.node(new_node_idx).children[child].push_back(new_blob);
+ }
+ } else {
+ new_op.add_output(TransformBlobWrapper(blob));
+ }
+ }
+
+ // Connect all internal edges within replace graph
+ for (const auto& edge : r_.node(i).parents) {
+ int parent = edge.first;
+ int new_node_parent = offset + parent;
+ const auto& blobs = edge.second;
+ for (const string& blob : blobs) {
+ g.node(new_node_idx)
+ .parents[new_node_parent]
+ .push_back(TransformBlobWrapper(blob));
+ }
+ }
+
+ for (const auto& edge : r_.node(i).children) {
+ int child = edge.first;
+ int new_node_child = offset + child;
+ const auto& blobs = edge.second;
+ for (const string& blob : blobs) {
+ g.node(offset + i)
+ .children[new_node_child]
+ .push_back(TransformBlobWrapper(blob));
+ }
+ }
+
+ g.node(new_node_idx).op = new_op;
+ g.node(new_node_idx).active = true;
+ }
+ return true;
+}
+
+} // namespace Caffe2
diff --git a/caffe2/contrib/transform/transforms/pattern_net_transform.h b/caffe2/contrib/transform/transforms/pattern_net_transform.h
new file mode 100644
index 0000000..e89a67e
--- /dev/null
+++ b/caffe2/contrib/transform/transforms/pattern_net_transform.h
@@ -0,0 +1,126 @@
+#pragma once
+
+#include "caffe2/contrib/transform/transform.h"
+#include "caffe2/core/common.h"
+#include "caffe2/proto/caffe2.pb.h"
+#include "caffe2/utils/proto_utils.h"
+
+namespace caffe2 {
+
+/**
+ * PatternNetTransform allows you to create transforms using a simple
+ * interface.
+ *
+ * Simply provide a Pattern NetDef and a Replace NetDef,
+ * and this Transform will find subgraphs which fit the pattern net,
+ * and replace it with the replace net.
+ */
+class PatternNetTransform : public Transform {
+ public:
+ PatternNetTransform(const NetDef& pattern_net, const NetDef& replace_net)
+ : p_(transform::Graph(pattern_net)), r_(transform::Graph(replace_net)) {
+ // external input and output must match!
+ CAFFE_ENFORCE(
+ p_.external_input() == r_.external_input(),
+ "External inputs do not match!");
+ CAFFE_ENFORCE(
+ p_.external_output() == r_.external_output(),
+ "External outputs do not match!");
+ ordered_ops_ = GetPatternTraversalOrder(p_);
+ inverse_ops_.resize(ordered_ops_.size());
+ for (int i = 0; i < ordered_ops_.size(); i++) {
+ inverse_ops_[ordered_ops_[i]] = i;
+ }
+ }
+
+ protected:
+ /**
+ * We want to the final result of subgraph to match the PatternNet in the
+ * order of ordered_ops, operator by operator.
+ *
+ * [[[ ie. g.node(subgraph[i]) should match p.node(ordered_ops[i]) ]]]
+ *
+ * PatternRule for PatternNetTransform does the following:
+ *
+ * When trying to insert node idx into subgraph[p_idx],
+ * we need to see if the edges between index and the
+ * subgraph match the edges between p[ordered_ops[idx]]
+ * and p[ordered_ops[0]...ordered_ops[p_idx-1]].
+ */
+ bool PatternRule(
+ const transform::Graph& g,
+ const std::vector<int>& subgraph,
+ int idx) override;
+ /**
+ * ValidatorRule for PatternNetTransform does the following:
+ *
+ * Checks if the size of subgraph and p.size() are the same. That's it!
+ */
+ bool ValidatorRule(
+ const transform::Graph& g,
+ const std::vector<int>& subgraph) override;
+ /**
+ * ReplaceRule for PatternNet Transform does the following:
+ *
+ * 1) Figure out edge renamings for edges going into/out of the subgraph.
+ * That is, for each blob in the pattern graph, what is it called in the
+ * matched subgraph?
+ *
+ * 2) Remove the matched subgraph.
+ *
+ * 3) Append the replace graph's operators to the graph's operators, and use
+ * the renamings to rename the blob names.
+ *
+ * 4) Create all the children/parent relationships within the replaced graph,
+ * and stitch together the inputs and outputs into the rest of the graph,
+ * matching the removed subgraph.
+ */
+ bool ReplaceRule(const std::vector<int>& subgraph, transform::Graph* g_ptr)
+ override;
+
+ private:
+ /**
+ * This returns a permutation of the Pattern Net's operators.
+ * The permutation satisfies this property:
+ * - For any index i, order(i) is a neighbor of some node from
+ * {order(1), ..., order(i-1)}.
+ *
+ * Why is this important? Consider the following case:
+ * PatternNet: 0 ---> 2 <--- 1
+ *
+ * When we have matched onto [0], and trying to add [1] to our subgraph,
+ * we cannot, since PatternMatch only considers neighbors of the current
+ * subgraph as a candidate next node.
+ *
+ * Therefore, we must present the subgraph in an order such that each node is
+ * a neighbor of its prefix subgraph. One ordering for the above example is
+ * [0, 2, 1].
+ */
+ std::vector<int> GetPatternTraversalOrder(const transform::Graph& g);
+
+ // Graph of Pattern NetDef
+ transform::Graph p_;
+
+ // The Traversal Order of the Pattern Net's Operators
+ // This is a permutation of the numbers from {0, ..., p.size()-1}
+ std::vector<int> ordered_ops_;
+
+ // The Inverse of the Traversal Order of the Pattern Net's Operators
+ // That is, inverse_ops[ordered_ops[i]] == i is always true.
+ std::vector<int> inverse_ops_;
+
+ // Graph of Replace NetDef
+ transform::Graph r_;
+
+ // This flag determines if the transform will match operator arguments.
+ // TODO(benz): Write a good argument comparator
+ bool argument_match_ = false;
+
+ const string TransformBlobWrapper(const string& blob_name) {
+ return "transform/" + blob_name + "_" + caffe2::to_string(ssa_id_);
+ }
+
+ int ssa_id_ = 0;
+};
+
+} // namespace caffe2
diff --git a/caffe2/contrib/transform/transforms/pattern_net_transform_test.cc b/caffe2/contrib/transform/transforms/pattern_net_transform_test.cc
new file mode 100644
index 0000000..32d36e9
--- /dev/null
+++ b/caffe2/contrib/transform/transforms/pattern_net_transform_test.cc
@@ -0,0 +1,348 @@
+#include <google/protobuf/text_format.h>
+#include <gtest/gtest.h>
+#include "caffe2/contrib/transform/transforms/pattern_net_transform.h"
+#include "caffe2/core/net.h"
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+namespace {
+
+using transform::Graph;
+
+static std::atomic<int> counter;
+
+class DummyCounterOp final : public OperatorBase {
+ public:
+ using OperatorBase::OperatorBase;
+ bool Run(int /* unused */) override {
+ counter.fetch_add(1);
+ return true;
+ }
+};
+
+REGISTER_CPU_OPERATOR(DummyCounterOp1, DummyCounterOp);
+REGISTER_CUDA_OPERATOR(DummyCounterOp1, DummyCounterOp);
+
+OPERATOR_SCHEMA(DummyCounterOp1)
+ .NumInputs(0, INT_MAX)
+ .NumOutputs(0, INT_MAX)
+ .AllowInplace({{0, 0}, {1, 1}});
+
+REGISTER_CPU_OPERATOR(DummyCounterOp2, DummyCounterOp);
+REGISTER_CUDA_OPERATOR(DummyCounterOp2, DummyCounterOp);
+
+OPERATOR_SCHEMA(DummyCounterOp2)
+ .NumInputs(0, INT_MAX)
+ .NumOutputs(0, INT_MAX)
+ .AllowInplace({{0, 0}, {1, 1}});
+
+REGISTER_CPU_OPERATOR(DummyCounterOp3, DummyCounterOp);
+REGISTER_CUDA_OPERATOR(DummyCounterOp3, DummyCounterOp);
+
+OPERATOR_SCHEMA(DummyCounterOp3)
+ .NumInputs(0, INT_MAX)
+ .NumOutputs(0, INT_MAX)
+ .AllowInplace({{0, 0}, {1, 1}});
+
+/**
+ * P = ---> (Op1) ---> (Op2) --->
+ *
+ * R = ---> (Op3) ---> (Op3) --->
+ */
+TEST(PatternNetTransformTest, TestGenerateTransform) {
+ Workspace ws;
+ ws.CreateBlob("in");
+
+ NetDef netdef;
+ AddOp(&netdef, "DummyCounterOp1", {"in"}, {"mid1"});
+ AddOp(&netdef, "DummyCounterOp2", {"mid1"}, {"mid2"});
+ AddOp(&netdef, "DummyCounterOp1", {"mid2"}, {"mid3"});
+ AddOp(&netdef, "DummyCounterOp2", {"mid3"}, {"out"});
+
+ NetDef pdef;
+ AddOp(&pdef, "DummyCounterOp1", {"in"}, {"mid"});
+ AddOp(&pdef, "DummyCounterOp2", {"mid"}, {"out"});
+
+ NetDef rdef;
+ AddOp(&rdef, "DummyCounterOp3", {"in"}, {"new_mid"});
+ AddOp(&rdef, "DummyCounterOp3", {"new_mid"}, {"out"});
+
+ PatternNetTransform t(pdef, rdef);
+
+ // test pattern match
+ Graph g(netdef);
+
+ auto matches = t.PatternMatch(g);
+ EXPECT_EQ(matches.size(), 2);
+
+ t.ReplacePattern(matches, &g);
+
+ EXPECT_EQ(g.size(), 8);
+ for (int i = 0; i < 4; i++) {
+ EXPECT_FALSE(g.is_node_active(i));
+ }
+ for (int i = 4; i < 8; i++) {
+ EXPECT_TRUE(g.is_node_active(i));
+ }
+
+ EXPECT_TRUE(g.node(4).children.count(5));
+ EXPECT_TRUE(g.node(5).children.count(6));
+ EXPECT_TRUE(g.node(6).children.count(7));
+
+ for (int i = 4; i < 8; i++) {
+ EXPECT_EQ(g.node(i).op.input().size(), 1);
+ EXPECT_EQ(g.node(i).op.output().size(), 1);
+ }
+
+ NetDef replaced_netdef = g.GetNetDef();
+
+ EXPECT_EQ(replaced_netdef.op().size(), 4);
+ EXPECT_EQ(replaced_netdef.op(0).type(), "DummyCounterOp3");
+ EXPECT_EQ(replaced_netdef.op(1).type(), "DummyCounterOp3");
+ EXPECT_EQ(replaced_netdef.op(2).type(), "DummyCounterOp3");
+ EXPECT_EQ(replaced_netdef.op(3).type(), "DummyCounterOp3");
+}
+
+/**
+ * P = ---> (Op1) ---> (Op2) --->
+ *
+ * R = ---> (Op3) ---> (Op3) --->
+ */
+TEST(PatternNetTransformTest, TestRepeatedTransform) {
+ Workspace ws;
+ ws.CreateBlob("in");
+
+ NetDef netdef;
+ AddOp(&netdef, "DummyCounterOp1", {"in"}, {"out"});
+ AddOp(&netdef, "DummyCounterOp2", {"out"}, {"out"});
+ for (int i = 0; i < 99; i++) {
+ AddOp(&netdef, "DummyCounterOp1", {"out"}, {"out"});
+ AddOp(&netdef, "DummyCounterOp2", {"out"}, {"out"});
+ }
+
+ NetDef pdef;
+ AddOp(&pdef, "DummyCounterOp1", {"in"}, {"mid"});
+ AddOp(&pdef, "DummyCounterOp2", {"mid"}, {"out"});
+
+ NetDef rdef;
+ AddOp(&rdef, "DummyCounterOp3", {"in"}, {"new_mid"});
+ AddOp(&rdef, "DummyCounterOp3", {"new_mid"}, {"out"});
+
+ PatternNetTransform t(pdef, rdef);
+
+ // test pattern match
+ Graph g(netdef);
+
+ auto matches = t.PatternMatch(g);
+ EXPECT_EQ(matches.size(), 100);
+
+ t.ReplacePattern(matches, &g);
+ NetDef replaced_netdef = g.GetNetDef();
+
+ EXPECT_EQ(replaced_netdef.op_size(), 200);
+ for (int i = 0; i < 200; i++) {
+ EXPECT_EQ(replaced_netdef.op(i).type(), "DummyCounterOp3");
+ }
+
+ unique_ptr<NetBase> net = CreateNet(replaced_netdef, &ws);
+ counter.exchange(0);
+ net.get()->Run();
+ EXPECT_EQ(200, counter.load());
+}
+
+/**
+ * P = ---> (Op1) ---> (Op3) ---> (Op2) --->
+ * |------> (Op3) -------|
+ *
+ * R = ---> (Op1) --------------> (Op3) --->
+ * |_(Op3)-->(Op3)-->(Op2)_|
+ *
+ */
+TEST(PatternNetTransformTest, TestHardTransform) {
+ Workspace ws;
+ ws.CreateBlob("in");
+
+ NetDef netdef;
+ // Segment 1 (differs from P because of type)
+ AddOp(&netdef, "DummyCounterOp1", {"in"}, {"mid1a_1", "mid1b_1"});
+ AddOp(&netdef, "DummyCounterOp2", {"mid1a_1"}, {"mid2a_1"});
+ AddOp(&netdef, "DummyCounterOp3", {"mid1b_1"}, {"mid2b_1"});
+ AddOp(&netdef, "DummyCounterOp3", {"mid2a_1", "mid2b_1"}, {"out_1"});
+
+ // Segment 2 (differs from P because of structure)
+ AddOp(
+ &netdef, "DummyCounterOp1", {"out_1"}, {"mid1a_2", "mid1b_2", "mid1c_2"});
+ AddOp(&netdef, "DummyCounterOp3", {"mid1a_2"}, {"mid2a_2"});
+ AddOp(&netdef, "DummyCounterOp3", {"mid1b_2"}, {"mid2b_2"});
+ AddOp(&netdef, "DummyCounterOp3", {"mid1c_2"}, {"mid2c_2"});
+ AddOp(
+ &netdef, "DummyCounterOp2", {"mid2a_2", "mid2b_2", "mid2c_2"}, {"out_2"});
+
+ // Segment 3
+ AddOp(&netdef, "DummyCounterOp1", {"out_2"}, {"mid1a_3", "mid1b_3"});
+ AddOp(&netdef, "DummyCounterOp3", {"mid1a_3"}, {"mid2a_3"});
+ AddOp(&netdef, "DummyCounterOp3", {"mid1b_3"}, {"mid2b_3"});
+ AddOp(&netdef, "DummyCounterOp2", {"mid2a_3", "mid2b_3"}, {"out"});
+
+ NetDef pdef;
+ // Should only match Segment 3
+ AddOp(&pdef, "DummyCounterOp1", {"sub_in"}, {"mid1a", "mid1b"});
+ AddOp(&pdef, "DummyCounterOp3", {"mid1a"}, {"mid2a"});
+ AddOp(&pdef, "DummyCounterOp3", {"mid1b"}, {"mid2b"});
+ AddOp(&pdef, "DummyCounterOp2", {"mid2a", "mid2b"}, {"sub_out"});
+
+ NetDef rdef;
+ AddOp(&rdef, "DummyCounterOp1", {"sub_in"}, {"mid1a", "mid1b"});
+ AddOp(&rdef, "DummyCounterOp3", {"mid1b"}, {"mid2b"});
+ AddOp(&rdef, "DummyCounterOp3", {"mid2b"}, {"mid3b"});
+ AddOp(&rdef, "DummyCounterOp2", {"mid3b"}, {"mid4b"});
+ AddOp(&rdef, "DummyCounterOp3", {"mid1a", "mid4b"}, {"sub_out"});
+
+ PatternNetTransform t(pdef, rdef);
+ Graph g(netdef);
+ EXPECT_EQ(g.size(), 13);
+
+ auto matches = t.PatternMatch(g);
+ EXPECT_EQ(matches.size(), 1);
+
+ t.ReplacePattern(matches, &g);
+ EXPECT_EQ(g.size(), 18);
+
+ NetDef replaced_netdef = g.GetNetDef();
+ EXPECT_EQ(replaced_netdef.op_size(), 14);
+ unique_ptr<NetBase> net = CreateNet(replaced_netdef, &ws);
+ counter.exchange(0);
+ net.get()->Run();
+ EXPECT_EQ(14, counter.load());
+}
+
+/**
+ * |--(Op2)--|
+ * P = --->(Op1)----->(Op3)--->
+ * |--(Op2)--|
+ *
+ * R = ---> (Op2) --->
+ *
+ * |--(Op2)--|
+ * -->(Op1)----->(Op3)---
+ * | |--(Op2)--| |
+ * G = ---> (Op1) (Op3) --->
+ * | |--(Op2)--| |
+ * -->(Op1)----->(Op3)--
+ * |--(Op2)--|
+ *
+ * In this test, the two "parallel" modules have intersecting execution orders.
+ * We wish to test that the pattern match can still detect the two modules,
+ * separately.
+ *
+ * Furthermore, we will apply the transform to G, TWICE.
+ * It should reduce G to a single operator.
+ */
+TEST(PatternNetTransformTest, TestNonStrictTopographicTransform) {
+ Workspace ws;
+ ws.CreateBlob("in");
+
+ NetDef netdef;
+ // Head
+ AddOp(&netdef, "DummyCounterOp1", {"in"}, {"in_1", "in_2"});
+
+ // 2 intertwined segments, each matching P. No strict ordering.
+ AddOp(&netdef, "DummyCounterOp1", {"in_1"}, {"m1_1", "m2_1"});
+ AddOp(&netdef, "DummyCounterOp1", {"in_2"}, {"m1_2", "m2_2"});
+ AddOp(&netdef, "DummyCounterOp2", {"m1_1"}, {"out1_1"});
+ AddOp(&netdef, "DummyCounterOp2", {"m1_2"}, {"out1_2"});
+ AddOp(&netdef, "DummyCounterOp2", {"m2_1"}, {"out2_1"});
+ AddOp(&netdef, "DummyCounterOp2", {"m2_2"}, {"out2_2"});
+ AddOp(&netdef, "DummyCounterOp3", {"out1_1", "out2_1"}, {"out1"});
+ AddOp(&netdef, "DummyCounterOp3", {"out1_2", "out2_2"}, {"out2"});
+
+ // Tail
+ AddOp(&netdef, "DummyCounterOp3", {"out1", "out2"}, {"out"});
+
+ NetDef pdef;
+ AddOp(&pdef, "DummyCounterOp1", {"myin"}, {"mid1a", "mid1b"});
+ AddOp(&pdef, "DummyCounterOp2", {"mid1a"}, {"mid2a"});
+ AddOp(&pdef, "DummyCounterOp2", {"mid1b"}, {"mid2b"});
+ AddOp(&pdef, "DummyCounterOp3", {"mid2a", "mid2b"}, {"myout"});
+
+ NetDef rdef;
+ AddOp(&rdef, "DummyCounterOp2", {"myin"}, {"myout"});
+
+ PatternNetTransform t(pdef, rdef);
+
+ NetDef replaced_netdef = t.ApplyTo(netdef);
+ EXPECT_EQ(replaced_netdef.op_size(), 4);
+ unique_ptr<NetBase> net = CreateNet(replaced_netdef, &ws);
+ counter.exchange(0);
+ net.get()->Run();
+ EXPECT_EQ(4, counter.load());
+
+ // apply the transform again
+ // the entire net should get transformed this time
+ NetDef double_transformed_net = t.ApplyTo(replaced_netdef);
+ EXPECT_EQ(double_transformed_net.op_size(), 1);
+}
+
+/**
+ * --->(Op1)----->(Op2)--->
+ * | ^
+ * P = |----------|
+ * | v
+ * --->(Op1)----->(Op2)--->
+ *
+ * R = ---> (Op3) --->
+ *
+ * G = P -> P
+ *
+ * In this test, we fuse a subgraph with two inputs and two outputs, into one
+ * operator.
+ *
+ * This will ensure that we can allow a single edge to represent
+ * multiple blob names (the input and output of R are both 2 blobs).
+ *
+ * This will also ensure that patternmatch can traverse "backwards", from a node
+ * to its parent.
+ *
+ * Furthermore, this tests for repeat matches, since matching on either of the
+ * first two Op1 nodes will produce a match, but they are identical.
+ * So, the pattern should match 4 times, but only be replaced twice.
+ */
+TEST(PatternNetTransformTest, TestMultiInputOutputTransform) {
+ Workspace ws;
+ ws.CreateBlob("in1");
+ ws.CreateBlob("in2");
+
+ NetDef netdef;
+ AddOp(&netdef, "DummyCounterOp1", {"in1"}, {"in1"}); // has 2 children
+ AddOp(&netdef, "DummyCounterOp1", {"in2"}, {"in2"}); // has 2 children
+ AddOp(&netdef, "DummyCounterOp2", {"in1", "in2"}, {"mid1"});
+ AddOp(&netdef, "DummyCounterOp2", {"in1", "in2"}, {"mid2"});
+ AddOp(&netdef, "DummyCounterOp1", {"mid1"}, {"mid1"}); // has 2 children
+ AddOp(&netdef, "DummyCounterOp1", {"mid2"}, {"mid2"}); // has 2 children
+ AddOp(&netdef, "DummyCounterOp2", {"mid1", "mid2"}, {"out1"});
+ AddOp(&netdef, "DummyCounterOp2", {"mid1", "mid2"}, {"out2"});
+
+ NetDef pdef;
+ AddOp(&pdef, "DummyCounterOp1", {"subin1"}, {"subin1"}); // has 2 children
+ AddOp(&pdef, "DummyCounterOp1", {"subin2"}, {"subin2"}); // has 2 children
+ AddOp(&pdef, "DummyCounterOp2", {"subin1", "subin2"}, {"subout1"});
+ AddOp(&pdef, "DummyCounterOp2", {"subin1", "subin2"}, {"subout2"});
+
+ NetDef rdef;
+ AddOp(&rdef, "DummyCounterOp3", {"subin1", "subin2"}, {"subout1", "subout2"});
+
+ PatternNetTransform t(pdef, rdef);
+ Graph g(netdef);
+
+ NetDef replaced_netdef = t.ApplyTo(netdef);
+ EXPECT_EQ(replaced_netdef.op_size(), 2);
+ unique_ptr<NetBase> net = CreateNet(replaced_netdef, &ws);
+ counter.exchange(0);
+ net.get()->Run();
+ EXPECT_EQ(2, counter.load());
+}
+
+} // namespace
+
+} // namespace Caffe2