Improve ProcessGroup RpcBackendOptions Constructor API (#34081)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34081
Before this commit, applications have to do the following to configure
number of threads in ProcessGroup RPC backend:
```
op = ProcessGroupRpcBackendOptions()
op.rpc_timeout = rpc_timeout
op.init_method = init_method
op.num_send_recv_threads = 32
init_rpc(...., rpc_backend_options=op)
```
After this commit, it can be simplified to:
```
init_rpc(...., rpc_backend_options=ProcessGroupRpcBackendOptions(num_send_recv_threads=32))
```
Fixes #34075
Test Plan: Imported from OSS
Differential Revision: D20227344
Pulled By: mrshenli
fbshipit-source-id: def4318e987179b8c8ecca44d7ff935702c8a6e7
diff --git a/docs/source/rpc.rst b/docs/source/rpc.rst
index 6b78062..5abe798 100644
--- a/docs/source/rpc.rst
+++ b/docs/source/rpc.rst
@@ -101,9 +101,9 @@
.. autofunction:: shutdown
.. autoclass:: WorkerInfo
:members:
-.. autoclass:: RpcBackendOptions
+.. autoclass:: ProcessGroupRpcBackendOptions
:members:
-
+ :inherited-members:
.. _rref:
diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp
index 1fa5b1e..3303b0e 100644
--- a/torch/csrc/distributed/rpc/init.cpp
+++ b/torch/csrc/distributed/rpc/init.cpp
@@ -39,10 +39,11 @@
shared_ptr_class_<RpcBackendOptions>(
module,
"RpcBackendOptions",
- R"(A structure encapsulating the options passed into the RPC backend.
- An instance of this class can be passed in to :meth:`~torch.distributed.rpc.init_rpc`
- in order to initialize RPC with specific configurations, such as the
- RPC timeout and init_method to be used. )")
+ R"(An abstract structure encapsulating the options passed into the RPC
+ backend. An instance of this class can be passed in to
+ :meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC
+ with specific configurations, such as the RPC timeout and
+ `init_method` to be used. )")
.def_readwrite(
"rpc_timeout",
&RpcBackendOptions::rpcTimeout,
@@ -53,7 +54,10 @@
"init_method",
&RpcBackendOptions::initMethod,
R"(URL specifying how to initialize the process group.
- Default is env://)");
+ Default is ``env://``)");
+
+ module.attr("_DEFAULT_RPC_TIMEOUT") = py::cast(kDefaultRpcTimeout);
+ module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod);
auto workerInfo =
shared_ptr_class_<WorkerInfo>(
@@ -207,11 +211,54 @@
)");
shared_ptr_class_<ProcessGroupRpcBackendOptions>(
- module, "ProcessGroupRpcBackendOptions", rpcBackendOptions)
- .def(py::init<>())
+ module,
+ "ProcessGroupRpcBackendOptions",
+ rpcBackendOptions,
+ R"(
+ The backend options class for ``ProcessGroupAgent``, which is derived
+ from ``RpcBackendOptions``.
+
+ Arguments:
+ num_send_recv_threads (int, optional): The number of threads in
+ the thread-pool used by ``ProcessGroupAgent`` (default: 4).
+ rpc_timeout (datetime.timedelta, optional): The timeout for RPC
+ requests (default: ``timedelta(seconds=60)``).
+ init_method (str, optional): The URL to initialize
+ ``ProcessGroupGloo`` (default: ``env://``).
+
+
+ Example::
+ >>> import datetime, os
+ >>> from torch.distributed import rpc
+ >>> os.environ['MASTER_ADDR'] = 'localhost'
+ >>> os.environ['MASTER_PORT'] = '29500'
+ >>>
+ >>> rpc.init_rpc(
+ >>> "worker1",
+ >>> rank=0,
+ >>> world_size=2,
+ >>> rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
+ >>> num_send_recv_threads=16,
+ >>> datetime.timedelta(seconds=20)
+ >>> )
+ >>> )
+ >>>
+ >>> # omitting init_rpc invocation on worker2
+ )")
+ .def(
+ py::init<int, std::chrono::milliseconds, std::string>(),
+ py::arg("num_send_recv_threads") = kDefaultNumSendRecvThreads,
+ py::arg("rpc_timeout") = kDefaultRpcTimeout,
+ py::arg("init_method") = kDefaultInitMethod)
.def_readwrite(
"num_send_recv_threads",
- &ProcessGroupRpcBackendOptions::numSendRecvThreads);
+ &ProcessGroupRpcBackendOptions::numSendRecvThreads,
+ R"(
+ The number of threads in the thread-pool used by ProcessGroupAgent.
+ )");
+
+ module.attr("_DEFAULT_NUM_SEND_RECV_THREADS") =
+ py::cast(kDefaultNumSendRecvThreads);
shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent)
.def(
diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h
index 1d910b3..141ff4a 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.h
+++ b/torch/csrc/distributed/rpc/process_group_agent.h
@@ -12,8 +12,22 @@
namespace distributed {
namespace rpc {
+constexpr auto kDefaultNumSendRecvThreads = 4;
+
struct ProcessGroupRpcBackendOptions : public RpcBackendOptions {
- ProcessGroupRpcBackendOptions() = default;
+ ProcessGroupRpcBackendOptions(
+ int num_send_recv_threads,
+ std::chrono::milliseconds rpc_timeout,
+ std::string init_method)
+ : RpcBackendOptions(rpc_timeout, init_method),
+ numSendRecvThreads(num_send_recv_threads) {
+ TORCH_CHECK(
+ num_send_recv_threads > 0,
+ "Cannot create ProcessGroup RPC backend with ",
+ num_send_recv_threads,
+ " threads in the thread-pool.");
+ }
+
int numSendRecvThreads;
};
diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h
index 07d660e..b02bb3d 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.h
+++ b/torch/csrc/distributed/rpc/rpc_agent.h
@@ -11,6 +11,9 @@
namespace distributed {
namespace rpc {
+constexpr auto kDefaultRpcTimeout = std::chrono::seconds(60);
+constexpr auto kDefaultInitMethod = "env://";
+
using steady_clock_time_point =
std::chrono::time_point<std::chrono::steady_clock>;
// Input is qualified name string, output is JIT StrongTypePtr
@@ -20,7 +23,14 @@
std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
struct RpcBackendOptions {
- RpcBackendOptions() = default;
+ RpcBackendOptions()
+ : RpcBackendOptions(kDefaultRpcTimeout, kDefaultInitMethod) {}
+
+ RpcBackendOptions(
+ std::chrono::milliseconds rpcTimeout,
+ std::string initMethod)
+ : rpcTimeout(rpcTimeout), initMethod(initMethod) {}
+
std::chrono::milliseconds rpcTimeout;
std::string initMethod;
};
diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py
index 7b76395..75b8d4b 100644
--- a/torch/distributed/rpc/__init__.py
+++ b/torch/distributed/rpc/__init__.py
@@ -6,8 +6,6 @@
import torch
import torch.distributed as dist
-from . import backend_registry
-
def is_available():
return sys.version_info >= (3, 0) and hasattr(torch._C, "_rpc_init")
@@ -18,7 +16,7 @@
if is_available():
- from . import api
+ from . import api, backend_registry
from .api import _rpc_sync_torchscript, _rpc_async_torchscript, _remote_torchscript
from .api import * # noqa: F401
import torch.distributed.autograd as dist_autograd
@@ -58,7 +56,9 @@
``rpc_backend_options``, RPC would initialize the underlying
process group backend using ``init_method = "env://"``,
meaning that environment variables ``MASTER_ADDRESS`` and
- ``MASTER_PORT`` needs to be set properly.
+ ``MASTER_PORT`` needs to be set properly. See
+ :class:`~torch.distributed.rpc.ProcessGroupRpcBackendOptions`
+ for examples.
"""
if not rpc_backend_options:
diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py
index 482c092..679a860 100644
--- a/torch/distributed/rpc/backend_registry.py
+++ b/torch/distributed/rpc/backend_registry.py
@@ -83,11 +83,11 @@
):
from . import ProcessGroupRpcBackendOptions
- rpc_backend_options = ProcessGroupRpcBackendOptions()
- rpc_backend_options.rpc_timeout = rpc_timeout
- rpc_backend_options.init_method = init_method
- rpc_backend_options.num_send_recv_threads = num_send_recv_threads
- return rpc_backend_options
+ return ProcessGroupRpcBackendOptions(
+ rpc_timeout=rpc_timeout,
+ init_method=init_method,
+ num_send_recv_threads=num_send_recv_threads
+ )
def _process_group_init_backend_handler(
diff --git a/torch/distributed/rpc/constants.py b/torch/distributed/rpc/constants.py
index 993644c..7677e1d 100644
--- a/torch/distributed/rpc/constants.py
+++ b/torch/distributed/rpc/constants.py
@@ -1,12 +1,16 @@
-from datetime import timedelta
from torch.distributed.constants import default_pg_timeout
-# For any RpcAgent.
-DEFAULT_RPC_TIMEOUT = timedelta(seconds=60)
-DEFAULT_INIT_METHOD = "env://"
+from . import (
+ _DEFAULT_RPC_TIMEOUT,
+ _DEFAULT_INIT_METHOD,
+ _DEFAULT_NUM_SEND_RECV_THREADS
+)
+# For any RpcAgent.
+DEFAULT_RPC_TIMEOUT = _DEFAULT_RPC_TIMEOUT
+DEFAULT_INIT_METHOD = _DEFAULT_INIT_METHOD
# For ProcessGroupAgent.
-DEFAULT_NUM_SEND_RECV_THREADS = 4
+DEFAULT_NUM_SEND_RECV_THREADS = _DEFAULT_NUM_SEND_RECV_THREADS
# Same default timeout as in c10d.
DEFAULT_PROCESS_GROUP_TIMEOUT = default_pg_timeout
diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py
index 2b9e4df..d7cf166 100644
--- a/torch/testing/_internal/distributed/rpc/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -1542,6 +1542,26 @@
self.assertEqual(timeout, set_timeout)
rpc.shutdown()
+ @dist_init(setup_rpc=False)
+ @requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
+ def test_set_and_get_num_send_recv_threads(self):
+ NUM_THREADS = 27
+ rpc_backend_options = rpc.ProcessGroupRpcBackendOptions(
+ init_method=self.rpc_backend_options.init_method,
+ num_send_recv_threads=NUM_THREADS
+ )
+ rpc.init_rpc(
+ name="worker{}".format(self.rank),
+ backend=self.rpc_backend,
+ rank=self.rank,
+ world_size=self.world_size,
+ rpc_backend_options=rpc_backend_options,
+ )
+
+ info = rpc.api._get_current_rpc_agent().get_debug_info()
+ self.assertEqual(int(info["agent.thread_pool_size"]), NUM_THREADS)
+ rpc.shutdown()
+
@dist_init
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_rpc_timeouts(self):