Add support for multi output nodes in partial eval graph stitching (#66097)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66097

Adding logic to generate runtime shapes for nodes with multi-outputs. It is generalizing existing flow of looking at a node, getting its shape graph, inlining it, and adding a mapping from the output to the new value in the stitched shape compute graph to loop over multiple outputs.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31732418

Pulled By: eellison

fbshipit-source-id: 767698d031b1daf002678a025b270e0ede429061
diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py
index 2fb4edb..720c392 100644
--- a/test/jit/test_symbolic_shape_analysis.py
+++ b/test/jit/test_symbolic_shape_analysis.py
@@ -364,3 +364,25 @@
         out1 = outs[0].type().symbolic_sizes()
         out2 = outs[1].type().symbolic_sizes()
         self.assertEqual(out1, out2)
+
+    def test_stitching_multi_output(self):
+        max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, return_indices=True)
+        tensor = torch.rand(1, 3, 224, 224)
+        mod = torch.jit.trace(max_pool, (tensor,))
+        mod = torch.jit.freeze(mod.eval())
+        inp = list(mod.graph.inputs())[1]
+        inp.setType(inp.type().with_sizes([None, None, None, None]))
+        shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(
+            mod.graph,
+            next(mod.graph.nodes()),
+            mod.graph.findNode("prim::TupleConstruct")
+        )
+        max_pool_node = mod.graph.findNode("aten::max_pool2d_with_indices")
+        outs = list(max_pool_node.outputs())
+        self.assertEqual(outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes())
+        g = shape_compute_graph.partial_eval_shape_graph()
+        # to make into a jit function cant have multiple outputs
+        g.makeMultiOutputIntoTuple()
+        func = torch._C._create_function_from_graph("partial_eval_graph", g)
+        output_shape = func(tensor.size())
+        self.assertEqual(list(output_shape), list(mod(tensor)[0].size()))
diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
index b9ce451..66cd770 100644
--- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
+++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
@@ -635,20 +635,24 @@
       if (curr->kind() == prim::Constant) {
         continue;
       }
