Add support for multi output nodes in partial eval graph stitching (#66097)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66097
Adding logic to generate runtime shapes for nodes with multi-outputs. It is generalizing existing flow of looking at a node, getting its shape graph, inlining it, and adding a mapping from the output to the new value in the stitched shape compute graph to loop over multiple outputs.
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D31732418
Pulled By: eellison
fbshipit-source-id: 767698d031b1daf002678a025b270e0ede429061
diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py
index 2fb4edb..720c392 100644
--- a/test/jit/test_symbolic_shape_analysis.py
+++ b/test/jit/test_symbolic_shape_analysis.py
@@ -364,3 +364,25 @@
out1 = outs[0].type().symbolic_sizes()
out2 = outs[1].type().symbolic_sizes()
self.assertEqual(out1, out2)
+
+ def test_stitching_multi_output(self):
+ max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, return_indices=True)
+ tensor = torch.rand(1, 3, 224, 224)
+ mod = torch.jit.trace(max_pool, (tensor,))
+ mod = torch.jit.freeze(mod.eval())
+ inp = list(mod.graph.inputs())[1]
+ inp.setType(inp.type().with_sizes([None, None, None, None]))
+ shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(
+ mod.graph,
+ next(mod.graph.nodes()),
+ mod.graph.findNode("prim::TupleConstruct")
+ )
+ max_pool_node = mod.graph.findNode("aten::max_pool2d_with_indices")
+ outs = list(max_pool_node.outputs())
+ self.assertEqual(outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes())
+ g = shape_compute_graph.partial_eval_shape_graph()
+ # to make into a jit function cant have multiple outputs
+ g.makeMultiOutputIntoTuple()
+ func = torch._C._create_function_from_graph("partial_eval_graph", g)
+ output_shape = func(tensor.size())
+ self.assertEqual(list(output_shape), list(mod(tensor)[0].size()))
diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
index b9ce451..66cd770 100644
--- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
+++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp
@@ -635,20 +635,24 @@
if (curr->kind() == prim::Constant) {
continue;
}
- if (curr->outputs().size() != 1) {
- GRAPH_DEBUG("Multi output node ", getHeader(curr));
+ if (!partial_evaluated_graphs.count(curr)) {
+ GRAPH_DEBUG("No graph ", getHeader(curr));
return c10::nullopt;
}
- auto tt = curr->output()->type()->cast<TensorType>();
- if (!tt || !partial_evaluated_graphs.count(curr)) {
- GRAPH_DEBUG("Non tensor node or no graph ", getHeader(curr));
- return c10::nullopt;
- }
- auto symbolic_sizes = tt->symbolic_sizes();
- // TODO: dont require # of dimensions of tensors set ?
- if (!symbolic_sizes.rank()) {
- GRAPH_DEBUG("No rank on output ", getHeader(curr));
- return c10::nullopt;
+
+ auto outputs = curr->outputs();
+ for (Value* v : outputs) {
+ auto tt = v->type()->cast<TensorType>();
+ if (!tt) {
+ GRAPH_DEBUG("Non tensor node", getHeader(curr));
+ return c10::nullopt;
+ }
+ auto symbolic_sizes = tt->symbolic_sizes();
+ // TODO: dont require # of dimensions of tensors set ?
+ if (!symbolic_sizes.rank()) {
+ GRAPH_DEBUG("No rank on output ", getHeader(curr));
+ return c10::nullopt;
+ }
}
auto partial_eval_graph = partial_evaluated_graphs[curr];
joinPartialEvaluatedShapeGraphToLargeShapeGraph(
@@ -708,26 +712,27 @@
std::unordered_map<int64_t, int64_t>& sym_shape_equalities) {
for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
auto curr = *it;
- if (curr->outputs().size() != 1) {
- continue;
- }
- auto tt = curr->output()->type()->cast<TensorType>();
- if (!tt || !tt->symbolic_sizes().rank()) {
- continue;
- }
- bool changed = false;
- std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
- auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
- auto value = shape.value();
- if (sym_shape_equalities.count(value)) {
- changed = true;
- return sym_shape_equalities[value];
+ for (size_t i = 0; i < curr->outputs().size(); ++i) {
+ auto output = curr->output(i);
+ auto tt = output->type()->cast<TensorType>();
+ if (!tt || !tt->symbolic_sizes().rank()) {
+ continue;
}
- return value;
- });
- if (changed) {
- curr->output()->setType(
- tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
+ bool changed = false;
+ std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
+ auto new_sizes =
+ c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
+ auto value = shape.value();
+ if (sym_shape_equalities.count(value)) {
+ changed = true;
+ return sym_shape_equalities[value];
+ }
+ return value;
+ });
+ if (changed) {
+ output->setType(
+ tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
+ }
}
}
}
@@ -778,34 +783,36 @@
insertGraph(
*stitched_shape_compute_graph, *partial_eval_graph, inputs, value_map);
- TORCH_INTERNAL_ASSERT(partial_eval_graph->outputs().size() == 1);
- Value* new_list_output = value_map[partial_eval_graph->outputs().at(0)];
- enclosing_graph_value_to_shape_graph_input_[curr->output()] =
- new_list_output;
+ for (size_t i = 0; i < curr->outputs().size(); ++i) {
+ Value* new_list_output = value_map[partial_eval_graph->outputs().at(i)];
+ enclosing_graph_value_to_shape_graph_input_[curr->output(i)] =
+ new_list_output;
- TORCH_INTERNAL_ASSERT(
- new_list_output->node()->kind() == prim::ListConstruct);
- TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses());
+ TORCH_INTERNAL_ASSERT(
+ new_list_output->node()->kind() == prim::ListConstruct);
+ TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses());
- auto symbolic_sizes =
- curr->output()->type()->expect<TensorType>()->symbolic_sizes();
- TORCH_INTERNAL_ASSERT(symbolic_sizes.rank());
+ auto symbolic_sizes =
+ curr->output(i)->type()->expect<TensorType>()->symbolic_sizes();
+ TORCH_INTERNAL_ASSERT(symbolic_sizes.rank());
- for (size_t i = 0; i < *symbolic_sizes.rank(); i++) {
- if (symbolic_sizes[i].is_static()) {
- continue;
+ for (size_t i = 0; i < *symbolic_sizes.rank(); i++) {
+ if (symbolic_sizes[i].is_static()) {
+ continue;
+ }
+ int64_t symbolic_shape = symbolic_sizes[i].value();
+ if (symbolic_shape_value_to_graph_output_.count(symbolic_shape)) {
+ continue;
+ }
+ stitched_shape_compute_graph->registerOutput(
+ new_list_output->node()->input(i));
+ output_index_to_symbolic_shape_
+ [stitched_shape_compute_graph->outputs().size() - 1] =
+ symbolic_shape;
+ symbolic_shape_value_to_graph_output_[symbolic_shape] =
+ stitched_shape_compute_graph->outputs().at(
+ stitched_shape_compute_graph->outputs().size() - 1);
}
- int64_t symbolic_shape = symbolic_sizes[i].value();
- if (symbolic_shape_value_to_graph_output_.count(symbolic_shape)) {
- continue;
- }
- stitched_shape_compute_graph->registerOutput(
- new_list_output->node()->input(i));
- output_index_to_symbolic_shape_
- [stitched_shape_compute_graph->outputs().size() - 1] = symbolic_shape;
- symbolic_shape_value_to_graph_output_[symbolic_shape] =
- stitched_shape_compute_graph->outputs().at(
- stitched_shape_compute_graph->outputs().size() - 1);
}
}