create a cse pass, with very naive support.
diff --git a/setup.py b/setup.py
index 02acb17..4eda29c 100644
--- a/setup.py
+++ b/setup.py
@@ -381,6 +381,7 @@
"torch/csrc/jit/passes/graph_fuser.cpp",
"torch/csrc/jit/passes/onnx.cpp",
"torch/csrc/jit/passes/dead_code_elimination.cpp",
+ "torch/csrc/jit/passes/common_subexpression_elimination.cpp",
"torch/csrc/autograd/init.cpp",
"torch/csrc/autograd/engine.cpp",
"torch/csrc/autograd/function.cpp",
diff --git a/test/expect/TestJit.test_cse.expect b/test/expect/TestJit.test_cse.expect
new file mode 100644
index 0000000..c7c0160
--- /dev/null
+++ b/test/expect/TestJit.test_cse.expect
@@ -0,0 +1,6 @@
+graph(%1 : Double(2)
+ %2 : Double(2)) {
+ %3 : Double(2) = Add(%1, %2), uses = [%5.i0, %5.i1];
+ %5 : Double(2) = Mul(%3, %3), uses = [%0.i0];
+ return (%5);
+}
diff --git a/test/test_jit.py b/test/test_jit.py
index de9720a..a8f7f86 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -98,6 +98,20 @@
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
+ def test_cse(self):
+ x = Variable(torch.Tensor([0.4, 0.3]), requires_grad=True)
+ y = Variable(torch.Tensor([0.7, 0.5]), requires_grad=True)
+
+ trace = torch._C._tracer_enter((x, y), 0)
+ z = (x + y) * (x + y)
+ torch._C._tracer_exit((z,))
+ torch._C._jit_pass_lint(trace)
+ torch._C._jit_pass_onnx(trace)
+ torch._C._jit_pass_lint(trace)
+ torch._C._jit_pass_cse(trace)
+
+ self.assertExpected(str(trace))
+
def test_verify(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index b7778a4..916f6d0 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -6,6 +6,7 @@
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/onnx.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
@@ -39,6 +40,7 @@
.def("_jit_pass_onnx", ToONNX)
.def("_jit_pass_fuse", graph_pass<FuseGraph>)
.def("_jit_pass_dce", graph_pass<EliminateDeadCode>)
+ .def("_jit_pass_cse", graph_pass<EliminateCommonSubexpression>)
.def("_jit_pass_lint", graph_pass<LintGraph>)
.def("_jit_run_cpp_tests", runJITCPPTests);
diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
new file mode 100644
index 0000000..d9dd6b1
--- /dev/null
+++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
@@ -0,0 +1,105 @@
+#include "torch/csrc/jit/ir.h"
+
+#include <algorithm>
+#include <unordered_map>
+
+#include "torch/csrc/jit/interned_strings.h"
+#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
+
+namespace torch { namespace jit {
+
+struct HashNodeCSE {
+ std::size_t operator()(const Node* k) const {
+ JIT_ASSERT(k != nullptr);
+ std::size_t p = 31; // A prime.
+ std::size_t h = k->kind() * p + k->stage();
+ for (auto i : k->inputs()) {
+ h = h * p + i->unique();
+ }
+ return h;
+ }
+};
+
+struct EqualNodeCSE {
+ bool operator()(const Node* lhs, const Node* rhs) const {
+ if (lhs == nullptr && rhs == nullptr) return true;
+ if (lhs == nullptr || rhs == nullptr) return false;
+
+ // Check whether two nodes are the same kind.
+ if (lhs->kind() != rhs->kind()) return false;
+
+ // Check the stage.
+ if (lhs->stage() != rhs->stage()) return false;
+
+ // TODO check the device.
+
+ // Check whether the inputs are the same.
+ if (lhs->inputs().size() != rhs->inputs().size()) return false;
+
+ if (!std::equal(lhs->inputs().begin(), lhs->inputs().end(), rhs->inputs().begin())) return false;
+
+ // Check the attributes.
+ // TODO support attributes comparison.
+ if (lhs->hasAttributes() || rhs->hasAttributes()) return false;
+
+ return true;
+ }
+};
+
+void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
+ // Keep iterating until reach the fixed point.
+ bool reach_fixed = false;
+ while (!reach_fixed) {
+ reach_fixed = true;
+ auto nodes = graph->nodes();
+ std::unordered_set<Node*, HashNodeCSE, EqualNodeCSE> subexprs;
+ for (auto it = nodes.begin(); it != nodes.end(); ++ it) {
+ auto node = *it;
+ if (node->kind() != kAdd
+ && node->kind() != kMul
+ && node->kind() != kNeg
+ && node->kind() != kSigmoid
+ && node->kind() != kTanh
+ && node->kind() != kSplit
+ && node->kind() != kAddConstant
+ ) {
+ // TODO support more kinds of nodes.
+ // Only support CSE on these nodes.
+ continue;
+ }
+
+ // Check whether the same subexpression already exists.
+ if (subexprs.find(node) == subexprs.end()) {
+ // If not put it into the map
+ subexprs.insert(node);
+ } else {
+ // Subexpression exists, replace the uses of node, and destory it.
+ auto existing = *subexprs.find(node);
+ const use_list & uses = node->uses();
+ const use_list & reuses= existing->uses();
+ if (node->hasMultipleOutputs()) {
+ // For Multi-Output nodes, all its uses should be Select nodes.
+ JIT_ASSERT(uses.size() == reuses.size());
+ // Replace the uses of Select nodes.
+ for (size_t i = 0; i < uses.size(); ++ i) {
+ JIT_ASSERT(uses[i].user->kind() == kSelect);
+ JIT_ASSERT(reuses[i].user->kind() == kSelect);
+ uses[i].user->replaceAllUsesWith(reuses[i].user);
+ }
+ // Destroy Select nodes.
+ while (uses.size() > 0) {
+ uses[0].user->destroy();
+ }
+ } else {
+ node->replaceAllUsesWith(existing);
+ }
+ // Destroy the node.
+ node->destroy();
+ reach_fixed = false;
+ break;
+ }
+ }
+ }
+}
+
+}}
diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.h b/torch/csrc/jit/passes/common_subexpression_elimination.h
new file mode 100644
index 0000000..483c573
--- /dev/null
+++ b/torch/csrc/jit/passes/common_subexpression_elimination.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "torch/csrc/jit/ir.h"
+
+namespace torch { namespace jit {
+
+void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph);
+
+}}