-      if (curr->outputs().size() != 1) {
-        GRAPH_DEBUG("Multi output node ", getHeader(curr));
+      if (!partial_evaluated_graphs.count(curr)) {
+        GRAPH_DEBUG("No graph ", getHeader(curr));
         return c10::nullopt;
       }
-      auto tt = curr->output()->type()->cast<TensorType>();
-      if (!tt || !partial_evaluated_graphs.count(curr)) {
-        GRAPH_DEBUG("Non tensor node or no graph ", getHeader(curr));
-        return c10::nullopt;
-      }
-      auto symbolic_sizes = tt->symbolic_sizes();
-      // TODO: dont require # of dimensions of tensors set ?
-      if (!symbolic_sizes.rank()) {
-        GRAPH_DEBUG("No rank on output ", getHeader(curr));
-        return c10::nullopt;
+
+      auto outputs = curr->outputs();
+      for (Value* v : outputs) {
+        auto tt = v->type()->cast<TensorType>();
+        if (!tt) {
+          GRAPH_DEBUG("Non tensor node", getHeader(curr));
+          return c10::nullopt;
+        }
+        auto symbolic_sizes = tt->symbolic_sizes();
+        // TODO: dont require # of dimensions of tensors set ?
+        if (!symbolic_sizes.rank()) {
+          GRAPH_DEBUG("No rank on output ", getHeader(curr));
+          return c10::nullopt;
+        }
       }
       auto partial_eval_graph = partial_evaluated_graphs[curr];
       joinPartialEvaluatedShapeGraphToLargeShapeGraph(
@@ -708,26 +712,27 @@
       std::unordered_map<int64_t, int64_t>& sym_shape_equalities) {
     for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
       auto curr = *it;
-      if (curr->outputs().size() != 1) {
-        continue;
-      }
-      auto tt = curr->output()->type()->cast<TensorType>();
-      if (!tt || !tt->symbolic_sizes().rank()) {
-        continue;
-      }
-      bool changed = false;
-      std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
-      auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
-        auto value = shape.value();
-        if (sym_shape_equalities.count(value)) {
-          changed = true;
-          return sym_shape_equalities[value];
+      for (size_t i = 0; i < curr->outputs().size(); ++i) {
+        auto output = curr->output(i);
+        auto tt = output->type()->cast<TensorType>();
+        if (!tt || !tt->symbolic_sizes().rank()) {
+          continue;
         }
-        return value;
-      });
-      if (changed) {
-        curr->output()->setType(
-            tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
+        bool changed = false;
+        std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
+        auto new_sizes =
+            c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
+              auto value = shape.value();
+              if (sym_shape_equalities.count(value)) {
+                changed = true;
+                return sym_shape_equalities[value];
+              }
+              return value;
+            });
+        if (changed) {
+          output->setType(
+              tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
+        }
       }
     }
   }
@@ -778,34 +783,36 @@
     insertGraph(
         *stitched_shape_compute_graph, *partial_eval_graph, inputs, value_map);
 
-    TORCH_INTERNAL_ASSERT(partial_eval_graph->outputs().size() == 1);
-    Value* new_list_output = value_map[partial_eval_graph->outputs().at(0)];
-    enclosing_graph_value_to_shape_graph_input_[curr->output()] =
-        new_list_output;
+    for (size_t i = 0; i < curr->outputs().size(); ++i) {
+      Value* new_list_output = value_map[partial_eval_graph->outputs().at(i)];
+      enclosing_graph_value_to_shape_graph_input_[curr->output(i)] =
+          new_list_output;
 
-    TORCH_INTERNAL_ASSERT(
-        new_list_output->node()->kind() == prim::ListConstruct);
-    TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses());
+      TORCH_INTERNAL_ASSERT(
+          new_list_output->node()->kind() == prim::ListConstruct);
+      TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses());
 
-    auto symbolic_sizes =
-        curr->output()->type()->expect<TensorType>()->symbolic_sizes();
-    TORCH_INTERNAL_ASSERT(symbolic_sizes.rank());
+      auto symbolic_sizes =
+          curr->output(i)->type()->expect<TensorType>()->symbolic_sizes();
+      TORCH_INTERNAL_ASSERT(symbolic_sizes.rank());
 
-    for (size_t i = 0; i < *symbolic_sizes.rank(); i++) {
-      if (symbolic_sizes[i].is_static()) {
-        continue;
+      for (size_t i = 0; i < *symbolic_sizes.rank(); i++) {
+        if (symbolic_sizes[i].is_static()) {
+          continue;
+        }
+        int64_t symbolic_shape = symbolic_sizes[i].value();
+        if (symbolic_shape_value_to_graph_output_.count(symbolic_shape)) {
+          continue;
+        }
+        stitched_shape_compute_graph->registerOutput(
+            new_list_output->node()->input(i));
+        output_index_to_symbolic_shape_
+            [stitched_shape_compute_graph->outputs().size() - 1] =
+                symbolic_shape;
+        symbolic_shape_value_to_graph_output_[symbolic_shape] =
+            stitched_shape_compute_graph->outputs().at(
+                stitched_shape_compute_graph->outputs().size() - 1);
       }
-      int64_t symbolic_shape = symbolic_sizes[i].value();
-      if (symbolic_shape_value_to_graph_output_.count(symbolic_shape)) {
-        continue;
-      }
-      stitched_shape_compute_graph->registerOutput(
-          new_list_output->node()->input(i));
-      output_index_to_symbolic_shape_
-          [stitched_shape_compute_graph->outputs().size() - 1] = symbolic_shape;
-      symbolic_shape_value_to_graph_output_[symbolic_shape] =
-          stitched_shape_compute_graph->outputs().at(
-              stitched_shape_compute_graph->outputs().size() - 1);
     }
   }