[jit] [shape analysis] Move constant tensors out of fused subgraphs during generalization (#70320)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70320

ghstack-source-id: 146514368

Test Plan: `buck test mode/dev-nosan //caffe2/test/cpp/jit:jit`

Reviewed By: eellison

Differential Revision: D33280508

fbshipit-source-id: fe4291d7c49f0a498b330de96b698e99f6f6a505
diff --git a/test/cpp/jit/test_shape_analysis.cpp b/test/cpp/jit/test_shape_analysis.cpp
index d087767..65b7d73 100644
--- a/test/cpp/jit/test_shape_analysis.cpp
+++ b/test/cpp/jit/test_shape_analysis.cpp
@@ -7,6 +7,7 @@
 #include <torch/csrc/jit/ir/ir.h>
 #include <torch/csrc/jit/ir/ir_views.h>
 #include <torch/csrc/jit/ir/irparser.h>
+#include <torch/csrc/jit/passes/constant_propagation.h>
 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
 #include <torch/csrc/jit/runtime/graph_iterator.h>
@@ -246,5 +247,46 @@
   }
 }
 
+TEST(ShapeAnalysisTest, MovingConstantOutOfFusionGroups) {
+  std::shared_ptr<Graph> subgraph = std::make_shared<Graph>();
+  const auto graph_string = R"IR(
+      graph(%x.1 : Tensor):
+        %none : NoneType = prim::Constant()
+        %size1 : int = prim::Constant[value=1]()
+        %size10 : int = prim::Constant[value=10]()
+        %sizes : int[] = prim::ListConstruct(%size10, %size1)
+        %device : Device = prim::Constant[value="cpu"]()
+        %10 : Tensor = aten::ones(%sizes, %none, %none, %device, %none)
+        %3 : Tensor = aten::tanh(%x.1)
+        %29 : Tensor = aten::mul(%3, %10)
+        return (%29))IR";
+  torch::jit::parseIR(graph_string, subgraph.get());
+  ConstantPropagation(subgraph);
+
+  std::shared_ptr<Graph> g = std::make_shared<Graph>();
+  auto x_inp = g->addInput("x_inp");
+  auto x_type = TensorType::create(at::rand({10, 5}));
+  x_inp->setType(x_type);
+  subgraph->inputs().at(0)->setType(x_type);
+  auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
+  output->node()->addInput(x_inp);
+  output->node()->g_(attr::Subgraph, subgraph);
+
+  auto success = GenerateGuard(output->node());
+  TORCH_INTERNAL_ASSERT(success);
+
+  // Check that the constants have been moved out of the fused graph.
+  // This should result in not have any conditionals other than the one
+  // checking the result of TensorExprDynamicGuard.
+  testing::FileCheck()
+      .check("TensorExprDynamicGuard")
+      ->check_next("prim::If")
+      ->check_not("prim::If") // no other IFs due to constants.
+      ->check("TensorExprGroup")
+      ->check("block1")
+      ->check("FallbackGraph")
+      ->run(*g);
+}
+
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
index 0a0d127..b90fa28 100644
--- a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
+++ b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
@@ -106,9 +106,58 @@
   return true;
 }
 
+void moveConstantTensorsOutOfSubgraph(
+    Node* tensorexpr_graph_node,
+    std::shared_ptr<Graph> tensorexpr_graph) {
+  auto parent = tensorexpr_graph_node->owningGraph();
+
+  auto env = [&](Value* v) {
+    TORCH_INTERNAL_ASSERT(
+        false,
+        "this should never happen since constant nodes do not have any inputs",
+        v->debugName());
+    return v;
+  };
+
+  WithInsertPoint wip(tensorexpr_graph_node);
+  std::vector<Node*> to_destroy;
+  for (auto node : tensorexpr_graph->nodes()) {
+    if (node->kind() == prim::Constant) {
+      if (!node->output()->type()->cast<TensorType>()) {
+        continue;
+      }
+
+      // copy the constant and insert that copy into the parent graph.
+      auto copy = parent->createClone(node, env);
+      parent->insertNode(copy);
+
+      // add a new input to the te subgraph and replace the uses of the
+      // constant with this input.
+      auto new_const = tensorexpr_graph->addInput();
+      new_const->setType(node->output()->type());
+      node->output()->replaceAllUsesWith(new_const);
+
+      // add the copy as input to the te node
+      tensorexpr_graph_node->addInput(copy->output());
+
+      to_destroy.push_back(node);
+    }
+  }
+
+  for (auto n : to_destroy) {
+    n->destroy();
+  }
+}
+
 bool GenerateGuard(Node* tensorexpr_graph_node, bool add_composed_op) {
   auto tensorexpr_graph = SubgraphUtils::getSubgraph(tensorexpr_graph_node);
 
+  // Move constant tensors from the subgraph to the outer scope.
+  // This is necessary because symbolic shape analysis does not handle the
+  // case of broadcast(constant, symbolic_shape) well and that results in poor
+  // performance.
+  moveConstantTensorsOutOfSubgraph(tensorexpr_graph_node, tensorexpr_graph);
+
   // Generalize Inputs
   if (!TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph)) {
     return false;