[static runtime] Split out graph preparation from runtime (#44131)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44131

Test Plan: Imported from OSS

Reviewed By: hlu1

Differential Revision: D23604305

Pulled By: bwasti

fbshipit-source-id: 7b47da4961d99074199417ef1407a788c7d80ee6
diff --git a/benchmarks/static_runtime/deep_wide_pt_bench.cc b/benchmarks/static_runtime/deep_wide_pt_bench.cc
index ef960d2..21c2923 100644
--- a/benchmarks/static_runtime/deep_wide_pt_bench.cc
+++ b/benchmarks/static_runtime/deep_wide_pt_bench.cc
@@ -60,7 +60,8 @@
 
 static void BM_deep_wide_static(benchmark::State& state) {
   auto mod = getDeepAndWideSciptModel();
-  torch::jit::StaticRuntime runtime(mod);
+  auto g = torch::jit::PrepareForStaticRuntime(mod);
+  torch::jit::StaticRuntime runtime(g);
 
   const int batch_size = state.range(0);
   auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
@@ -75,6 +76,28 @@
   }
 }
 
+const std::shared_ptr<torch::jit::Graph>& getStaticGraph() {
+  static const std::shared_ptr<torch::jit::Graph> g =
+      torch::jit::PrepareForStaticRuntime(getDeepAndWideSciptModel());
+  return g;
+}
+
+static void BM_deep_wide_static_threaded(benchmark::State& state) {
+  auto g = getStaticGraph();
+  torch::jit::StaticRuntime runtime(g);
+
+  const int batch_size = 1; // state.range(0);
+  auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
+  auto user_emb = torch::randn({batch_size, 1, embedding_size});
+  auto wide = torch::randn({batch_size, num_features});
+
+  std::vector<at::Tensor> inputs({ad_emb_packed, user_emb, wide});
+
+  for (auto _ : state) {
+    runtime.run(inputs);
+  }
+}
+
 BENCHMARK(BM_deep_wide_base)->RangeMultiplier(8)->Ranges({{1, 20}});
 
 BENCHMARK(BM_deep_wide_jit_graph_executor)
@@ -86,5 +109,6 @@
     ->Ranges({{1, 20}});
 
 BENCHMARK(BM_deep_wide_static)->RangeMultiplier(8)->Ranges({{1, 20}});
+BENCHMARK(BM_deep_wide_static_threaded)->Threads(8);
 
 BENCHMARK_MAIN();
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index 6b9b1dd..211839f 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -14,7 +14,8 @@
 
   // run static runtime
   std::vector<at::Tensor> input_tensors({a, b, c});
-  torch::jit::StaticRuntime runtime(mod);
+  auto g = torch::jit::PrepareForStaticRuntime(mod);
+  torch::jit::StaticRuntime runtime(g);
   at::Tensor output_2 = runtime.run(input_tensors)[0];
   EXPECT_TRUE(output_1.equal(output_2));
 }
@@ -23,7 +24,8 @@
   const int embedding_size = 32;
   const int num_features = 50;
   torch::jit::Module mod = getDeepAndWideSciptModel();
