[jit] [shape analysis] Move constant tensors out of fused subgraphs during generalization (#70320)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70320
ghstack-source-id: 146514368
Test Plan: `buck test mode/dev-nosan //caffe2/test/cpp/jit:jit`
Reviewed By: eellison
Differential Revision: D33280508
fbshipit-source-id: fe4291d7c49f0a498b330de96b698e99f6f6a505
diff --git a/test/cpp/jit/test_shape_analysis.cpp b/test/cpp/jit/test_shape_analysis.cpp
index d087767..65b7d73 100644
--- a/test/cpp/jit/test_shape_analysis.cpp
+++ b/test/cpp/jit/test_shape_analysis.cpp
@@ -7,6 +7,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/ir/irparser.h>
+#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
@@ -246,5 +247,46 @@
}
}
+TEST(ShapeAnalysisTest, MovingConstantOutOfFusionGroups) {
+ std::shared_ptr<Graph> subgraph = std::make_shared<Graph>();
+ const auto graph_string = R"IR(
+ graph(%x.1 : Tensor):
+ %none : NoneType = prim::Constant()
+ %size1 : int = prim::Constant[value=1]()
+ %size10 : int = prim::Constant[value=10]()
+ %sizes : int[] = prim::ListConstruct(%size10, %size1)
+ %device : Device = prim::Constant[value="cpu"]()
+ %10 : Tensor = aten::ones(%sizes, %none, %none, %device, %none)
+ %3 : Tensor = aten::tanh(%x.1)
+ %29 : Tensor = aten::mul(%3, %10)
+ return (%29))IR";
+ torch::jit::parseIR(graph_string, subgraph.get());
+ ConstantPropagation(subgraph);
+
+ std::shared_ptr<Graph> g = std::make_shared<Graph>();
+ auto x_inp = g->addInput("x_inp");
+ auto x_type = TensorType::create(at::rand({10, 5}));
+ x_inp->setType(x_type);
+ subgraph->inputs().at(0)->setType(x_type);
+ auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
+ output->node()->addInput(x_inp);
+ output->node()->g_(attr::Subgraph, subgraph);
+
+ auto success = GenerateGuard(output->node());
+ TORCH_INTERNAL_ASSERT(success);
+
+ // Check that the constants have been moved out of the fused graph.
+ // This should result in not have any conditionals other than the one
+ // checking the result of TensorExprDynamicGuard.
+ testing::FileCheck()
+ .check("TensorExprDynamicGuard")
+ ->check_next("prim::If")
+ ->check_not("prim::If") // no other IFs due to constants.
+ ->check("TensorExprGroup")
+ ->check("block1")
+ ->check("FallbackGraph")
+ ->run(*g);
+}
+
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
index 0a0d127..b90fa28 100644
--- a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
+++ b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
@@ -106,9 +106,58 @@
return true;
}
+void moveConstantTensorsOutOfSubgraph(
+ Node* tensorexpr_graph_node,
+ std::shared_ptr<Graph> tensorexpr_graph) {
+ auto parent = tensorexpr_graph_node->owningGraph();
+
+ auto env = [&](Value* v) {
+ TORCH_INTERNAL_ASSERT(
+ false,
+ "this should never happen since constant nodes do not have any inputs",
+ v->debugName());
+ return v;
+ };
+
+ WithInsertPoint wip(tensorexpr_graph_node);
+ std::vector<Node*> to_destroy;
+ for (auto node : tensorexpr_graph->nodes()) {
+ if (node->kind() == prim::Constant) {
+ if (!node->output()->type()->cast<TensorType>()) {
+ continue;
+ }
+
+ // copy the constant and insert that copy into the parent graph.
+ auto copy = parent->createClone(node, env);
+ parent->insertNode(copy);
+
+ // add a new input to the te subgraph and replace the uses of the
+ // constant with this input.
+ auto new_const = tensorexpr_graph->addInput();
+ new_const->setType(node->output()->type());
+ node->output()->replaceAllUsesWith(new_const);
+
+ // add the copy as input to the te node
+ tensorexpr_graph_node->addInput(copy->output());
+
+ to_destroy.push_back(node);
+ }
+ }
+
+ for (auto n : to_destroy) {
+ n->destroy();
+ }
+}
+
bool GenerateGuard(Node* tensorexpr_graph_node, bool add_composed_op) {
auto tensorexpr_graph = SubgraphUtils::getSubgraph(tensorexpr_graph_node);
+ // Move constant tensors from the subgraph to the outer scope.
+ // This is necessary because symbolic shape analysis does not handle the
+ // case of broadcast(constant, symbolic_shape) well and that results in poor
+ // performance.
+ moveConstantTensorsOutOfSubgraph(tensorexpr_graph_node, tensorexpr_graph);
+
// Generalize Inputs
if (!TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph)) {
return false;