[RPC Reliability] Implemented retries for RPCs with exponential backoff (#32602)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32602

This adds functionality for re-trying RPC's that are sent with the function `sendWithRetries()`. It adds RPC's that will potentially need to be retried to a sorted map that contains the timeout at which to retry the RPC and associated metadata. A separate thread iteratively removes the earliest retry-able RPC from the map, sleeps until the corresponding time point, re-tries the RPC, and adds to the map again with a future timeout.

GitHub Issue: https://github.com/pytorch/pytorch/issues/32124

Per the first 3 milestones, the following will be addressed in future PR's:
* enabling RPC Retries for RRef internal messages

Differential Revision: D19560159

fbshipit-source-id: 40cd86f9a25dc24367624d279a3b9720b20824cf
diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp
index 9d2745e..d34f983 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.cpp
+++ b/torch/csrc/distributed/rpc/process_group_agent.cpp
@@ -241,7 +241,7 @@
       std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
 }
 
-void ProcessGroupAgent::shutdown() {
+void ProcessGroupAgent::shutdownImpl() {
   LOG(INFO) << "Shutting down ProcessGroupAgent on rank " << pg_->getRank()
             << ".";
   std::unique_lock<std::mutex> lock{futureMutex_};
@@ -653,8 +653,10 @@
          << " milliseconds and timed out.";
       const auto exceptionMsg = createExceptionResponse(
           Message({}, {}, MessageType::EXCEPTION), ss.str());
-      timedOutFuture.future_->setError(std::string(
-          exceptionMsg.payload().begin(), exceptionMsg.payload().end()));
+      if (!timedOutFuture.future_->hasError()) {
+        timedOutFuture.future_->setError(std::string(
+            exceptionMsg.payload().begin(), exceptionMsg.payload().end()));
+      }
 
       const int dst = timedOutFuture.dstRank_;
       recvCounts_.increment(dst);
diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h
index 1d910b3..49f89a2 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.h
+++ b/torch/csrc/distributed/rpc/process_group_agent.h
@@ -63,7 +63,7 @@
 
   void start() override;
 
-  void shutdown() override;
+  void shutdownImpl() override;
 
   ~ProcessGroupAgent() override;
 
diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp
index 6f2f137..31e3ada 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.cpp
+++ b/torch/csrc/distributed/rpc/rpc_agent.cpp
@@ -6,6 +6,12 @@
 
 constexpr size_t WorkerInfo::MAX_NAME_LEN;
 
+// Large Time Duration for waiting on the condition variable until the map is
+// population. Cannot use
+// std::chrono::time_point<std::chrono::steady_clock>::max() due to a known
+// overflow-related bug.
+constexpr auto kLargeTimeDuration = std::chrono::hours(10000);
+
 RpcAgent::RpcAgent(
     WorkerInfo workerId,
     std::unique_ptr<RequestCallback> cb,
@@ -13,9 +19,170 @@
     : workerInfo_(std::move(workerId)),
       cb_(std::move(cb)),
       rpcTimeout_(rpcTimeout),
-      profilingEnabled_(false) {}
+      profilingEnabled_(false),
+      rpcAgentRunning_(true) {
+  rpcRetryThread_ = std::thread(&RpcAgent::retryExpiredRpcs, this);
+}
 
-RpcAgent::~RpcAgent() = default;
+RpcAgent::~RpcAgent() {
+  shutdown();
+}
+
+void RpcAgent::shutdown() {
+  if (!rpcAgentRunning_.exchange(false)) {
+    return;
+  }
+  rpcRetryMapCV_.notify_one();
+  rpcRetryThread_.join();
+  shutdownImpl();
+}
+
+std::shared_ptr<FutureMessage> RpcAgent::sendWithRetries(
+    const WorkerInfo& to,
+    Message&& message,
+    RpcRetryOptions retryOptions) {
+  TORCH_CHECK(retryOptions.maxRetries >= 0, "maxRetries cannot be negative.");
+  TORCH_CHECK(
+      retryOptions.retryBackoff >= 1,
+      "maxRetries cannot be exponentially decaying.");
+  TORCH_CHECK(
+      retryOptions.rpcRetryDuration.count() >= 0,
+      "rpcRetryDuration cannot be negative.");
+
+  auto originalFuture = std::make_shared<FutureMessage>();
+  steady_clock_time_point newTime =
+      computeNewRpcRetryTime(retryOptions, /* retryCount */ 0);
+  // Making a copy of the message so it can be retried after the first send.
+  Message msgCopy = message;
+  auto fm = send(to, std::move(message));
+  auto firstRetryRpc = std::make_shared<RpcRetryInfo>(
+      to,
+      std::move(msgCopy),
+      originalFuture,
+      /* retryCount */ 0,
+      retryOptions);
+
+  fm->addCallback([this, newTime, firstRetryRpc](
+                      const rpc::Message& lambdaMessage,
+                      const c10::optional<utils::FutureError>& futErr) {
+    rpcRetryCallback(lambdaMessage, futErr, newTime, firstRetryRpc);
+  });
+
+  return originalFuture;
+}
+
+void RpcAgent::retryExpiredRpcs() {
+  while (rpcAgentRunning_.load()) {
+    std::unique_lock<std::mutex> lock(rpcRetryMutex_);
+
+    // We must continue sleeping as long as the RPC Agent is running and when
+    // either the Retry Map is empty, or when the Retry Map's earliest expiring
+    // RPC is set to be retried in the future.
+    steady_clock_time_point earliestTimeout =
+        std::chrono::steady_clock::now() + kLargeTimeDuration;
+
+    for (;;) {
+      if (!rpcAgentRunning_.load())
+        return;
+      if (std::chrono::steady_clock::now() >= earliestTimeout)
+        break;
+      if (!rpcRetryMap_.empty()) {
+        earliestTimeout = rpcRetryMap_.begin()->first;
+      }
+      rpcRetryMapCV_.wait_until(lock, earliestTimeout);
+    }
+
+    // Updating these since something may have been added to the map while this
+    // thread was sleeping.
+    earliestTimeout = rpcRetryMap_.begin()->first;
+    auto& earliestRpcList = rpcRetryMap_.begin()->second;
+
+    // We iterate through all the RPC's set to be retried at the current
+    // timepoint, resend those RPC's, and add the RPC's and their futures to
+    // a list to later attach callbacks. These callbacks either schedule
+    // the RPC for a future retry or marks it with success/error depending on
+    // the outcome of the current send. Then, we clean up the rpcRetryMap_.
+    for (auto it = earliestRpcList.begin(); it != earliestRpcList.end();
+         /* no increment */) {
+      auto& earliestRpc = *it;
+      // Making a copy of the message so it can be retried in the future.
+      Message msgCopy = earliestRpc->message_;
+      auto fm = send(earliestRpc->to_, std::move(msgCopy));
+      futures.emplace_back(fm, earliestRpc);
+
+      // A callback will be attached to all futures for the retries in this
+      // list. Thus they will either be rescheduled for future retries or they
+      // will be marked as complete. We can safely delete them from the retry
+      // Map for the current timepoint.
+      it = earliestRpcList.erase(it);
+    }
+
+    lock.unlock();
+    // We attach callbacks to the futures outside of the lock to prevent
+    // potential deadlocks.
+    for (const auto& it : futures) {
+      auto fm = it.first;
+      auto earliestRpc = it.second;
+      steady_clock_time_point newTime = computeNewRpcRetryTime(
+          earliestRpc->options_, earliestRpc->retryCount_);
+      earliestRpc->retryCount_++;
+
+      fm->addCallback([this, newTime, earliestRpc](
+                          const rpc::Message& message,
+                          const c10::optional<utils::FutureError>& futErr) {
+        rpcRetryCallback(message, futErr, newTime, earliestRpc);
+      });
+    }
+
+    // If there are no more RPC's set to be retried at the current timepoint,
+    // we can remove the corresponsing unordered_set from the retry map. We
+    // must also clear the futures vector.
+    {
+      std::lock_guard<std::mutex> retryMapLock(rpcRetryMutex_);
+      futures.clear();
+      if (earliestRpcList.empty()) {
+        rpcRetryMap_.erase(earliestTimeout);
+      }
+    }
+  }
+}
+
+void RpcAgent::rpcRetryCallback(
+    const rpc::Message& message,
+    const c10::optional<utils::FutureError>& futErr,
+    steady_clock_time_point newTime,
+    std::shared_ptr<RpcRetryInfo> earliestRpc) {
+  if (futErr) {
+    // Adding one since we want to include the original send as well and not
+    // just the retry count.
+    LOG(INFO) << "Send try " << std::to_string(earliestRpc->retryCount_ + 1)
+              << " failed";
+    if (earliestRpc->retryCount_ < earliestRpc->options_.maxRetries) {
+      // If the previous future completed with an error and we haven't
+      // completed maxRetries send attempts, we move the earliestRpc
+      // struct to a new time point in the retry map (effectively
+      // scheduling it for a future retry.)
+      {
+        std::lock_guard<std::mutex> retryMapLock(rpcRetryMutex_);
+        rpcRetryMap_[newTime].emplace(std::move(earliestRpc));
+      }
+      // The retry thread waits for the map to be populated. Thus we notify
+      // once an item has been added.
+      rpcRetryMapCV_.notify_one();
+    } else {
+      // We have completed maxRetries send attempts. We're now marking
+      // the future with an error.
+      std::string errorMessage = c10::str(
+          "The RPC has not succeeded after the specified number of max retries (",
+          earliestRpc->options_.maxRetries,
+          ").");
+      earliestRpc->originalFuture_->setError(errorMessage);
+    }
+  } else {
+    // This try succeeded, so we can make the original future as complete.
+    earliestRpc->originalFuture_->markCompleted(message);
+  }
+}
 
 const WorkerInfo& RpcAgent::getWorkerInfo() const {
   return workerInfo_;
diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h
index aa7d853..d37e5bf 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.h
+++ b/torch/csrc/distributed/rpc/rpc_agent.h
@@ -11,6 +11,9 @@
 namespace distributed {
 namespace rpc {
 
+using steady_clock_time_point =
+    std::chrono::time_point<std::chrono::steady_clock>;
+
 struct RpcBackendOptions {
   RpcBackendOptions() = default;
   std::chrono::milliseconds rpcTimeout;
@@ -55,6 +58,44 @@
   const worker_id_t id_;
 };
 
+// Struct for options to configure the RPC Retry protocol.
+struct TORCH_API RpcRetryOptions {
+  // Using a default constructor like all other Options structs in the RPC
+  // codebase. TORCH_CHECKs for input validation are done in the
+  // sendWithRetries function.
+  RpcRetryOptions() = default;
+  // Maximum number of times we will retry the RPC
+  int maxRetries{3};
+  // Initial duration between consecutive RPC send attempts
+  std::chrono::milliseconds rpcRetryDuration{std::chrono::milliseconds(1000)};
+  // Constant for exponential backoff used while calculating future wait
+  // durations
+  float retryBackoff{1.5};
+};
+
+// Struct that stores all the metadata needed to retry a given RPC.
+struct TORCH_API RpcRetryInfo {
+  RpcRetryInfo(
+      const WorkerInfo& to,
+      Message&& message,
+      std::shared_ptr<FutureMessage> originalFuture,
+      int retryCount,
+      RpcRetryOptions options)
+      : to_(to),
+        message_(message),
+        originalFuture_(std::move(originalFuture)),
+        retryCount_(retryCount),
+        options_(options) {}
+
+  const WorkerInfo& to_;
+  Message message_;
+  // Future that is returned to the caller of sendWithRetries().
+  std::shared_ptr<FutureMessage> originalFuture_;
+  // Number of send attempts completed so far.
+  int retryCount_;
+  RpcRetryOptions options_;
+};
+
 // ``RpcAgent`` is the base class for sending and receiving RPC messages. It
 // provides a unified ``send`` API for both request and response messages, and
 // will invoke the given ``RequestCallback`` to process received requests. It
@@ -93,6 +134,23 @@
       const WorkerInfo& to,
       Message&& message) = 0;
 
+  // Retries sending the message up to maxRetries times until an ACK is
+  // receieved. The duration between consecutive sends is increased over
+  // time using an exponential backoff algorithm.
+  //
+  // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a
+  // ``FutureMessage`` ptr, just like send(). Caller can specify the maximum
+  // number of retries for this RPC (default is 3), initial duration between
+  // sends (default is 1000ms), and backoff constant (default is 1.5) by
+  // passing in the RpcRetryOptions struct. This API might end up
+  // executing a method twice on the remote end (it does not guarantee
+  // exactly-once semantics). Therefore, the user must ensure their requests
+  // are idempotent.
+  std::shared_ptr<FutureMessage> sendWithRetries(
+      const WorkerInfo& to,
+      Message&& message,
+      RpcRetryOptions retryOptions = RpcRetryOptions());
+
   // Return a reference to the ``WorkerInfo`` of this RpcAgent.
   // NB: not using ``c10::optional<const std::string&>`` here because we might
   // need to create a separate RPC API lib and avoid forcing all ``RpcAgent``
@@ -130,7 +188,7 @@
 
   // Stop accepting requests and shutdown the RPC framework as soon as possible
   // by terminating all RPC threads.
-  virtual void shutdown() = 0;
+  void shutdown();
 
   // Check if current RPC agent is set.
   static bool isCurrentRpcAgentSet();
@@ -160,11 +218,70 @@
   std::atomic<std::chrono::milliseconds> rpcTimeout_;
   std::atomic<bool> profilingEnabled_;
 
+  // Pure virtual function for shutting down the RPC framework.
+  virtual void shutdownImpl() = 0;
+
  private:
   static std::shared_ptr<RpcAgent> currentRpcAgent_;
   // Add GIL wait time data point to metrics
   virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0;
   friend class PythonRpcHandler;
+
+  // Map that stores metadata for RPC's that may need to be re-tried as well as
+  // the timepoint at which we should re-try them.
+  std::map<
+      steady_clock_time_point,
+      std::unordered_set<std::shared_ptr<RpcRetryInfo>>>
+      rpcRetryMap_;
+
+  // Thread that checks for retryable RPC's in the rpcRetryMap_ and sleeps until
+  // the next unACKed RPC's timeout has expired.
+  std::thread rpcRetryThread_;
+
+  // Function that rpcRetryThread_ calls in a loop as long as RpcAgent is
+  // running.
+  void retryExpiredRpcs();
+
+  // This is the callback attached to futures corresponding to send retries.
+  // This handles 3 cases: 1). send was completed, 2). send failed with an
+  // error and we've done maxRetries failed send attempts, and 3). send
+  // failed with an error and we have more retries to go. In case 1, we mark
+  // the original future as complete. In case 2, we mark the future with an
+  // error and do not retry again. In case 3, we move the RpcRetryInfo struct
+  // to another time point in the map to schedule the RPC for a future send.
+  void rpcRetryCallback(
+      const rpc::Message& message,
+      const c10::optional<utils::FutureError>& futErr,
+      steady_clock_time_point newTime,
+      std::shared_ptr<RpcRetryInfo> earliestRpc);
+
+  // Function that uses the exponential backoff algorithm to compute the next
+  // time point to retry a given RPC.
+  inline steady_clock_time_point computeNewRpcRetryTime(
+      RpcRetryOptions& options,
+      int retryCount) {
+    // The exponential backoff algorithm being used here is:
+    // newTime = timeNow + (retryDuration * (backoffConstant ^ retryCount)).
+    std::chrono::milliseconds timedelta =
+        std::chrono::duration_cast<std::chrono::milliseconds>(
+            options.rpcRetryDuration * pow(options.retryBackoff, retryCount));
+    return std::chrono::time_point_cast<std::chrono::milliseconds>(
+        std::chrono::steady_clock::now() + timedelta);
+  }
+
+  // Boolean that indicates whether RpcAgent is running.
+  std::atomic<bool> rpcAgentRunning_;
+
+  // storing futures before adding callback
+  std::vector<
+      std::pair<std::shared_ptr<FutureMessage>, std::shared_ptr<RpcRetryInfo>>>
+      futures;
+
+  // Condition Variable to signal when the rpcRetryMap_ has been populated.
+  std::condition_variable rpcRetryMapCV_;
+
+  // Mutex to protect RpcRetryMap_.
+  std::mutex rpcRetryMutex_;
 };
 
 } // namespace rpc