-  torch::jit::StaticRuntime runtime(mod);
+  auto g = torch::jit::PrepareForStaticRuntime(mod);
+  torch::jit::StaticRuntime runtime(g);
 
   for (int batch_size : {1, 8, 32}) {
     for (int i = 0; i < 5; ++i) {
diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py
index 582cb32..86dafa3 100644
--- a/test/test_static_runtime.py
+++ b/test/test_static_runtime.py
@@ -150,13 +150,13 @@
             acc_top = top_l_acc(top_inp)[0]
             torch.testing.assert_allclose(acc_top, ref_top)
 
-    # def test_trivial_graph(self):
-    #     s = torch.full((2, 2), 2)
-    #     tg = torch.jit.script(trivial_graph)
-    #     o_ref = tg(s, s, s)
-    #     tg_a = StaticRuntime(tg)
-    #     o_test = tg_a(s, s, s)[0]
-    #     torch.testing.assert_allclose(o_ref, o_test)
+    def test_trivial_graph(self):
+        s = torch.full((2, 2), 2)
+        tg = torch.jit.script(trivial_graph)
+        o_ref = tg(s, s, s)
+        tg_a = StaticRuntime(tg)
+        o_test = tg_a(s, s, s)[0]
+        torch.testing.assert_allclose(o_ref, o_test)
 
 
 if __name__ == "__main__":
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 42abeec..1ebd7cc 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -14,20 +14,16 @@
 using c10::DispatchKey;
 using c10::RegisterOperators;
 
-StaticRuntime::StaticRuntime(const torch::jit::Module& m)
-    : module_(m.copy()), graph_(nullptr) {
-  module_.eval();
-  module_ = freeze_module(module_);
-  graph_ = module_.get_method("forward").graph();
+std::shared_ptr<torch::jit::Graph> PrepareForStaticRuntime(
+    std::shared_ptr<torch::jit::Graph> g) {
+  Inline(*g);
+  ConstantPropagation(g);
+  Canonicalize(g);
+  ConstantPropagation(g);
+  RemoveTensorMutation(g);
+  ConstantPropagation(g);
 
-  Inline(*graph_);
-  ConstantPropagation(graph_);
-  Canonicalize(graph_);
-  ConstantPropagation(graph_);
-  RemoveTensorMutation(graph_);
-  ConstantPropagation(graph_);
-
-  for (auto n : graph_->nodes()) {
+  for (auto n : g->nodes()) {
     if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) {
       throw std::runtime_error("Cannot accelerate unfrozen graphs");
     }
@@ -45,12 +41,25 @@
   }
 
   // remove unused input 0 from graph
-  if (graph_->inputs().at(0)->type()->is_module()) {
-    if (!graph_->inputs().at(0)->hasUses()) {
-      graph_->eraseInput(0);
+  if (g->inputs().at(0)->type()->is_module()) {
+    if (!g->inputs().at(0)->hasUses()) {
+      g->eraseInput(0);
     }
   }
 
+  return g;
+}
+
+std::shared_ptr<torch::jit::Graph> PrepareForStaticRuntime(
+    const torch::jit::Module& m) {
+  auto module = m.copy();
+  module.eval();
+  module = freeze_module(module);
+  auto g = module.get_method("forward").graph();
+  return PrepareForStaticRuntime(g);
+}
+
+StaticRuntime::StaticRuntime(std::shared_ptr<torch::jit::Graph> g) : graph_(g) {
   // fill workspace_ with constants
   for (Node* node : graph_->nodes()) {
     if (node->kind() == prim::Constant) {
@@ -63,19 +72,13 @@
 }
 
 std::vector<at::Tensor> StaticRuntime::run(
-    const std::vector<at::Tensor>& inps) {
+    const std::vector<at::Tensor>& inps) const {
   // Container for inputs, outputs, and activations (excluding parameters)
 
-  int start = 0;
-  if (graph_->inputs().size() != inps.size()) {
-    start = 1;
-    CHECK_EQ(graph_->inputs().size(), inps.size() + 1);
-    CHECK((graph_->inputs().at(0)->type()->is_module()));
-    workspace_[graph_->inputs()[0]] = module_._ivalue();
-  }
+  TORCH_INTERNAL_ASSERT(graph_->inputs().size() == inps.size());
 
   for (size_t i = 0; i < inps.size(); i++) {
-    workspace_[graph_->inputs()[i + start]] = inps[i];
+    workspace_[graph_->inputs()[i]] = inps[i];
   }
 
   for (const auto& n : nodes_) {
diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h
index 3fc8a2e..8fe8c80 100644
--- a/torch/csrc/jit/runtime/static/impl.h
+++ b/torch/csrc/jit/runtime/static/impl.h
@@ -14,15 +14,16 @@
 namespace torch {
 namespace jit {
 
+TORCH_API std::shared_ptr<torch::jit::Graph> PrepareForStaticRuntime(
+    std::shared_ptr<torch::jit::Graph> g);
+TORCH_API std::shared_ptr<torch::jit::Graph> PrepareForStaticRuntime(
+    const torch::jit::Module& m);
+
 class ProcessedNode;
 class TORCH_API StaticRuntime {
  public:
-  explicit StaticRuntime(std::shared_ptr<torch::jit::Graph> g)
-      : graph_(std::move(g)) {}
-
-  explicit StaticRuntime(const torch::jit::Module& m);
-
-  std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps);
+  explicit StaticRuntime(std::shared_ptr<torch::jit::Graph> g);
+  std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps) const;
 
 #ifdef FBCODE_CAFFE2
   using ConstantMap = folly::F14FastMap<Value*, IValue>;
@@ -31,12 +32,11 @@
 #endif
 
  private:
-  torch::jit::Module module_;
   std::shared_ptr<torch::jit::Graph> graph_;
 
   // Static runtime states
   // Value table (including weights)
-  ConstantMap workspace_;
+  mutable ConstantMap workspace_;
 
   // The nodes we need to run
   std::vector<ProcessedNode> nodes_;
diff --git a/torch/csrc/jit/runtime/static/init.cpp b/torch/csrc/jit/runtime/static/init.cpp
index d57242d..86292f4 100644
--- a/torch/csrc/jit/runtime/static/init.cpp
+++ b/torch/csrc/jit/runtime/static/init.cpp
@@ -10,10 +10,10 @@
   m.def(
        "_jit_to_static_runtime",
        [](const std::shared_ptr<torch::jit::Graph>& g) {
-         return StaticRuntime(g);
+         return StaticRuntime(PrepareForStaticRuntime(g));
        })
       .def("_jit_to_static_runtime", [](const torch::jit::Module& m) {
-        return StaticRuntime(m);
+        return StaticRuntime(PrepareForStaticRuntime(m));
       });
 }