Lift rpc_timeout to RpcAgent, for other RpcAgents to reuse. (#29341)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29341

So that other RpcAgent could use this timeout setting as well.

ghstack-source-id: 93481902

Differential Revision: D5681951

fbshipit-source-id: 569c768dc342e8a2d9faf142ceccf696e12e41dc
diff --git a/test/rpc_test.py b/test/rpc_test.py
index 18f0518..456067f 100644
--- a/test/rpc_test.py
+++ b/test/rpc_test.py
@@ -962,13 +962,11 @@
 
         self.assertEqual(result, sum(vals))
 
-    @requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
     @dist_init
     def test_get_default_rpc_timeout(self):
         timeout = rpc.get_rpc_timeout()
         self.assertEqual(timeout, rpc.constants.DEFAULT_RPC_TIMEOUT)
 
-    @requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
     @dist_init(setup_model_parallel=False)
     def test_set_rpc_timeout(self):
         timeout = timedelta(seconds=1)
diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp
index c2771a5..2629c36 100644
--- a/torch/csrc/distributed/rpc/init.cpp
+++ b/torch/csrc/distributed/rpc/init.cpp
@@ -42,8 +42,10 @@
           .def(
               "join", &RpcAgent::join, py::call_guard<py::gil_scoped_release>())
           .def(
-              "sync",
-              &RpcAgent::sync,
+              "sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
+          .def(
+              "_get_rpc_timeout",
+              &RpcAgent::getRpcTimeout,
               py::call_guard<py::gil_scoped_release>());
 
   auto pyFuture = shared_ptr_class_<PyFuture>(module, "Future")
@@ -118,10 +120,6 @@
       .def(
           "sync",
           &ProcessGroupAgent::sync,
-          py::call_guard<py::gil_scoped_release>())
-      .def(
-          "_get_rpc_timeout",
-          &ProcessGroupAgent::getRpcTimeout,
           py::call_guard<py::gil_scoped_release>());
 
   module.def("_start_rpc_agent", [](const std::shared_ptr<RpcAgent>& agent) {
diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp
index 6f7b6e3..75fb9b3 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.cpp
+++ b/torch/csrc/distributed/rpc/process_group_agent.cpp
@@ -122,14 +122,14 @@
     std::chrono::milliseconds rpcTimeout)
     : RpcAgent(
           WorkerInfo(std::move(workerName), pg->getRank()),
-          c10::guts::make_unique<RequestCallbackImpl>()),
+          c10::guts::make_unique<RequestCallbackImpl>(),
+          rpcTimeout),
       pg_(std::move(pg)),
       sendCounts_(pg_->getSize()),
       recvCounts_(pg_->getSize()),
       nextId_(0),
       sendMutexes_(pg_->getSize()),
-      threadPool_(numSendRecvThreads),
-      rpcTimeout_(rpcTimeout) {
+      threadPool_(numSendRecvThreads) {
   collectNames();
   TORCH_CHECK(
       nameMap_.size() > 1,
@@ -175,10 +175,6 @@
   return allWorkerInfo_[id];
 }
 
-const std::chrono::milliseconds& ProcessGroupAgent::getRpcTimeout() const {
-  return rpcTimeout_;
-}
-
 void ProcessGroupAgent::join() {
   // Every process i sends a SHUTDOWN message to process i + 1. This is
   // necessary for now because:
diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h
index a55f4ac..ad4c70e 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.h
+++ b/torch/csrc/distributed/rpc/process_group_agent.h
@@ -51,9 +51,6 @@
 
   void start() override;
 
-  // retrieves the timeout for all RPCs
-  const std::chrono::milliseconds& getRpcTimeout() const;
-
  protected:
   // This method wraps the destination information and the message into a
   // SendWork object, and put the SendWork into a queue. Another thread will
@@ -148,7 +145,6 @@
   std::map<std::chrono::milliseconds, std::vector<int64_t>> futureTimeouts_;
   mutable std::mutex futureMutex_;
   mutable std::condition_variable futureCV_;
-  std::chrono::milliseconds rpcTimeout_;
 };
 
 } // namespace rpc
diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp
index 97cbabb..bf2e427 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.cpp
+++ b/torch/csrc/distributed/rpc/rpc_agent.cpp
@@ -6,8 +6,13 @@
 
 constexpr size_t WorkerInfo::MAX_NAME_LEN;
 
-RpcAgent::RpcAgent(WorkerInfo workerId, std::unique_ptr<RequestCallback> cb)
-    : workerInfo_(std::move(workerId)), cb_(std::move(cb)) {}
+RpcAgent::RpcAgent(
+    WorkerInfo workerId,
+    std::unique_ptr<RequestCallback> cb,
+    std::chrono::milliseconds rpcTimeout)
+    : workerInfo_(std::move(workerId)),
+      cb_(std::move(cb)),
+      rpcTimeout_(rpcTimeout) {}
 
 RpcAgent::~RpcAgent() = default;
 
diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h
index 7d62f27..a53bd59 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.h
+++ b/torch/csrc/distributed/rpc/rpc_agent.h
@@ -66,7 +66,10 @@
   // NB: RpcAgent implementations should not start serving requests until
   // ``start()`` is called, as there could be other contexts that have not been
   // initialized yet at this time.
-  RpcAgent(WorkerInfo id, std::unique_ptr<RequestCallback> cb);
+  RpcAgent(
+      WorkerInfo id,
+      std::unique_ptr<RequestCallback> cb,
+      std::chrono::milliseconds rpcTimeout);
 
   virtual ~RpcAgent();
 
@@ -93,6 +96,11 @@
 
   virtual const WorkerInfo& getWorkerInfo(worker_id_t id) const = 0;
 
+  // Retrieve the timeout for all RPCs.
+  inline const std::chrono::milliseconds& getRpcTimeout() const {
+    return rpcTimeout_;
+  }
+
   // Call sync and join all internal threads. This method should be called
   // before every RPC process exits.
   virtual void join() = 0;
@@ -114,6 +122,7 @@
   const WorkerInfo workerInfo_;
   const std::string workerName_;
   const std::unique_ptr<RequestCallback> cb_;
+  const std::chrono::milliseconds rpcTimeout_;
 
  private:
   static std::shared_ptr<RpcAgent> defaultRpcAgent_;
diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py
index 63a022b..b5d9dd1 100644
--- a/torch/distributed/rpc/__init__.py
+++ b/torch/distributed/rpc/__init__.py
@@ -54,6 +54,7 @@
             init_method(str): backend specific init arguments.
             num_send_recv_threads(int): Number of threads for send/recv work.
             rpc_timeout (datetime.timedelta): Timeout for RPCs. Defaults to 10 seconds.
+                0 means infinity.
         """
         # Rendezvous.
         world_size = len(worker_name_to_id)
diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py
index e3fa15e..4c3714d 100644
--- a/torch/distributed/rpc/api.py
+++ b/torch/distributed/rpc/api.py
@@ -7,6 +7,7 @@
 from .constants import DEFAULT_RPC_TIMEOUT, DEFAULT_NUM_SEND_RECV_THREADS
 from .internal import _internal_rpc_pickler, PythonUDF
 
+import datetime
 import functools
 import sys
 import torch
@@ -78,6 +79,11 @@
         raise RuntimeError("RPC is already initialized")
 
     # Initialize RPC.
+    if not isinstance(rpc_timeout, datetime.timedelta):
+        raise RuntimeError(
+            "`rpc_timeout` must be a `datetime.timedelta`."
+        )
+
     _agent = backend_registry.init_backend(
         backend,
         store=store,
diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py
index c762f2f..ea8a741 100644
--- a/torch/distributed/rpc/backend_registry.py
+++ b/torch/distributed/rpc/backend_registry.py
@@ -45,6 +45,7 @@
     self_rank,
     worker_name_to_id,
     num_send_recv_threads,
+    rpc_timeout,
     *args,
     **kwargs
 ):
@@ -78,7 +79,12 @@
                 )
             )
         # TODO: add try-except and destroy _agent in all processes if any fails.
-        return ProcessGroupAgent(self_name, group, num_send_recv_threads, kwargs["rpc_timeout"])
+        return ProcessGroupAgent(
+            self_name,
+            group,
+            num_send_recv_threads,
+            rpc_timeout,
+        )
     except Exception as ex:
         dist.destroy_process_group()
         raise ex