[SR] Add EliminateTrivialEquallySplit graph pass (#67166)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67166

This optimization is not really the same thing as `FuseListUnpack`, and mixing the logic in that pass is confusing and error-prone. It should really be its own pass.

It's slower since we have to do another pass over the graph, but this is not perf critical code; readability is more important.

Test Plan: Unit tests: `buck test caffe2/benchmarks/static_runtime/...`

Reviewed By: hlu1

Differential Revision: D31887458

fbshipit-source-id: 289e281d512435861fccfe19f017751ad015688c
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 0b3f0dd..3b3f0ca 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -99,6 +99,7 @@
   FuseInferenceOpsForSparseNN(graph);
   UseVariadicCat(graph);
   UseVariadicStack(graph);
+  EliminateTrivialEquallySplit(graph);
 
   if (opts.enable_out_variant) {
     UseVariadicOp(
diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp
index 1ac31f9..15a2949 100644
--- a/torch/csrc/jit/runtime/static/passes.cpp
+++ b/torch/csrc/jit/runtime/static/passes.cpp
@@ -524,6 +524,39 @@
 #endif
 }
 
+void EliminateTrivialEquallySplit(std::shared_ptr<torch::jit::Graph>& graph) {
+  const auto equally_split = fromQualString("fb::equally_split");
+  std::vector<Node*> to_remove;
+  for (auto* node : graph->nodes()) {
+    if (node->kind() != equally_split) {
+      continue;
+    }
+
+    const Value* value_out = node->outputs()[0];
+    if (value_out->uses().size() != 1) {
+      continue;
+    }
+
+    Node* list_unpack_node = value_out->uses()[0].user;
+    if (list_unpack_node->kind() != prim::ListUnpack) {
+      continue;
+    }
+
+    auto list_unpack_outputs = list_unpack_node->outputs();
+    if (list_unpack_outputs.size() != 1) {
+      continue;
+    }
+
+    list_unpack_node->output()->replaceAllUsesWith(node->input(0));
+    list_unpack_node->destroy();
+    to_remove.push_back(node);
+  }
+
+  for (Node* node : to_remove) {
+    node->destroy();
+  }
+}
+
 // NB: The alias type of the fused op needs to be changed to
 // c10::AliasAnalysisKind::PURE_FUNCTION to make alias analysis work.
 void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
@@ -589,16 +622,6 @@
       }
     }
 
-    if (is_equally_split && list_unpack_outputs.size() == 1) {
-      // This captures a case of `y = fb::equally_split(x, 1, _)` where y
-      // becomes just an alias of x.
-      // If this case is found, replace y with x to avoid executing this op.
-      list_unpack_node->output()->replaceAllUsesWith(node->input(0));
-      list_unpack_node->destroy();
-      to_remove.push_back(node);
-      continue;
-    }
-
     const auto& new_sym = unfused_to_fused_it->second;
     auto* new_node = graph->create(new_sym, 0);
 
diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h
index dc343b7..e771fe4 100644
--- a/torch/csrc/jit/runtime/static/passes.h
+++ b/torch/csrc/jit/runtime/static/passes.h
@@ -5,6 +5,10 @@
 
 TORCH_API void FuseInferenceOpsForSparseNN(
     std::shared_ptr<torch::jit::Graph>& graph);
+
+TORCH_API void EliminateTrivialEquallySplit(
+    std::shared_ptr<torch::jit::Graph>& graph);
+
 TORCH_API void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph);
 
 // If outputs_are_immutable is set to false, don't replace the view ops that