[SR] Add EliminateTrivialEquallySplit graph pass (#67166)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67166
This optimization is not really the same thing as `FuseListUnpack`, and mixing the logic in that pass is confusing and error-prone. It should really be its own pass.
It's slower since we have to do another pass over the graph, but this is not perf critical code; readability is more important.
Test Plan: Unit tests: `buck test caffe2/benchmarks/static_runtime/...`
Reviewed By: hlu1
Differential Revision: D31887458
fbshipit-source-id: 289e281d512435861fccfe19f017751ad015688c
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 0b3f0dd..3b3f0ca 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -99,6 +99,7 @@
FuseInferenceOpsForSparseNN(graph);
UseVariadicCat(graph);
UseVariadicStack(graph);
+ EliminateTrivialEquallySplit(graph);
if (opts.enable_out_variant) {
UseVariadicOp(
diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp
index 1ac31f9..15a2949 100644
--- a/torch/csrc/jit/runtime/static/passes.cpp
+++ b/torch/csrc/jit/runtime/static/passes.cpp
@@ -524,6 +524,39 @@
#endif
}
+void EliminateTrivialEquallySplit(std::shared_ptr<torch::jit::Graph>& graph) {
+ const auto equally_split = fromQualString("fb::equally_split");
+ std::vector<Node*> to_remove;
+ for (auto* node : graph->nodes()) {
+ if (node->kind() != equally_split) {
+ continue;
+ }
+
+ const Value* value_out = node->outputs()[0];
+ if (value_out->uses().size() != 1) {
+ continue;
+ }
+
+ Node* list_unpack_node = value_out->uses()[0].user;
+ if (list_unpack_node->kind() != prim::ListUnpack) {
+ continue;
+ }
+
+ auto list_unpack_outputs = list_unpack_node->outputs();
+ if (list_unpack_outputs.size() != 1) {
+ continue;
+ }
+
+ list_unpack_node->output()->replaceAllUsesWith(node->input(0));
+ list_unpack_node->destroy();
+ to_remove.push_back(node);
+ }
+
+ for (Node* node : to_remove) {
+ node->destroy();
+ }
+}
+
// NB: The alias type of the fused op needs to be changed to
// c10::AliasAnalysisKind::PURE_FUNCTION to make alias analysis work.
void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
@@ -589,16 +622,6 @@
}
}
- if (is_equally_split && list_unpack_outputs.size() == 1) {
- // This captures a case of `y = fb::equally_split(x, 1, _)` where y
- // becomes just an alias of x.
- // If this case is found, replace y with x to avoid executing this op.
- list_unpack_node->output()->replaceAllUsesWith(node->input(0));
- list_unpack_node->destroy();
- to_remove.push_back(node);
- continue;
- }
-
const auto& new_sym = unfused_to_fused_it->second;
auto* new_node = graph->create(new_sym, 0);
diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h
index dc343b7..e771fe4 100644
--- a/torch/csrc/jit/runtime/static/passes.h
+++ b/torch/csrc/jit/runtime/static/passes.h
@@ -5,6 +5,10 @@
TORCH_API void FuseInferenceOpsForSparseNN(
std::shared_ptr<torch::jit::Graph>& graph);
+
+TORCH_API void EliminateTrivialEquallySplit(
+ std::shared_ptr<torch::jit::Graph>& graph);
+
TORCH_API void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph);
// If outputs_are_immutable is set to false, don't replace the view ops that