[RPC Reliability] Implemented retries for RPCs with exponential backoff (#32602)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32602
This adds functionality for re-trying RPC's that are sent with the function `sendWithRetries()`. It adds RPC's that will potentially need to be retried to a sorted map that contains the timeout at which to retry the RPC and associated metadata. A separate thread iteratively removes the earliest retry-able RPC from the map, sleeps until the corresponding time point, re-tries the RPC, and adds to the map again with a future timeout.
GitHub Issue: https://github.com/pytorch/pytorch/issues/32124
Per the first 3 milestones, the following will be addressed in future PR's:
* enabling RPC Retries for RRef internal messages
Differential Revision: D19560159
fbshipit-source-id: 40cd86f9a25dc24367624d279a3b9720b20824cf
diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp
index 9d2745e..d34f983 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.cpp
+++ b/torch/csrc/distributed/rpc/process_group_agent.cpp
@@ -241,7 +241,7 @@
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
}
-void ProcessGroupAgent::shutdown() {
+void ProcessGroupAgent::shutdownImpl() {
LOG(INFO) << "Shutting down ProcessGroupAgent on rank " << pg_->getRank()
<< ".";
std::unique_lock<std::mutex> lock{futureMutex_};
@@ -653,8 +653,10 @@
<< " milliseconds and timed out.";
const auto exceptionMsg = createExceptionResponse(
Message({}, {}, MessageType::EXCEPTION), ss.str());
- timedOutFuture.future_->setError(std::string(
- exceptionMsg.payload().begin(), exceptionMsg.payload().end()));
+ if (!timedOutFuture.future_->hasError()) {
+ timedOutFuture.future_->setError(std::string(
+ exceptionMsg.payload().begin(), exceptionMsg.payload().end()));
+ }
const int dst = timedOutFuture.dstRank_;
recvCounts_.increment(dst);
diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h
index 1d910b3..49f89a2 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.h
+++ b/torch/csrc/distributed/rpc/process_group_agent.h
@@ -63,7 +63,7 @@
void start() override;
- void shutdown() override;
+ void shutdownImpl() override;
~ProcessGroupAgent() override;
diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp
index 6f2f137..31e3ada 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.cpp
+++ b/torch/csrc/distributed/rpc/rpc_agent.cpp
@@ -6,6 +6,12 @@
constexpr size_t WorkerInfo::MAX_NAME_LEN;
+// Large Time Duration for waiting on the condition variable until the map is
+// population. Cannot use
+// std::chrono::time_point<std::chrono::steady_clock>::max() due to a known
+// overflow-related bug.
+constexpr auto kLargeTimeDuration = std::chrono::hours(10000);
+
RpcAgent::RpcAgent(
WorkerInfo workerId,
std::unique_ptr<RequestCallback> cb,
@@ -13,9 +19,170 @@
: workerInfo_(std::move(workerId)),
cb_(std::move(cb)),
rpcTimeout_(rpcTimeout),
- profilingEnabled_(false) {}
+ profilingEnabled_(false),
+ rpcAgentRunning_(true) {
+ rpcRetryThread_ = std::thread(&RpcAgent::retryExpiredRpcs, this);
+}
-RpcAgent::~RpcAgent() = default;
+RpcAgent::~RpcAgent() {
+ shutdown();
+}
+
+void RpcAgent::shutdown() {
+ if (!rpcAgentRunning_.exchange(false)) {
+ return;
+ }
+ rpcRetryMapCV_.notify_one();
+ rpcRetryThread_.join();
+ shutdownImpl();
+}
+
+std::shared_ptr<FutureMessage> RpcAgent::sendWithRetries(
+ const WorkerInfo& to,
+ Message&& message,
+ RpcRetryOptions retryOptions) {
+ TORCH_CHECK(retryOptions.maxRetries >= 0, "maxRetries cannot be negative.");
+ TORCH_CHECK(
+ retryOptions.retryBackoff >= 1,
+ "maxRetries cannot be exponentially decaying.");
+ TORCH_CHECK(
+ retryOptions.rpcRetryDuration.count() >= 0,
+ "rpcRetryDuration cannot be negative.");
+
+ auto originalFuture = std::make_shared<FutureMessage>();
+ steady_clock_time_point newTime =
+ computeNewRpcRetryTime(retryOptions, /* retryCount */ 0);
+ // Making a copy of the message so it can be retried after the first send.
+ Message msgCopy = message;
+ auto fm = send(to, std::move(message));
+ auto firstRetryRpc = std::make_shared<RpcRetryInfo>(
+ to,
+ std::move(msgCopy),
+ originalFuture,
+ /* retryCount */ 0,
+ retryOptions);
+
+ fm->addCallback([this, newTime, firstRetryRpc](
+ const rpc::Message& lambdaMessage,
+ const c10::optional<utils::FutureError>& futErr) {
+ rpcRetryCallback(lambdaMessage, futErr, newTime, firstRetryRpc);
+ });
+
+ return originalFuture;
+}
+
+void RpcAgent::retryExpiredRpcs() {
+ while (rpcAgentRunning_.load()) {
+ std::unique_lock<std::mutex> lock(rpcRetryMutex_);
+
+ // We must continue sleeping as long as the RPC Agent is running and when
+ // either the Retry Map is empty, or when the Retry Map's earliest expiring
+ // RPC is set to be retried in the future.
+ steady_clock_time_point earliestTimeout =
+ std::chrono::steady_clock::now() + kLargeTimeDuration;
+
+ for (;;) {
+ if (!rpcAgentRunning_.load())
+ return;
+ if (std::chrono::steady_clock::now() >= earliestTimeout)
+ break;
+ if (!rpcRetryMap_.empty()) {
+ earliestTimeout = rpcRetryMap_.begin()->first;
+ }
+ rpcRetryMapCV_.wait_until(lock, earliestTimeout);
+ }
+
+ // Updating these since something may have been added to the map while this
+ // thread was sleeping.
+ earliestTimeout = rpcRetryMap_.begin()->first;
+ auto& earliestRpcList = rpcRetryMap_.begin()->second;
+
+ // We iterate through all the RPC's set to be retried at the current
+ // timepoint, resend those RPC's, and add the RPC's and their futures to
+ // a list to later attach callbacks. These callbacks either schedule
+ // the RPC for a future retry or marks it with success/error depending on
+ // the outcome of the current send. Then, we clean up the rpcRetryMap_.
+ for (auto it = earliestRpcList.begin(); it != earliestRpcList.end();
+ /* no increment */) {
+ auto& earliestRpc = *it;
+ // Making a copy of the message so it can be retried in the future.
+ Message msgCopy = earliestRpc->message_;
+ auto fm = send(earliestRpc->to_, std::move(msgCopy));
+ futures.emplace_back(fm, earliestRpc);
+
+ // A callback will be attached to all futures for the retries in this
+ // list. Thus they will either be rescheduled for future retries or they
+ // will be marked as complete. We can safely delete them from the retry
+ // Map for the current timepoint.
+ it = earliestRpcList.erase(it);
+ }
+
+ lock.unlock();
+ // We attach callbacks to the futures outside of the lock to prevent
+ // potential deadlocks.
+ for (const auto& it : futures) {
+ auto fm = it.first;
+ auto earliestRpc = it.second;
+ steady_clock_time_point newTime = computeNewRpcRetryTime(
+ earliestRpc->options_, earliestRpc->retryCount_);
+ earliestRpc->retryCount_++;
+
+ fm->addCallback([this, newTime, earliestRpc](
+ const rpc::Message& message,
+ const c10::optional<utils::FutureError>& futErr) {
+ rpcRetryCallback(message, futErr, newTime, earliestRpc);
+ });
+ }
+
+ // If there are no more RPC's set to be retried at the current timepoint,
+ // we can remove the corresponsing unordered_set from the retry map. We
+ // must also clear the futures vector.
+ {
+ std::lock_guard<std::mutex> retryMapLock(rpcRetryMutex_);
+ futures.clear();
+ if (earliestRpcList.empty()) {
+ rpcRetryMap_.erase(earliestTimeout);
+ }
+ }
+ }
+}
+
+void RpcAgent::rpcRetryCallback(
+ const rpc::Message& message,
+ const c10::optional<utils::FutureError>& futErr,
+ steady_clock_time_point newTime,
+ std::shared_ptr<RpcRetryInfo> earliestRpc) {
+ if (futErr) {
+ // Adding one since we want to include the original send as well and not
+ // just the retry count.
+ LOG(INFO) << "Send try " << std::to_string(earliestRpc->retryCount_ + 1)
+ << " failed";
+ if (earliestRpc->retryCount_ < earliestRpc->options_.maxRetries) {
+ // If the previous future completed with an error and we haven't
+ // completed maxRetries send attempts, we move the earliestRpc
+ // struct to a new time point in the retry map (effectively
+ // scheduling it for a future retry.)
+ {
+ std::lock_guard<std::mutex> retryMapLock(rpcRetryMutex_);
+ rpcRetryMap_[newTime].emplace(std::move(earliestRpc));
+ }
+ // The retry thread waits for the map to be populated. Thus we notify
+ // once an item has been added.
+ rpcRetryMapCV_.notify_one();
+ } else {
+ // We have completed maxRetries send attempts. We're now marking
+ // the future with an error.
+ std::string errorMessage = c10::str(
+ "The RPC has not succeeded after the specified number of max retries (",
+ earliestRpc->options_.maxRetries,
+ ").");
+ earliestRpc->originalFuture_->setError(errorMessage);
+ }
+ } else {
+ // This try succeeded, so we can make the original future as complete.
+ earliestRpc->originalFuture_->markCompleted(message);
+ }
+}
const WorkerInfo& RpcAgent::getWorkerInfo() const {
return workerInfo_;
diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h
index aa7d853..d37e5bf 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.h
+++ b/torch/csrc/distributed/rpc/rpc_agent.h
@@ -11,6 +11,9 @@
namespace distributed {
namespace rpc {
+using steady_clock_time_point =
+ std::chrono::time_point<std::chrono::steady_clock>;
+
struct RpcBackendOptions {
RpcBackendOptions() = default;
std::chrono::milliseconds rpcTimeout;
@@ -55,6 +58,44 @@
const worker_id_t id_;
};
+// Struct for options to configure the RPC Retry protocol.
+struct TORCH_API RpcRetryOptions {
+ // Using a default constructor like all other Options structs in the RPC
+ // codebase. TORCH_CHECKs for input validation are done in the
+ // sendWithRetries function.
+ RpcRetryOptions() = default;
+ // Maximum number of times we will retry the RPC
+ int maxRetries{3};
+ // Initial duration between consecutive RPC send attempts
+ std::chrono::milliseconds rpcRetryDuration{std::chrono::milliseconds(1000)};
+ // Constant for exponential backoff used while calculating future wait
+ // durations
+ float retryBackoff{1.5};
+};
+
+// Struct that stores all the metadata needed to retry a given RPC.
+struct TORCH_API RpcRetryInfo {
+ RpcRetryInfo(
+ const WorkerInfo& to,
+ Message&& message,
+ std::shared_ptr<FutureMessage> originalFuture,
+ int retryCount,
+ RpcRetryOptions options)
+ : to_(to),
+ message_(message),
+ originalFuture_(std::move(originalFuture)),
+ retryCount_(retryCount),
+ options_(options) {}
+
+ const WorkerInfo& to_;
+ Message message_;
+ // Future that is returned to the caller of sendWithRetries().
+ std::shared_ptr<FutureMessage> originalFuture_;
+ // Number of send attempts completed so far.
+ int retryCount_;
+ RpcRetryOptions options_;
+};
+
// ``RpcAgent`` is the base class for sending and receiving RPC messages. It
// provides a unified ``send`` API for both request and response messages, and
// will invoke the given ``RequestCallback`` to process received requests. It
@@ -93,6 +134,23 @@
const WorkerInfo& to,
Message&& message) = 0;
+ // Retries sending the message up to maxRetries times until an ACK is
+ // receieved. The duration between consecutive sends is increased over
+ // time using an exponential backoff algorithm.
+ //
+ // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a
+ // ``FutureMessage`` ptr, just like send(). Caller can specify the maximum
+ // number of retries for this RPC (default is 3), initial duration between
+ // sends (default is 1000ms), and backoff constant (default is 1.5) by
+ // passing in the RpcRetryOptions struct. This API might end up
+ // executing a method twice on the remote end (it does not guarantee
+ // exactly-once semantics). Therefore, the user must ensure their requests
+ // are idempotent.
+ std::shared_ptr<FutureMessage> sendWithRetries(
+ const WorkerInfo& to,
+ Message&& message,
+ RpcRetryOptions retryOptions = RpcRetryOptions());
+
// Return a reference to the ``WorkerInfo`` of this RpcAgent.
// NB: not using ``c10::optional<const std::string&>`` here because we might
// need to create a separate RPC API lib and avoid forcing all ``RpcAgent``
@@ -130,7 +188,7 @@
// Stop accepting requests and shutdown the RPC framework as soon as possible
// by terminating all RPC threads.
- virtual void shutdown() = 0;
+ void shutdown();
// Check if current RPC agent is set.
static bool isCurrentRpcAgentSet();
@@ -160,11 +218,70 @@
std::atomic<std::chrono::milliseconds> rpcTimeout_;
std::atomic<bool> profilingEnabled_;
+ // Pure virtual function for shutting down the RPC framework.
+ virtual void shutdownImpl() = 0;
+
private:
static std::shared_ptr<RpcAgent> currentRpcAgent_;
// Add GIL wait time data point to metrics
virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0;
friend class PythonRpcHandler;
+
+ // Map that stores metadata for RPC's that may need to be re-tried as well as
+ // the timepoint at which we should re-try them.
+ std::map<
+ steady_clock_time_point,
+ std::unordered_set<std::shared_ptr<RpcRetryInfo>>>
+ rpcRetryMap_;
+
+ // Thread that checks for retryable RPC's in the rpcRetryMap_ and sleeps until
+ // the next unACKed RPC's timeout has expired.
+ std::thread rpcRetryThread_;
+
+ // Function that rpcRetryThread_ calls in a loop as long as RpcAgent is
+ // running.
+ void retryExpiredRpcs();
+
+ // This is the callback attached to futures corresponding to send retries.
+ // This handles 3 cases: 1). send was completed, 2). send failed with an
+ // error and we've done maxRetries failed send attempts, and 3). send
+ // failed with an error and we have more retries to go. In case 1, we mark
+ // the original future as complete. In case 2, we mark the future with an
+ // error and do not retry again. In case 3, we move the RpcRetryInfo struct
+ // to another time point in the map to schedule the RPC for a future send.
+ void rpcRetryCallback(
+ const rpc::Message& message,
+ const c10::optional<utils::FutureError>& futErr,
+ steady_clock_time_point newTime,
+ std::shared_ptr<RpcRetryInfo> earliestRpc);
+
+ // Function that uses the exponential backoff algorithm to compute the next
+ // time point to retry a given RPC.
+ inline steady_clock_time_point computeNewRpcRetryTime(
+ RpcRetryOptions& options,
+ int retryCount) {
+ // The exponential backoff algorithm being used here is:
+ // newTime = timeNow + (retryDuration * (backoffConstant ^ retryCount)).
+ std::chrono::milliseconds timedelta =
+ std::chrono::duration_cast<std::chrono::milliseconds>(
+ options.rpcRetryDuration * pow(options.retryBackoff, retryCount));
+ return std::chrono::time_point_cast<std::chrono::milliseconds>(
+ std::chrono::steady_clock::now() + timedelta);
+ }
+
+ // Boolean that indicates whether RpcAgent is running.
+ std::atomic<bool> rpcAgentRunning_;
+
+ // storing futures before adding callback
+ std::vector<
+ std::pair<std::shared_ptr<FutureMessage>, std::shared_ptr<RpcRetryInfo>>>
+ futures;
+
+ // Condition Variable to signal when the rpcRetryMap_ has been populated.
+ std::condition_variable rpcRetryMapCV_;
+
+ // Mutex to protect RpcRetryMap_.
+ std::mutex rpcRetryMutex_;
};
} // namespace rpc