Improve ProcessGroup RpcBackendOptions Constructor API (#34081)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34081

Before this commit, applications have to do the following to configure
number of threads in ProcessGroup RPC backend:

```
op = ProcessGroupRpcBackendOptions()
op.rpc_timeout = rpc_timeout
op.init_method = init_method
op.num_send_recv_threads = 32
init_rpc(...., rpc_backend_options=op)
```

After this commit, it can be simplified to:

```
init_rpc(...., rpc_backend_options=ProcessGroupRpcBackendOptions(num_send_recv_threads=32))
```

Fixes #34075

Test Plan: Imported from OSS

Differential Revision: D20227344

Pulled By: mrshenli

fbshipit-source-id: def4318e987179b8c8ecca44d7ff935702c8a6e7
diff --git a/docs/source/rpc.rst b/docs/source/rpc.rst
index 6b78062..5abe798 100644
--- a/docs/source/rpc.rst
+++ b/docs/source/rpc.rst
@@ -101,9 +101,9 @@
 .. autofunction:: shutdown
 .. autoclass:: WorkerInfo
     :members:
-.. autoclass:: RpcBackendOptions
+.. autoclass:: ProcessGroupRpcBackendOptions
     :members:
-
+    :inherited-members:
 
 .. _rref:
 
diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp
index 1fa5b1e..3303b0e 100644
--- a/torch/csrc/distributed/rpc/init.cpp
+++ b/torch/csrc/distributed/rpc/init.cpp
@@ -39,10 +39,11 @@
       shared_ptr_class_<RpcBackendOptions>(
           module,
           "RpcBackendOptions",
-          R"(A structure encapsulating the options passed into the RPC backend.
-            An instance of this class can be passed in to :meth:`~torch.distributed.rpc.init_rpc`
-            in order to initialize RPC with specific configurations, such as the
-             RPC timeout and init_method to be used. )")
+          R"(An abstract structure encapsulating the options passed into the RPC
+            backend. An instance of this class can be passed in to
+            :meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC
+            with specific configurations, such as the RPC timeout and
+            `init_method` to be used. )")
           .def_readwrite(
               "rpc_timeout",
               &RpcBackendOptions::rpcTimeout,
@@ -53,7 +54,10 @@
               "init_method",
               &RpcBackendOptions::initMethod,
               R"(URL specifying how to initialize the process group.
-                Default is env://)");
+                Default is ``env://``)");
+
+  module.attr("_DEFAULT_RPC_TIMEOUT") = py::cast(kDefaultRpcTimeout);
+  module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod);
 
   auto workerInfo =
       shared_ptr_class_<WorkerInfo>(
@@ -207,11 +211,54 @@
               )");
 
   shared_ptr_class_<ProcessGroupRpcBackendOptions>(
-      module, "ProcessGroupRpcBackendOptions", rpcBackendOptions)
-      .def(py::init<>())
+      module,
+      "ProcessGroupRpcBackendOptions",
+      rpcBackendOptions,
+      R"(
+          The backend options class for ``ProcessGroupAgent``, which is derived
+          from ``RpcBackendOptions``.
+
+          Arguments:
+              num_send_recv_threads (int, optional): The number of threads in
+                  the thread-pool used by ``ProcessGroupAgent`` (default: 4).
+              rpc_timeout (datetime.timedelta, optional): The timeout for RPC
+                  requests (default: ``timedelta(seconds=60)``).
+              init_method (str, optional): The URL to initialize
+                  ``ProcessGroupGloo`` (default: ``env://``).
+
+
+          Example::
+              >>> import datetime, os
+              >>> from torch.distributed import rpc
+              >>> os.environ['MASTER_ADDR'] = 'localhost'
+              >>> os.environ['MASTER_PORT'] = '29500'
+              >>>
+              >>> rpc.init_rpc(
+              >>>     "worker1",
+              >>>     rank=0,
+              >>>     world_size=2,
+              >>>     rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
+              >>>         num_send_recv_threads=16,
+              >>>         datetime.timedelta(seconds=20)
+              >>>     )
+              >>> )
+              >>>
+              >>> # omitting init_rpc invocation on worker2
+      )")
+      .def(
+          py::init<int, std::chrono::milliseconds, std::string>(),
+          py::arg("num_send_recv_threads") = kDefaultNumSendRecvThreads,
+          py::arg("rpc_timeout") = kDefaultRpcTimeout,
+          py::arg("init_method") = kDefaultInitMethod)
       .def_readwrite(
           "num_send_recv_threads",
-          &ProcessGroupRpcBackendOptions::numSendRecvThreads);
+          &ProcessGroupRpcBackendOptions::numSendRecvThreads,
+          R"(
+              The number of threads in the thread-pool used by ProcessGroupAgent.
+          )");
+
+  module.attr("_DEFAULT_NUM_SEND_RECV_THREADS") =
+      py::cast(kDefaultNumSendRecvThreads);
 
   shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent)
       .def(
diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h
index 1d910b3..141ff4a 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.h
+++ b/torch/csrc/distributed/rpc/process_group_agent.h
@@ -12,8 +12,22 @@
 namespace distributed {
 namespace rpc {
 
+constexpr auto kDefaultNumSendRecvThreads = 4;
+
 struct ProcessGroupRpcBackendOptions : public RpcBackendOptions {
-  ProcessGroupRpcBackendOptions() = default;
+  ProcessGroupRpcBackendOptions(
+      int num_send_recv_threads,
+      std::chrono::milliseconds rpc_timeout,
+      std::string init_method)
+      : RpcBackendOptions(rpc_timeout, init_method),
+        numSendRecvThreads(num_send_recv_threads) {
+    TORCH_CHECK(
+        num_send_recv_threads > 0,
+        "Cannot create ProcessGroup RPC backend with ",
+        num_send_recv_threads,
+        " threads in the thread-pool.");
+  }
+
   int numSendRecvThreads;
 };
 
diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h
index 07d660e..b02bb3d 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.h
+++ b/torch/csrc/distributed/rpc/rpc_agent.h
@@ -11,6 +11,9 @@
 namespace distributed {
 namespace rpc {
 
+constexpr auto kDefaultRpcTimeout = std::chrono::seconds(60);
+constexpr auto kDefaultInitMethod = "env://";
+
 using steady_clock_time_point =
     std::chrono::time_point<std::chrono::steady_clock>;
 // Input is qualified name string, output is JIT StrongTypePtr
@@ -20,7 +23,14 @@
     std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
 
 struct RpcBackendOptions {
-  RpcBackendOptions() = default;
+  RpcBackendOptions()
+      : RpcBackendOptions(kDefaultRpcTimeout, kDefaultInitMethod) {}
+
+  RpcBackendOptions(
+      std::chrono::milliseconds rpcTimeout,
+      std::string initMethod)
+      : rpcTimeout(rpcTimeout), initMethod(initMethod) {}
+
   std::chrono::milliseconds rpcTimeout;
   std::string initMethod;
 };
diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py
index 7b76395..75b8d4b 100644
--- a/torch/distributed/rpc/__init__.py
+++ b/torch/distributed/rpc/__init__.py
@@ -6,8 +6,6 @@
 import torch
 import torch.distributed as dist
 
-from . import backend_registry
-
 
 def is_available():
     return sys.version_info >= (3, 0) and hasattr(torch._C, "_rpc_init")
@@ -18,7 +16,7 @@
 
 
 if is_available():
-    from . import api
+    from . import api, backend_registry
     from .api import _rpc_sync_torchscript, _rpc_async_torchscript, _remote_torchscript
     from .api import *  # noqa: F401
     import torch.distributed.autograd as dist_autograd
@@ -58,7 +56,9 @@
                 ``rpc_backend_options``, RPC would initialize the underlying
                 process group backend using ``init_method = "env://"``,
                 meaning that environment variables ``MASTER_ADDRESS`` and
-                ``MASTER_PORT`` needs to be set properly.
+                ``MASTER_PORT`` needs to be set properly. See
+                :class:`~torch.distributed.rpc.ProcessGroupRpcBackendOptions`
+                for examples.
         """
 
         if not rpc_backend_options:
diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py
index 482c092..679a860 100644
--- a/torch/distributed/rpc/backend_registry.py
+++ b/torch/distributed/rpc/backend_registry.py
@@ -83,11 +83,11 @@
 ):
     from . import ProcessGroupRpcBackendOptions
 
-    rpc_backend_options = ProcessGroupRpcBackendOptions()
-    rpc_backend_options.rpc_timeout = rpc_timeout
-    rpc_backend_options.init_method = init_method
-    rpc_backend_options.num_send_recv_threads = num_send_recv_threads
-    return rpc_backend_options
+    return ProcessGroupRpcBackendOptions(
+        rpc_timeout=rpc_timeout,
+        init_method=init_method,
+        num_send_recv_threads=num_send_recv_threads
+    )
 
 
 def _process_group_init_backend_handler(
diff --git a/torch/distributed/rpc/constants.py b/torch/distributed/rpc/constants.py
index 993644c..7677e1d 100644
--- a/torch/distributed/rpc/constants.py
+++ b/torch/distributed/rpc/constants.py
@@ -1,12 +1,16 @@
-from datetime import timedelta
 from torch.distributed.constants import default_pg_timeout
 
-# For any RpcAgent.
-DEFAULT_RPC_TIMEOUT = timedelta(seconds=60)
-DEFAULT_INIT_METHOD = "env://"
+from . import (
+    _DEFAULT_RPC_TIMEOUT,
+    _DEFAULT_INIT_METHOD,
+    _DEFAULT_NUM_SEND_RECV_THREADS
+)
 
+# For any RpcAgent.
+DEFAULT_RPC_TIMEOUT = _DEFAULT_RPC_TIMEOUT
+DEFAULT_INIT_METHOD = _DEFAULT_INIT_METHOD
 
 # For ProcessGroupAgent.
-DEFAULT_NUM_SEND_RECV_THREADS = 4
+DEFAULT_NUM_SEND_RECV_THREADS = _DEFAULT_NUM_SEND_RECV_THREADS
 # Same default timeout as in c10d.
 DEFAULT_PROCESS_GROUP_TIMEOUT = default_pg_timeout
diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py
index 2b9e4df..d7cf166 100644
--- a/torch/testing/_internal/distributed/rpc/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -1542,6 +1542,26 @@
         self.assertEqual(timeout, set_timeout)
         rpc.shutdown()
 
+    @dist_init(setup_rpc=False)
+    @requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
+    def test_set_and_get_num_send_recv_threads(self):
+        NUM_THREADS = 27
+        rpc_backend_options = rpc.ProcessGroupRpcBackendOptions(
+            init_method=self.rpc_backend_options.init_method,
+            num_send_recv_threads=NUM_THREADS
+        )
+        rpc.init_rpc(
+            name="worker{}".format(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        info = rpc.api._get_current_rpc_agent().get_debug_info()
+        self.assertEqual(int(info["agent.thread_pool_size"]), NUM_THREADS)
+        rpc.shutdown()
+
     @dist_init
     @requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
     def test_rpc_timeouts(self):