Unify async execution for JIT functions (#57852)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57852
Another great example of the benefits of Futures. Thanks to the "right abstraction" (i.e., the `thenAsync` method), adding support for async execution becomes trivial, and the code much simpler than what it used to be.
ghstack-source-id: 129567063
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28253842
fbshipit-source-id: b660151ca300f3d6078db0f3e380c80a4d8f5190
diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp
index a0091f1..0de69e0 100644
--- a/torch/csrc/distributed/rpc/request_callback_impl.cpp
+++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp
@@ -152,32 +152,15 @@
return;
}
- auto jitFuture = runJitFunction(scriptCall.qualifiedName(), stack);
+ auto jitFuture = runJitFunction(
+ scriptCall.qualifiedName(), stack, scriptCall.isAsyncExecution());
- jitFuture->addCallback([responseFuture,
- isAsyncExecution = scriptCall.isAsyncExecution(),
- markComplete](JitFuture& jitFutureCaptured) {
- try {
- JitFuture& jitFuture = isAsyncExecution
- ? *jitFutureCaptured.value().toFuture()
- : jitFutureCaptured;
-
- // Setup response callback appropriately.
- auto responseCb = [responseFuture](JitFuture& jitFuture) {
- try {
- Message m = ScriptResp(jitFuture.value()).toMessage();
- responseFuture->markCompleted(
- IValue(c10::make_intrusive<Message>(std::move(m))));
- } catch (const std::exception& /* unused */) {
- responseFuture->setError(std::current_exception());
- }
- };
-
- // Call inline if we don't have async execution.
- isAsyncExecution ? jitFuture.addCallback(responseCb)
- : responseCb(jitFuture);
- } catch (const std::exception& /* unused */) {
- responseFuture->setError(std::current_exception());
+ jitFuture->addCallback([responseFuture, markComplete](JitFuture& jitFuture) {
+ if (jitFuture.hasError()) {
+ responseFuture->setError(jitFuture.exception_ptr());
+ } else {
+ responseFuture->markCompleted(c10::make_intrusive<Message>(
+ ScriptResp(jitFuture.value()).toMessage()));
}
});
}
@@ -240,36 +223,19 @@
return;
}
- auto isAsyncExecution = scriptRemoteCall.isAsyncExecution();
- auto asyncPostProcessing = [ownerRRef, postProcessing, isAsyncExecution](
- c10::ivalue::Future& jitFuture) mutable {
- // The user function will return a JIT future, install
- // setRRefValue and postProcessing to that valueFuture
- try {
- JitFuture& valueJitFuture =
- isAsyncExecution ? *jitFuture.value().toFuture() : jitFuture;
-
- // Setup callback.
- auto setRRefValue = [ownerRRef,
- postProcessing](JitFuture& valueJitFuture) mutable {
- try {
- ownerRRef->setValue(valueJitFuture.value());
- } catch (const std::exception& e) {
- ownerRRef->setError(std::current_exception());
- }
- postProcessing();
- };
-
- // Call inline if not async execution.
- isAsyncExecution ? valueJitFuture.addCallback(setRRefValue)
- : setRRefValue(valueJitFuture);
- } catch (std::exception& e) {
- ownerRRef->setError(std::current_exception());
- postProcessing();
+ auto asyncPostProcessing = [ownerRRef, postProcessing](JitFuture& jitFuture) {
+ if (jitFuture.hasError()) {
+ ownerRRef->setError(jitFuture.exception_ptr());
+ } else {
+ ownerRRef->setValue(jitFuture.value());
}
+ postProcessing();
};
- auto jitFuture = runJitFunction(scriptRemoteCall.qualifiedName(), stack);
+ auto jitFuture = runJitFunction(
+ scriptRemoteCall.qualifiedName(),
+ stack,
+ scriptRemoteCall.isAsyncExecution());
jitFuture->addCallback(asyncPostProcessing);
}
@@ -488,19 +454,35 @@
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runJitFunction(
const c10::QualifiedName& name,
- std::vector<at::IValue>& stack) const {
+ std::vector<at::IValue>& stack,
+ bool isAsyncExecution) const {
+ c10::intrusive_ptr<JitFuture> future;
try {
// runAsync() starts in the calling thread, but may return an uncompleted
// future (though for non-async code, it will typically be completed).
// If it was async, our callback will typically be invoked by the
// continuation on an at::launch() thread.
- return PythonRpcHandler::getInstance()
- .jitCompilationUnit()
- ->get_function(name)
- .runAsync(stack);
+ future = PythonRpcHandler::getInstance()
+ .jitCompilationUnit()
+ ->get_function(name)
+ .runAsync(stack);
} catch (const std::exception&) {
return asFuture(std::current_exception());
}
+
+ if (isAsyncExecution) {
+ at::TypePtr type = future->elementType();
+ if (type->kind() != at::FutureType::Kind) {
+ return asFuture(std::make_exception_ptr(std::runtime_error(c10::str(
+ "Async functions must return an IValue of Future type, but got ",
+ type->str()))));
+ }
+ future = future->thenAsync(
+ [](JitFuture& future) { return future.value().toFuture(); },
+ type->cast<at::FutureType>()->getElementType());
+ }
+
+ return future;
}
} // namespace rpc
diff --git a/torch/csrc/distributed/rpc/request_callback_impl.h b/torch/csrc/distributed/rpc/request_callback_impl.h
index a186280..dbccd73 100644
--- a/torch/csrc/distributed/rpc/request_callback_impl.h
+++ b/torch/csrc/distributed/rpc/request_callback_impl.h
@@ -61,7 +61,8 @@
c10::intrusive_ptr<JitFuture> runJitFunction(
const c10::QualifiedName& name,
- std::vector<at::IValue>& stack) const;
+ std::vector<at::IValue>& stack,
+ bool isAsyncExecution) const;
};
} // namespace rpc
diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
index 45740f8..098f265 100644
--- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
@@ -1308,7 +1308,10 @@
@dist_init
def test_async_function_wrong_return_type(self):
- with self.assertRaisesRegex(RuntimeError, "Expected Future but got Tensor"):
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "Async functions must return an IValue of Future type, but got Tensor",
+ ):
rpc.rpc_sync(
worker_name((self.rank + 1) % self.world_size), async_wrong_type
)
@@ -1368,5 +1371,8 @@
worker_name((self.rank + 1) % self.world_size), async_wrong_type
)
- with self.assertRaisesRegex(RuntimeError, "Expected Future but got Tensor"):
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "Async functions must return an IValue of Future type, but got Tensor",
+ ):
rref.to_here()