[RpcAgent] Metrics for current num active/async rpc calls. (#34398)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34398

As part of PR 34109, it was suggested that we track the number of outstanding
async calls for RPC DebugInfo, particularly if we move towards using
at::launch() threads on occasion for continuations.

This particular aspect of the change was distinct from the main purpose of the
diff, and started getting bigger, so split this functionality out as a separate diff.
For completeness, we track client_active_calls, server_active_calls,
server_active_async_calls, and write some very basic unittest coverage.
ghstack-source-id: 99708836

Test Plan: buck test mode/dev-nosan caffe2/torch/fb/distributed/thriftRpcBackend/...

Differential Revision: D20314994

fbshipit-source-id: 2f7c75d5c511b27ed0c09c7b8a67b6fb49df31a5
diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp
index 92b51b9a..10b3464 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.cpp
+++ b/torch/csrc/distributed/rpc/process_group_agent.cpp
@@ -54,6 +54,9 @@
 const std::string kThreadPoolSize = "agent.thread_pool_size";
 const std::string kNumIdleThreads = "agent.num_idle_threads";
 const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
+const std::string kClientActiveCalls = "agent.client_active_calls";
+const std::string kServerActiveCalls = "agent.server_active_calls";
+const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
 
 void ProcessGroupAgent::collectNames() {
   const std::string& workerName = workerInfo_.name_;
@@ -315,6 +318,7 @@
       futureTimeoutCV_.notify_one();
     }
     message.setId(requestId);
+    ++clientActiveCalls_;
   } else {
     future->markCompleted(Message());
   }
@@ -424,8 +428,16 @@
             work.type_,
             work.id_);
         if (message.isRequest()) {
-          auto futureResponse = cb_->operator()(message);
+          ++serverActiveCalls_;
+          std::shared_ptr<FutureMessage> futureResponse;
+          try {
+            futureResponse = cb_->operator()(message);
+          } catch (const std::exception& e) {
+            futureResponse = std::make_shared<FutureMessage>();
+            futureResponse->setError(e.what());
+          }
           if (futureResponse->completed()) {
+            --serverActiveCalls_;
             if (!futureResponse->hasError()) {
               send(work.from_, std::move(*futureResponse).moveValue());
             } else {
@@ -435,12 +447,15 @@
                       message, futureResponse->error()->what()));
             }
           } else {
+            ++serverActiveAsyncCalls_;
             auto fromId = work.from_.id_;
             auto requestId = work.id_;
             futureResponse->addCallback(
                 [this, fromId, requestId, futureResponse](
                     const Message& /* unused */,
                     const c10::optional<utils::FutureError>& err) {
+                  --serverActiveCalls_;
+                  --serverActiveAsyncCalls_;
                   if (!err) {
                     send(
                         getWorkerInfo(fromId),
@@ -487,6 +502,7 @@
             }
           }
           futureCV_.notify_all();
+          --clientActiveCalls_;
           if (message.type() == MessageType::EXCEPTION) {
             fm->setError(std::string(
                 message.payload().begin(), message.payload().end()));
@@ -536,6 +552,7 @@
     }
   }
 
+  --clientActiveCalls_;
   fm->setError(std::string(message.payload().begin(), message.payload().end()));
   futureCV_.notify_all();
 }
@@ -654,6 +671,7 @@
       const auto exceptionMsg = createExceptionResponse(
           Message({}, {}, MessageType::EXCEPTION), ss.str());
       if (!timedOutFuture.future_->hasError()) {
+        --clientActiveCalls_;
         timedOutFuture.future_->setError(std::string(
             exceptionMsg.payload().begin(), exceptionMsg.payload().end()));
       }
@@ -701,6 +719,10 @@
   }
   metrics[kThreadPoolSize] = c10::to_string(threadPool_.size());
   metrics[kNumIdleThreads] = c10::to_string(threadPool_.numAvailable());
+  metrics[kClientActiveCalls] = c10::to_string(clientActiveCalls_.load());
+  metrics[kServerActiveCalls] = c10::to_string(serverActiveCalls_.load());
+  metrics[kServerActiveAsyncCalls] =
+      c10::to_string(serverActiveAsyncCalls_.load());
   if (isGILProfilingEnabled()) {
     // Add time-series based metrics, just GIL wait times for now.
     {
diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h
index 141ff4a..32aaa57 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.h
+++ b/torch/csrc/distributed/rpc/process_group_agent.h
@@ -264,6 +264,10 @@
   std::mutex metricsMutex_;
   std::vector<std::unique_ptr<AverageMetricsTracker>> metrics_;
   void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
+
+  std::atomic<int32_t> clientActiveCalls_{0};
+  std::atomic<int32_t> serverActiveCalls_{0};
+  std::atomic<int32_t> serverActiveAsyncCalls_{0};
 };
 
 } // namespace rpc