[Static Runtime] support forked subgraph execution on parent graph's executor (#80381)
Summary:
- support async excecution of forked nodes on custom executor
- fork subgraph execution was performed on inter-op thread pool executor by default
- Handle forked graph async execution on custom executor when the parent graph is executed with runAsync() API passing the executor for async ops
Differential Revision: D37466525
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80381
Approved by: https://github.com/mikeiovine
diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp
index 313f893..6124c91 100644
--- a/torch/csrc/jit/runtime/static/native_ops.cpp
+++ b/torch/csrc/jit/runtime/static/native_ops.cpp
@@ -967,16 +967,19 @@
ForkedSubgraphSRLauncher(
std::shared_ptr<StaticModule> smodule,
std::vector<IValue> args,
- c10::intrusive_ptr<Future> future)
+ c10::intrusive_ptr<Future> future,
+ TaskLauncher launcher)
: smodule_(std::move(smodule)),
args_(std::move(args)),
- future_(std::move(future)) {}
+ future_(std::move(future)),
+ launcher_(std::move(launcher)) {}
void operator()() {
try {
StaticRuntime runtime(*smodule_);
- auto output = runtime(args_, {});
- future_->markCompleted(output);
+ auto future_subgraph = runtime.runAsync(args_, {}, launcher_);
+ future_subgraph->waitAndThrow();
+ future_->markCompleted(future_subgraph->value());
} catch (const std::exception& e) {
future_->setErrorIfNeeded(
std::make_exception_ptr(c10::ivalue::Future::FutureError(e.what())));
@@ -987,6 +990,7 @@
std::shared_ptr<StaticModule> smodule_;
std::vector<IValue> args_;
c10::intrusive_ptr<Future> future_;
+ torch::jit::TaskLauncher launcher_;
};
/*
@@ -1040,11 +1044,12 @@
createFutureTypeFromGraphOutput(forkedGraph);
p_node->Output(0) = future;
- ForkedSubgraphSRLauncher runtime_launcher(smodule, args, future);
auto* metadata = p_node->metadata();
DCHECK(metadata);
auto* launcher = metadata->launcher();
DCHECK(launcher);
+ ForkedSubgraphSRLauncher runtime_launcher(
+ smodule, args, future, *launcher);
(*launcher)(std::move(runtime_launcher));
};
});