create a cse pass, with very naive support.
diff --git a/setup.py b/setup.py
index 02acb17..4eda29c 100644
--- a/setup.py
+++ b/setup.py
@@ -381,6 +381,7 @@
     "torch/csrc/jit/passes/graph_fuser.cpp",
     "torch/csrc/jit/passes/onnx.cpp",
     "torch/csrc/jit/passes/dead_code_elimination.cpp",
+    "torch/csrc/jit/passes/common_subexpression_elimination.cpp",
     "torch/csrc/autograd/init.cpp",
     "torch/csrc/autograd/engine.cpp",
     "torch/csrc/autograd/function.cpp",
diff --git a/test/expect/TestJit.test_cse.expect b/test/expect/TestJit.test_cse.expect
new file mode 100644
index 0000000..c7c0160
--- /dev/null
+++ b/test/expect/TestJit.test_cse.expect
@@ -0,0 +1,6 @@
+graph(%1 : Double(2)
+      %2 : Double(2)) {
+  %3 : Double(2) = Add(%1, %2), uses = [%5.i0, %5.i1];
+  %5 : Double(2) = Mul(%3, %3), uses = [%0.i0];
+  return (%5);
+}
diff --git a/test/test_jit.py b/test/test_jit.py
index de9720a..a8f7f86 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -98,6 +98,20 @@
         torch._C._jit_pass_lint(trace)
         self.assertExpected(str(trace))
 
+    def test_cse(self):
+        x = Variable(torch.Tensor([0.4, 0.3]), requires_grad=True)
+        y = Variable(torch.Tensor([0.7, 0.5]), requires_grad=True)
+
+        trace = torch._C._tracer_enter((x, y), 0)
+        z = (x + y) * (x + y)
+        torch._C._tracer_exit((z,))
+        torch._C._jit_pass_lint(trace)
+        torch._C._jit_pass_onnx(trace)
+        torch._C._jit_pass_lint(trace)
+        torch._C._jit_pass_cse(trace)
+
+        self.assertExpected(str(trace))
+
     def test_verify(self):
         x = Variable(torch.Tensor([0.4]), requires_grad=True)
         y = Variable(torch.Tensor([0.7]), requires_grad=True)
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index b7778a4..916f6d0 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -6,6 +6,7 @@
 #include "torch/csrc/jit/passes/graph_fuser.h"
 #include "torch/csrc/jit/passes/onnx.h"
 #include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
 
 
 
@@ -39,6 +40,7 @@
    .def("_jit_pass_onnx", ToONNX)
    .def("_jit_pass_fuse", graph_pass<FuseGraph>)
    .def("_jit_pass_dce", graph_pass<EliminateDeadCode>)
+   .def("_jit_pass_cse", graph_pass<EliminateCommonSubexpression>)
    .def("_jit_pass_lint", graph_pass<LintGraph>)
    .def("_jit_run_cpp_tests", runJITCPPTests);
 
diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
new file mode 100644
index 0000000..d9dd6b1
--- /dev/null
+++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
@@ -0,0 +1,105 @@
+#include "torch/csrc/jit/ir.h"
+
+#include <algorithm>
+#include <unordered_map>
+
+#include "torch/csrc/jit/interned_strings.h"
+#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
+
+namespace torch { namespace jit {
+
+struct HashNodeCSE {
+  std::size_t operator()(const Node* k) const {
+    JIT_ASSERT(k != nullptr);
+    std::size_t p = 31; // A prime.
+    std::size_t h = k->kind() * p + k->stage();
+    for (auto i : k->inputs()) {
+      h = h * p + i->unique();
+    }
+    return h;
+  }
+};
+
+struct EqualNodeCSE {
+  bool operator()(const Node* lhs, const Node* rhs) const {
+    if (lhs == nullptr && rhs == nullptr) return true;
+    if (lhs == nullptr || rhs == nullptr) return false;
+
+    // Check whether two nodes are the same kind.
+    if (lhs->kind() != rhs->kind()) return false;
+
+    // Check the stage.
+    if (lhs->stage() != rhs->stage()) return false;
+
+    // TODO check the device.
+
+    // Check whether the inputs are the same.
+    if (lhs->inputs().size() != rhs->inputs().size()) return false;
+
+    if (!std::equal(lhs->inputs().begin(), lhs->inputs().end(), rhs->inputs().begin())) return false;
+
+    // Check the attributes.
+    // TODO support attributes comparison.
+    if (lhs->hasAttributes() || rhs->hasAttributes()) return false;
+
+    return true;
+  }
+};
+
+void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
+  // Keep iterating until reach the fixed point.
+  bool reach_fixed = false;
+  while (!reach_fixed) {
+    reach_fixed = true;
+    auto nodes = graph->nodes();
+    std::unordered_set<Node*, HashNodeCSE, EqualNodeCSE> subexprs;
+    for (auto it = nodes.begin(); it != nodes.end(); ++ it) {
+      auto node = *it;
+      if (node->kind() != kAdd
+          && node->kind() != kMul
+          && node->kind() != kNeg
+          && node->kind() != kSigmoid
+          && node->kind() != kTanh
+          && node->kind() != kSplit
+          && node->kind() != kAddConstant
+         ) {
+        // TODO support more kinds of nodes.
+        // Only support CSE on these nodes.
+        continue;
+      }
+
+      // Check whether the same subexpression already exists.
+      if (subexprs.find(node) == subexprs.end()) {
+        // If not put it into the map
+        subexprs.insert(node);
+      } else {
+        // Subexpression exists, replace the uses of node, and destory it.
+        auto existing = *subexprs.find(node);
+        const use_list & uses = node->uses();
+        const use_list & reuses= existing->uses();
+        if (node->hasMultipleOutputs()) {
+          // For Multi-Output nodes, all its uses should be Select nodes.
+          JIT_ASSERT(uses.size() == reuses.size());
+          // Replace the uses of Select nodes.
+          for (size_t i = 0; i < uses.size(); ++ i) {
+            JIT_ASSERT(uses[i].user->kind() == kSelect);
+            JIT_ASSERT(reuses[i].user->kind() == kSelect);
+            uses[i].user->replaceAllUsesWith(reuses[i].user);
+          }
+          // Destroy Select nodes.
+          while (uses.size() > 0) {
+            uses[0].user->destroy();
+          }
+        } else {
+          node->replaceAllUsesWith(existing);
+        }
+        // Destroy the node.
+        node->destroy();
+        reach_fixed = false;
+        break;
+      }
+    }
+  }
+}
+
+}}
diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.h b/torch/csrc/jit/passes/common_subexpression_elimination.h
new file mode 100644
index 0000000..483c573
--- /dev/null
+++ b/torch/csrc/jit/passes/common_subexpression_elimination.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "torch/csrc/jit/ir.h"
+
+namespace torch { namespace jit {
+
+void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph);
+
+}}