Make remaining autograd methods return futures (#57861)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57861

The very last methods left that still didn't return Futures were the autograd ones, but they're very easy to port.

We've now finished the conversion of RequestCallback to be fully Future-based!
ghstack-source-id: 129567055

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28286173

fbshipit-source-id: 1de58cee1b4513fb25b7e089eb9c45e2dda69fcb
diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp
index f62c706..80ee0dd 100644
--- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp
+++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp
@@ -370,9 +370,8 @@
   return responseFuture;
 }
 
-void RequestCallbackNoPython::processBackwardAutogradReq(
-    RpcCommandBase& rpc,
-    const c10::intrusive_ptr<JitFuture>& responseFuture) const {
+c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
+    processBackwardAutogradReq(RpcCommandBase& rpc) const {
   auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
   const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
 
@@ -392,20 +391,20 @@
       autogradContext, sendFunction, gradientsCall.retainGraph());
 
   // Our response is satisfied when the rpcs come back.
-  execFuture->addCallback([responseFuture](JitFuture& execFuture) {
-    if (!execFuture.hasError()) {
-      Message m = std::move(PropagateGradientsResp()).toMessage();
-      responseFuture->markCompleted(
-          IValue(c10::make_intrusive<Message>(std::move(m))));
-    } else {
-      responseFuture->setError(execFuture.exception_ptr());
-    }
-  });
+  return execFuture->then(
+      [](JitFuture& execFuture) {
+        if (execFuture.hasError()) {
+          std::rethrow_exception(execFuture.exception_ptr());
+        } else {
+          return c10::make_intrusive<Message>(
+              PropagateGradientsResp().toMessage());
+        }
+      },
+      c10::getCustomClassType<c10::intrusive_ptr<Message>>());
 }
 
-void RequestCallbackNoPython::processCleanupAutogradContextReq(
-    RpcCommandBase& rpc,
-    const std::function<void(Message)>& markComplete) const {
+c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
+    processCleanupAutogradContextReq(RpcCommandBase& rpc) const {
   auto& cleanupContextReq = static_cast<CleanupAutogradContextReq&>(rpc);
   auto cleanupContextId = cleanupContextReq.getContextId();
   // release the context if it still exists on this thread. We need to
@@ -414,7 +413,7 @@
   // notified to clean up their context.
   DistAutogradContainer::getInstance().releaseContextIfPresent(
       cleanupContextId);
-  markComplete(std::move(CleanupAutogradContextResp()).toMessage());
+  return asFuture(CleanupAutogradContextResp().toMessage());
 }
 
 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
@@ -510,14 +509,6 @@
     RpcCommandBase& rpc,
     const MessageType& messageType,
     std::shared_ptr<LazyStreamContext> ctx) const {
-  // TODO Avoid creating a future here and passing it down, and instead allow
-  // each method to create the future however it wants and pass it back up.
-  auto responseFuture = c10::make_intrusive<JitFuture>(
-      c10::getCustomClassType<c10::intrusive_ptr<Message>>());
-  auto markComplete = [&responseFuture](Message m) {
-    responseFuture->markCompleted(
-        IValue(c10::make_intrusive<Message>(std::move(m))));
-  };
   // TODO: RpcCommandBase should have an abstract execute() method that we can
   // call here instead of having another switch statement here. Even better we
   // could have abstract classes RpcRequest and RpcResp which inherit from
@@ -556,12 +547,10 @@
       return processForwardAutogradReq(rpc, std::move(ctx));
     }
     case MessageType::BACKWARD_AUTOGRAD_REQ: {
-      processBackwardAutogradReq(rpc, responseFuture);
-      return responseFuture;
+      return processBackwardAutogradReq(rpc);
     };
     case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
-      processCleanupAutogradContextReq(rpc, markComplete);
-      return responseFuture;
+      return processCleanupAutogradContextReq(rpc);
     }
     case MessageType::RUN_WITH_PROFILING_REQ: {
       return processRunWithProfilingReq(rpc);
diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.h b/torch/csrc/distributed/rpc/request_callback_no_python.h
index e8fd7d8..4bf1099 100644
--- a/torch/csrc/distributed/rpc/request_callback_no_python.h
+++ b/torch/csrc/distributed/rpc/request_callback_no_python.h
@@ -66,13 +66,11 @@
       RpcCommandBase& rpc,
       std::shared_ptr<LazyStreamContext> ctx) const;
 
-  void processBackwardAutogradReq(
-      RpcCommandBase& rpc,
-      const c10::intrusive_ptr<JitFuture>& responseFuture) const;
+  c10::intrusive_ptr<JitFuture> processBackwardAutogradReq(
+      RpcCommandBase& rpc) const;
 
-  void processCleanupAutogradContextReq(
-      RpcCommandBase& rpc,
-      const std::function<void(Message)>& markComplete) const;
+  c10::intrusive_ptr<JitFuture> processCleanupAutogradContextReq(
+      RpcCommandBase& rpc) const;
 
   c10::intrusive_ptr<JitFuture> processRunWithProfilingReq(
       RpcCommandBase& rpc) const;