Unify async execution for JIT functions (#57852)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57852

Another great example of the benefits of Futures. Thanks to the "right abstraction" (i.e., the `thenAsync` method), adding support for async execution becomes trivial, and the code much simpler than what it used to be.
ghstack-source-id: 129567063

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28253842

fbshipit-source-id: b660151ca300f3d6078db0f3e380c80a4d8f5190
diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp
index a0091f1..0de69e0 100644
--- a/torch/csrc/distributed/rpc/request_callback_impl.cpp
+++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp
@@ -152,32 +152,15 @@
     return;
   }
 
-  auto jitFuture = runJitFunction(scriptCall.qualifiedName(), stack);
+  auto jitFuture = runJitFunction(
+      scriptCall.qualifiedName(), stack, scriptCall.isAsyncExecution());
 
-  jitFuture->addCallback([responseFuture,
-                          isAsyncExecution = scriptCall.isAsyncExecution(),
-                          markComplete](JitFuture& jitFutureCaptured) {
-    try {
-      JitFuture& jitFuture = isAsyncExecution
-          ? *jitFutureCaptured.value().toFuture()
-          : jitFutureCaptured;
-
-      // Setup response callback appropriately.
-      auto responseCb = [responseFuture](JitFuture& jitFuture) {
-        try {
-          Message m = ScriptResp(jitFuture.value()).toMessage();
-          responseFuture->markCompleted(
-              IValue(c10::make_intrusive<Message>(std::move(m))));
-        } catch (const std::exception& /* unused */) {
-          responseFuture->setError(std::current_exception());
-        }
-      };
-
-      // Call inline if we don't have async execution.
-      isAsyncExecution ? jitFuture.addCallback(responseCb)
-                       : responseCb(jitFuture);
-    } catch (const std::exception& /* unused */) {
-      responseFuture->setError(std::current_exception());
+  jitFuture->addCallback([responseFuture, markComplete](JitFuture& jitFuture) {
+    if (jitFuture.hasError()) {
+      responseFuture->setError(jitFuture.exception_ptr());
+    } else {
+      responseFuture->markCompleted(c10::make_intrusive<Message>(
+          ScriptResp(jitFuture.value()).toMessage()));
     }
   });
 }
@@ -240,36 +223,19 @@
     return;
   }
 
-  auto isAsyncExecution = scriptRemoteCall.isAsyncExecution();
-  auto asyncPostProcessing = [ownerRRef, postProcessing, isAsyncExecution](
-                                 c10::ivalue::Future& jitFuture) mutable {
-    // The user function will return a JIT future, install
-    // setRRefValue and postProcessing to that valueFuture
-    try {
-      JitFuture& valueJitFuture =
-          isAsyncExecution ? *jitFuture.value().toFuture() : jitFuture;
-
-      // Setup callback.
-      auto setRRefValue = [ownerRRef,
-                           postProcessing](JitFuture& valueJitFuture) mutable {
-        try {
-          ownerRRef->setValue(valueJitFuture.value());
-        } catch (const std::exception& e) {
-          ownerRRef->setError(std::current_exception());
-        }
-        postProcessing();
-      };
-
-      // Call inline if not async execution.
-      isAsyncExecution ? valueJitFuture.addCallback(setRRefValue)
-                       : setRRefValue(valueJitFuture);
-    } catch (std::exception& e) {
-      ownerRRef->setError(std::current_exception());
-      postProcessing();
+  auto asyncPostProcessing = [ownerRRef, postProcessing](JitFuture& jitFuture) {
+    if (jitFuture.hasError()) {
+      ownerRRef->setError(jitFuture.exception_ptr());
+    } else {
+      ownerRRef->setValue(jitFuture.value());
     }
+    postProcessing();
   };
 
-  auto jitFuture = runJitFunction(scriptRemoteCall.qualifiedName(), stack);
+  auto jitFuture = runJitFunction(
+      scriptRemoteCall.qualifiedName(),
+      stack,
+      scriptRemoteCall.isAsyncExecution());
 
   jitFuture->addCallback(asyncPostProcessing);
 }
@@ -488,19 +454,35 @@
 
 c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runJitFunction(
     const c10::QualifiedName& name,
-    std::vector<at::IValue>& stack) const {
+    std::vector<at::IValue>& stack,
+    bool isAsyncExecution) const {
+  c10::intrusive_ptr<JitFuture> future;
   try {
     // runAsync() starts in the calling thread, but may return an uncompleted
     // future (though for non-async code, it will typically be completed).
     // If it was async, our callback will typically be invoked by the
     // continuation on an at::launch() thread.
-    return PythonRpcHandler::getInstance()
-        .jitCompilationUnit()
-        ->get_function(name)
-        .runAsync(stack);
+    future = PythonRpcHandler::getInstance()
+                 .jitCompilationUnit()
+                 ->get_function(name)
+                 .runAsync(stack);
   } catch (const std::exception&) {
     return asFuture(std::current_exception());
   }
+
+  if (isAsyncExecution) {
+    at::TypePtr type = future->elementType();
+    if (type->kind() != at::FutureType::Kind) {
+      return asFuture(std::make_exception_ptr(std::runtime_error(c10::str(
+          "Async functions must return an IValue of Future type, but got ",
+          type->str()))));
+    }
+    future = future->thenAsync(
+        [](JitFuture& future) { return future.value().toFuture(); },
+        type->cast<at::FutureType>()->getElementType());
+  }
+
+  return future;
 }
 
 } // namespace rpc
diff --git a/torch/csrc/distributed/rpc/request_callback_impl.h b/torch/csrc/distributed/rpc/request_callback_impl.h
index a186280..dbccd73 100644
--- a/torch/csrc/distributed/rpc/request_callback_impl.h
+++ b/torch/csrc/distributed/rpc/request_callback_impl.h
@@ -61,7 +61,8 @@
 
   c10::intrusive_ptr<JitFuture> runJitFunction(
       const c10::QualifiedName& name,
-      std::vector<at::IValue>& stack) const;
+      std::vector<at::IValue>& stack,
+      bool isAsyncExecution) const;
 };
 
 } // namespace rpc
diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
index 45740f8..098f265 100644
--- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
@@ -1308,7 +1308,10 @@
 
     @dist_init
     def test_async_function_wrong_return_type(self):
-        with self.assertRaisesRegex(RuntimeError, "Expected Future but got Tensor"):
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Async functions must return an IValue of Future type, but got Tensor",
+        ):
             rpc.rpc_sync(
                 worker_name((self.rank + 1) % self.world_size), async_wrong_type
             )
@@ -1368,5 +1371,8 @@
             worker_name((self.rank + 1) % self.world_size), async_wrong_type
         )
 
-        with self.assertRaisesRegex(RuntimeError, "Expected Future but got Tensor"):
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Async functions must return an IValue of Future type, but got Tensor",
+        ):
             rref.to_here()