Lift rpc_timeout to RpcAgent, for other RpcAgents to reuse. (#29341)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29341
So that other RpcAgent could use this timeout setting as well.
ghstack-source-id: 93481902
Differential Revision: D5681951
fbshipit-source-id: 569c768dc342e8a2d9faf142ceccf696e12e41dc
diff --git a/test/rpc_test.py b/test/rpc_test.py
index 18f0518..456067f 100644
--- a/test/rpc_test.py
+++ b/test/rpc_test.py
@@ -962,13 +962,11 @@
self.assertEqual(result, sum(vals))
- @requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
@dist_init
def test_get_default_rpc_timeout(self):
timeout = rpc.get_rpc_timeout()
self.assertEqual(timeout, rpc.constants.DEFAULT_RPC_TIMEOUT)
- @requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
@dist_init(setup_model_parallel=False)
def test_set_rpc_timeout(self):
timeout = timedelta(seconds=1)
diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp
index c2771a5..2629c36 100644
--- a/torch/csrc/distributed/rpc/init.cpp
+++ b/torch/csrc/distributed/rpc/init.cpp
@@ -42,8 +42,10 @@
.def(
"join", &RpcAgent::join, py::call_guard<py::gil_scoped_release>())
.def(
- "sync",
- &RpcAgent::sync,
+ "sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
+ .def(
+ "_get_rpc_timeout",
+ &RpcAgent::getRpcTimeout,
py::call_guard<py::gil_scoped_release>());
auto pyFuture = shared_ptr_class_<PyFuture>(module, "Future")
@@ -118,10 +120,6 @@
.def(
"sync",
&ProcessGroupAgent::sync,
- py::call_guard<py::gil_scoped_release>())
- .def(
- "_get_rpc_timeout",
- &ProcessGroupAgent::getRpcTimeout,
py::call_guard<py::gil_scoped_release>());
module.def("_start_rpc_agent", [](const std::shared_ptr<RpcAgent>& agent) {
diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp
index 6f7b6e3..75fb9b3 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.cpp
+++ b/torch/csrc/distributed/rpc/process_group_agent.cpp
@@ -122,14 +122,14 @@
std::chrono::milliseconds rpcTimeout)
: RpcAgent(
WorkerInfo(std::move(workerName), pg->getRank()),
- c10::guts::make_unique<RequestCallbackImpl>()),
+ c10::guts::make_unique<RequestCallbackImpl>(),
+ rpcTimeout),
pg_(std::move(pg)),
sendCounts_(pg_->getSize()),
recvCounts_(pg_->getSize()),
nextId_(0),
sendMutexes_(pg_->getSize()),
- threadPool_(numSendRecvThreads),
- rpcTimeout_(rpcTimeout) {
+ threadPool_(numSendRecvThreads) {
collectNames();
TORCH_CHECK(
nameMap_.size() > 1,
@@ -175,10 +175,6 @@
return allWorkerInfo_[id];
}
-const std::chrono::milliseconds& ProcessGroupAgent::getRpcTimeout() const {
- return rpcTimeout_;
-}
-
void ProcessGroupAgent::join() {
// Every process i sends a SHUTDOWN message to process i + 1. This is
// necessary for now because:
diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h
index a55f4ac..ad4c70e 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.h
+++ b/torch/csrc/distributed/rpc/process_group_agent.h
@@ -51,9 +51,6 @@
void start() override;
- // retrieves the timeout for all RPCs
- const std::chrono::milliseconds& getRpcTimeout() const;
-
protected:
// This method wraps the destination information and the message into a
// SendWork object, and put the SendWork into a queue. Another thread will
@@ -148,7 +145,6 @@
std::map<std::chrono::milliseconds, std::vector<int64_t>> futureTimeouts_;
mutable std::mutex futureMutex_;
mutable std::condition_variable futureCV_;
- std::chrono::milliseconds rpcTimeout_;
};
} // namespace rpc
diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp
index 97cbabb..bf2e427 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.cpp
+++ b/torch/csrc/distributed/rpc/rpc_agent.cpp
@@ -6,8 +6,13 @@
constexpr size_t WorkerInfo::MAX_NAME_LEN;
-RpcAgent::RpcAgent(WorkerInfo workerId, std::unique_ptr<RequestCallback> cb)
- : workerInfo_(std::move(workerId)), cb_(std::move(cb)) {}
+RpcAgent::RpcAgent(
+ WorkerInfo workerId,
+ std::unique_ptr<RequestCallback> cb,
+ std::chrono::milliseconds rpcTimeout)
+ : workerInfo_(std::move(workerId)),
+ cb_(std::move(cb)),
+ rpcTimeout_(rpcTimeout) {}
RpcAgent::~RpcAgent() = default;
diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h
index 7d62f27..a53bd59 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.h
+++ b/torch/csrc/distributed/rpc/rpc_agent.h
@@ -66,7 +66,10 @@
// NB: RpcAgent implementations should not start serving requests until
// ``start()`` is called, as there could be other contexts that have not been
// initialized yet at this time.
- RpcAgent(WorkerInfo id, std::unique_ptr<RequestCallback> cb);
+ RpcAgent(
+ WorkerInfo id,
+ std::unique_ptr<RequestCallback> cb,
+ std::chrono::milliseconds rpcTimeout);
virtual ~RpcAgent();
@@ -93,6 +96,11 @@
virtual const WorkerInfo& getWorkerInfo(worker_id_t id) const = 0;
+ // Retrieve the timeout for all RPCs.
+ inline const std::chrono::milliseconds& getRpcTimeout() const {
+ return rpcTimeout_;
+ }
+
// Call sync and join all internal threads. This method should be called
// before every RPC process exits.
virtual void join() = 0;
@@ -114,6 +122,7 @@
const WorkerInfo workerInfo_;
const std::string workerName_;
const std::unique_ptr<RequestCallback> cb_;
+ const std::chrono::milliseconds rpcTimeout_;
private:
static std::shared_ptr<RpcAgent> defaultRpcAgent_;
diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py
index 63a022b..b5d9dd1 100644
--- a/torch/distributed/rpc/__init__.py
+++ b/torch/distributed/rpc/__init__.py
@@ -54,6 +54,7 @@
init_method(str): backend specific init arguments.
num_send_recv_threads(int): Number of threads for send/recv work.
rpc_timeout (datetime.timedelta): Timeout for RPCs. Defaults to 10 seconds.
+ 0 means infinity.
"""
# Rendezvous.
world_size = len(worker_name_to_id)
diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py
index e3fa15e..4c3714d 100644
--- a/torch/distributed/rpc/api.py
+++ b/torch/distributed/rpc/api.py
@@ -7,6 +7,7 @@
from .constants import DEFAULT_RPC_TIMEOUT, DEFAULT_NUM_SEND_RECV_THREADS
from .internal import _internal_rpc_pickler, PythonUDF
+import datetime
import functools
import sys
import torch
@@ -78,6 +79,11 @@
raise RuntimeError("RPC is already initialized")
# Initialize RPC.
+ if not isinstance(rpc_timeout, datetime.timedelta):
+ raise RuntimeError(
+ "`rpc_timeout` must be a `datetime.timedelta`."
+ )
+
_agent = backend_registry.init_backend(
backend,
store=store,
diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py
index c762f2f..ea8a741 100644
--- a/torch/distributed/rpc/backend_registry.py
+++ b/torch/distributed/rpc/backend_registry.py
@@ -45,6 +45,7 @@
self_rank,
worker_name_to_id,
num_send_recv_threads,
+ rpc_timeout,
*args,
**kwargs
):
@@ -78,7 +79,12 @@
)
)
# TODO: add try-except and destroy _agent in all processes if any fails.
- return ProcessGroupAgent(self_name, group, num_send_recv_threads, kwargs["rpc_timeout"])
+ return ProcessGroupAgent(
+ self_name,
+ group,
+ num_send_recv_threads,
+ rpc_timeout,
+ )
except Exception as ex:
dist.destroy_process_group()
raise ex