[static runtime] Split out graph preparation from runtime (#44131)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44131
Test Plan: Imported from OSS
Reviewed By: hlu1
Differential Revision: D23604305
Pulled By: bwasti
fbshipit-source-id: 7b47da4961d99074199417ef1407a788c7d80ee6
diff --git a/benchmarks/static_runtime/deep_wide_pt_bench.cc b/benchmarks/static_runtime/deep_wide_pt_bench.cc
index ef960d2..21c2923 100644
--- a/benchmarks/static_runtime/deep_wide_pt_bench.cc
+++ b/benchmarks/static_runtime/deep_wide_pt_bench.cc
@@ -60,7 +60,8 @@
static void BM_deep_wide_static(benchmark::State& state) {
auto mod = getDeepAndWideSciptModel();
- torch::jit::StaticRuntime runtime(mod);
+ auto g = torch::jit::PrepareForStaticRuntime(mod);
+ torch::jit::StaticRuntime runtime(g);
const int batch_size = state.range(0);
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
@@ -75,6 +76,28 @@
}
}
+const std::shared_ptr<torch::jit::Graph>& getStaticGraph() {
+ static const std::shared_ptr<torch::jit::Graph> g =
+ torch::jit::PrepareForStaticRuntime(getDeepAndWideSciptModel());
+ return g;
+}
+
+static void BM_deep_wide_static_threaded(benchmark::State& state) {
+ auto g = getStaticGraph();
+ torch::jit::StaticRuntime runtime(g);
+
+ const int batch_size = 1; // state.range(0);
+ auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
+ auto user_emb = torch::randn({batch_size, 1, embedding_size});
+ auto wide = torch::randn({batch_size, num_features});
+
+ std::vector<at::Tensor> inputs({ad_emb_packed, user_emb, wide});
+
+ for (auto _ : state) {
+ runtime.run(inputs);
+ }
+}
+
BENCHMARK(BM_deep_wide_base)->RangeMultiplier(8)->Ranges({{1, 20}});
BENCHMARK(BM_deep_wide_jit_graph_executor)
@@ -86,5 +109,6 @@
->Ranges({{1, 20}});
BENCHMARK(BM_deep_wide_static)->RangeMultiplier(8)->Ranges({{1, 20}});
+BENCHMARK(BM_deep_wide_static_threaded)->Threads(8);
BENCHMARK_MAIN();
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index 6b9b1dd..211839f 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -14,7 +14,8 @@
// run static runtime
std::vector<at::Tensor> input_tensors({a, b, c});
- torch::jit::StaticRuntime runtime(mod);
+ auto g = torch::jit::PrepareForStaticRuntime(mod);
+ torch::jit::StaticRuntime runtime(g);
at::Tensor output_2 = runtime.run(input_tensors)[0];
EXPECT_TRUE(output_1.equal(output_2));
}
@@ -23,7 +24,8 @@
const int embedding_size = 32;
const int num_features = 50;
torch::jit::Module mod = getDeepAndWideSciptModel();
- torch::jit::StaticRuntime runtime(mod);
+ auto g = torch::jit::PrepareForStaticRuntime(mod);
+ torch::jit::StaticRuntime runtime(g);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 5; ++i) {
diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py
index 582cb32..86dafa3 100644
--- a/test/test_static_runtime.py
+++ b/test/test_static_runtime.py
@@ -150,13 +150,13 @@
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
- # def test_trivial_graph(self):
- # s = torch.full((2, 2), 2)
- # tg = torch.jit.script(trivial_graph)
- # o_ref = tg(s, s, s)
- # tg_a = StaticRuntime(tg)
- # o_test = tg_a(s, s, s)[0]
- # torch.testing.assert_allclose(o_ref, o_test)
+ def test_trivial_graph(self):
+ s = torch.full((2, 2), 2)
+ tg = torch.jit.script(trivial_graph)
+ o_ref = tg(s, s, s)
+ tg_a = StaticRuntime(tg)
+ o_test = tg_a(s, s, s)[0]
+ torch.testing.assert_allclose(o_ref, o_test)
if __name__ == "__main__":
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 42abeec..1ebd7cc 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -14,20 +14,16 @@
using c10::DispatchKey;
using c10::RegisterOperators;
-StaticRuntime::StaticRuntime(const torch::jit::Module& m)
- : module_(m.copy()), graph_(nullptr) {
- module_.eval();
- module_ = freeze_module(module_);
- graph_ = module_.get_method("forward").graph();
+std::shared_ptr<torch::jit::Graph> PrepareForStaticRuntime(
+ std::shared_ptr<torch::jit::Graph> g) {
+ Inline(*g);
+ ConstantPropagation(g);
+ Canonicalize(g);
+ ConstantPropagation(g);
+ RemoveTensorMutation(g);
+ ConstantPropagation(g);
- Inline(*graph_);
- ConstantPropagation(graph_);
- Canonicalize(graph_);
- ConstantPropagation(graph_);
- RemoveTensorMutation(graph_);
- ConstantPropagation(graph_);
-
- for (auto n : graph_->nodes()) {
+ for (auto n : g->nodes()) {
if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) {
throw std::runtime_error("Cannot accelerate unfrozen graphs");
}
@@ -45,12 +41,25 @@
}
// remove unused input 0 from graph
- if (graph_->inputs().at(0)->type()->is_module()) {
- if (!graph_->inputs().at(0)->hasUses()) {
- graph_->eraseInput(0);
+ if (g->inputs().at(0)->type()->is_module()) {
+ if (!g->inputs().at(0)->hasUses()) {
+ g->eraseInput(0);
}
}
+ return g;
+}
+
+std::shared_ptr<torch::jit::Graph> PrepareForStaticRuntime(
+ const torch::jit::Module& m) {
+ auto module = m.copy();
+ module.eval();
+ module = freeze_module(module);
+ auto g = module.get_method("forward").graph();
+ return PrepareForStaticRuntime(g);
+}
+
+StaticRuntime::StaticRuntime(std::shared_ptr<torch::jit::Graph> g) : graph_(g) {
// fill workspace_ with constants
for (Node* node : graph_->nodes()) {
if (node->kind() == prim::Constant) {
@@ -63,19 +72,13 @@
}
std::vector<at::Tensor> StaticRuntime::run(
- const std::vector<at::Tensor>& inps) {
+ const std::vector<at::Tensor>& inps) const {
// Container for inputs, outputs, and activations (excluding parameters)
- int start = 0;
- if (graph_->inputs().size() != inps.size()) {
- start = 1;
- CHECK_EQ(graph_->inputs().size(), inps.size() + 1);
- CHECK((graph_->inputs().at(0)->type()->is_module()));
- workspace_[graph_->inputs()[0]] = module_._ivalue();
- }
+ TORCH_INTERNAL_ASSERT(graph_->inputs().size() == inps.size());
for (size_t i = 0; i < inps.size(); i++) {
- workspace_[graph_->inputs()[i + start]] = inps[i];
+ workspace_[graph_->inputs()[i]] = inps[i];
}
for (const auto& n : nodes_) {
diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h
index 3fc8a2e..8fe8c80 100644
--- a/torch/csrc/jit/runtime/static/impl.h
+++ b/torch/csrc/jit/runtime/static/impl.h
@@ -14,15 +14,16 @@
namespace torch {
namespace jit {
+TORCH_API std::shared_ptr<torch::jit::Graph> PrepareForStaticRuntime(
+ std::shared_ptr<torch::jit::Graph> g);
+TORCH_API std::shared_ptr<torch::jit::Graph> PrepareForStaticRuntime(
+ const torch::jit::Module& m);
+
class ProcessedNode;
class TORCH_API StaticRuntime {
public:
- explicit StaticRuntime(std::shared_ptr<torch::jit::Graph> g)
- : graph_(std::move(g)) {}
-
- explicit StaticRuntime(const torch::jit::Module& m);
-
- std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps);
+ explicit StaticRuntime(std::shared_ptr<torch::jit::Graph> g);
+ std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps) const;
#ifdef FBCODE_CAFFE2
using ConstantMap = folly::F14FastMap<Value*, IValue>;
@@ -31,12 +32,11 @@
#endif
private:
- torch::jit::Module module_;
std::shared_ptr<torch::jit::Graph> graph_;
// Static runtime states
// Value table (including weights)
- ConstantMap workspace_;
+ mutable ConstantMap workspace_;
// The nodes we need to run
std::vector<ProcessedNode> nodes_;
diff --git a/torch/csrc/jit/runtime/static/init.cpp b/torch/csrc/jit/runtime/static/init.cpp
index d57242d..86292f4 100644
--- a/torch/csrc/jit/runtime/static/init.cpp
+++ b/torch/csrc/jit/runtime/static/init.cpp
@@ -10,10 +10,10 @@
m.def(
"_jit_to_static_runtime",
[](const std::shared_ptr<torch::jit::Graph>& g) {
- return StaticRuntime(g);
+ return StaticRuntime(PrepareForStaticRuntime(g));
})
.def("_jit_to_static_runtime", [](const torch::jit::Module& m) {
- return StaticRuntime(m);
+ return StaticRuntime(PrepareForStaticRuntime(m));
});
}