[StaticRuntime] Fuse SigridTransforms + ListUnpack (#53920)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53920

Fusing SigridTransforms + ListUnpack allows for enabling out variant for SigridTransforms so that the output tensors can be managed by the MemoryPlanner in Static Runtime.

The speedup comes from three parts 1) get rid of memory allocation inside SigridTransforms itself, 2) memory deallocation cost (outside SigridTransforms, inside MemoryPlanner), 3) get rid of ListUnpack. However, in 3) we still need to pay the cost of constructing `vector<Tensor>` for outputs and a round of refcount bumps for all the output TensorImpls.

Reviewed By: ajyu

Differential Revision: D26220546

fbshipit-source-id: 651bdfb850225511c43b8f50083b13e8dec46bcc
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index c3189c8..ad624b7 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -37,6 +37,7 @@
 #ifdef FBCODE_CAFFE2
   if (opts.enable_out_variant) {
     ReplaceWithCopy(graph);
+    FuseSigridTransformsListUnpack(graph);
   }
 #endif
   ConstantPropagation(graph);
@@ -1115,13 +1116,7 @@
     : node_(node), inputs_(std::move(inputs)) {
   // TODO leverage type information
   outputs_.resize(node->outputs().size());
-  if (node->kind() != prim::ListConstruct &&
-      node->kind() != prim::TupleConstruct &&
-      node->kind() != prim::ListUnpack) {
-    const Operator& op = node->getOperator();
-    TORCH_CHECK(op.hasOperation());
-    op_ = op.getOperation(node);
-  }
+
   if (enable_out_variants && canRunOutOfPlace(node)) {
     fn_ = getOutOfPlaceOperation(node);
     std::ostringstream ss;
@@ -1132,7 +1127,14 @@
     std::ostringstream ss;
     node->print(ss, 0, nullptr, false);
     VLOG(1) << "Switch to native impl for node: " << ss.str();
-  } else {
+  } else if (
+      node->kind() != prim::ListConstruct &&
+      node->kind() != prim::TupleConstruct &&
+      node->kind() != prim::ListUnpack) {
+    const Operator& op = node->getOperator();
+    TORCH_CHECK(op.hasOperation());
+    op_ = op.getOperation(node);
+
     std::ostringstream ss;
     node->print(ss, 0, nullptr, false);
     VLOG(1) << "Fallback interpreter for node: " << ss.str();
diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp
index e93d107..e60875c 100644
--- a/torch/csrc/jit/runtime/static/passes.cpp
+++ b/torch/csrc/jit/runtime/static/passes.cpp
@@ -401,5 +401,43 @@
   }
 }
 
+void FuseSigridTransformsListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
+  auto nodes = graph->nodes();
+  for (auto it = nodes.begin(); it != nodes.end(); ++it) {
+    Node* sigrid_node = *it;
+    auto kind = sigrid_node->kind();
+    // TODO: make it work the TorchBind version
+    if (strcmp(kind.toQualString(), "fb::sigrid_transforms") == 0) {
+      const Value* sigrid_out = sigrid_node->outputs()[0];
+      if (sigrid_out->uses().size() > 1) {
+        continue;
+      }
+
+      Node* list_unpack_node = sigrid_out->uses()[0].user;
+      if (list_unpack_node->kind() != prim::ListUnpack) {
+        continue;
+      }
+
+      auto list_unpack_outputs = list_unpack_node->outputs();
+      if (list_unpack_outputs.empty()) {
+        continue;
+      }
+
+      // handle outputs
+      for (Value* out : list_unpack_outputs) {
+        Value* new_out = sigrid_node->addOutput();
+        new_out->copyMetadata(out);
+        out->replaceAllUsesWith(new_out);
+      }
+
+      auto it_next = it;
+      ++it_next; // it_next points to list_unpack
+      it_next.destroyCurrent(); // remove list_unpack
+
+      sigrid_node->eraseOutput(0);
+    }
+  }
+}
+
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h
index ee19d00..d3823ae 100644
--- a/torch/csrc/jit/runtime/static/passes.h
+++ b/torch/csrc/jit/runtime/static/passes.h
@@ -4,6 +4,8 @@
 namespace jit {
 
 void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph);
+void FuseSigridTransformsListUnpack(std::shared_ptr<torch::jit::Graph>& graph);
+
 void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph);
 
 void SplitOutPrecomputeOpsForSparseNN(