[RpcAgent] Metrics for current num active/async rpc calls. (#34398)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34398
As part of PR 34109, it was suggested that we track the number of outstanding
async calls for RPC DebugInfo, particularly if we move towards using
at::launch() threads on occasion for continuations.
This particular aspect of the change was distinct from the main purpose of the
diff, and started getting bigger, so split this functionality out as a separate diff.
For completeness, we track client_active_calls, server_active_calls,
server_active_async_calls, and write some very basic unittest coverage.
ghstack-source-id: 99708836
Test Plan: buck test mode/dev-nosan caffe2/torch/fb/distributed/thriftRpcBackend/...
Differential Revision: D20314994
fbshipit-source-id: 2f7c75d5c511b27ed0c09c7b8a67b6fb49df31a5
diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp
index 92b51b9a..10b3464 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.cpp
+++ b/torch/csrc/distributed/rpc/process_group_agent.cpp
@@ -54,6 +54,9 @@
const std::string kThreadPoolSize = "agent.thread_pool_size";
const std::string kNumIdleThreads = "agent.num_idle_threads";
const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
+const std::string kClientActiveCalls = "agent.client_active_calls";
+const std::string kServerActiveCalls = "agent.server_active_calls";
+const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
void ProcessGroupAgent::collectNames() {
const std::string& workerName = workerInfo_.name_;
@@ -315,6 +318,7 @@
futureTimeoutCV_.notify_one();
}
message.setId(requestId);
+ ++clientActiveCalls_;
} else {
future->markCompleted(Message());
}
@@ -424,8 +428,16 @@
work.type_,
work.id_);
if (message.isRequest()) {
- auto futureResponse = cb_->operator()(message);
+ ++serverActiveCalls_;
+ std::shared_ptr<FutureMessage> futureResponse;
+ try {
+ futureResponse = cb_->operator()(message);
+ } catch (const std::exception& e) {
+ futureResponse = std::make_shared<FutureMessage>();
+ futureResponse->setError(e.what());
+ }
if (futureResponse->completed()) {
+ --serverActiveCalls_;
if (!futureResponse->hasError()) {
send(work.from_, std::move(*futureResponse).moveValue());
} else {
@@ -435,12 +447,15 @@
message, futureResponse->error()->what()));
}
} else {
+ ++serverActiveAsyncCalls_;
auto fromId = work.from_.id_;
auto requestId = work.id_;
futureResponse->addCallback(
[this, fromId, requestId, futureResponse](
const Message& /* unused */,
const c10::optional<utils::FutureError>& err) {
+ --serverActiveCalls_;
+ --serverActiveAsyncCalls_;
if (!err) {
send(
getWorkerInfo(fromId),
@@ -487,6 +502,7 @@
}
}
futureCV_.notify_all();
+ --clientActiveCalls_;
if (message.type() == MessageType::EXCEPTION) {
fm->setError(std::string(
message.payload().begin(), message.payload().end()));
@@ -536,6 +552,7 @@
}
}
+ --clientActiveCalls_;
fm->setError(std::string(message.payload().begin(), message.payload().end()));
futureCV_.notify_all();
}
@@ -654,6 +671,7 @@
const auto exceptionMsg = createExceptionResponse(
Message({}, {}, MessageType::EXCEPTION), ss.str());
if (!timedOutFuture.future_->hasError()) {
+ --clientActiveCalls_;
timedOutFuture.future_->setError(std::string(
exceptionMsg.payload().begin(), exceptionMsg.payload().end()));
}
@@ -701,6 +719,10 @@
}
metrics[kThreadPoolSize] = c10::to_string(threadPool_.size());
metrics[kNumIdleThreads] = c10::to_string(threadPool_.numAvailable());
+ metrics[kClientActiveCalls] = c10::to_string(clientActiveCalls_.load());
+ metrics[kServerActiveCalls] = c10::to_string(serverActiveCalls_.load());
+ metrics[kServerActiveAsyncCalls] =
+ c10::to_string(serverActiveAsyncCalls_.load());
if (isGILProfilingEnabled()) {
// Add time-series based metrics, just GIL wait times for now.
{
diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h
index 141ff4a..32aaa57 100644
--- a/torch/csrc/distributed/rpc/process_group_agent.h
+++ b/torch/csrc/distributed/rpc/process_group_agent.h
@@ -264,6 +264,10 @@
std::mutex metricsMutex_;
std::vector<std::unique_ptr<AverageMetricsTracker>> metrics_;
void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
+
+ std::atomic<int32_t> clientActiveCalls_{0};
+ std::atomic<int32_t> serverActiveCalls_{0};
+ std::atomic<int32_t> serverActiveAsyncCalls_{0};
};
} // namespace rpc