[c10d] To make ProcessGroupNCCL to use globalStore for coordination (#117075)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117075
Approved by: https://github.com/wconstab
ghstack dependencies: #117074
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index c1c86cd..9ce4d92 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -746,6 +746,13 @@
getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 10 /*10 Mins*/);
ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0);
enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
+ // store_ usually is wrapped with PrefixStore and the prefix is different
+ // across different ProcessGroupNCCL(PG) instances. We need to get the
+ // underlying non-PrefixStore for sharing global information shared across
+ // different PGs.
+ PrefixStore* prefixStore = dynamic_cast<PrefixStore*>(store_.get());
+ globalStore_ =
+ prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_;
#ifdef ENABLE_NCCL_ERROR_CHECKING
enableTiming_.store(
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
@@ -1488,7 +1495,8 @@
if (timeSinceLastWorkListUpdate >= kWatchdogThreadSleepMillis &&
timeSinceLastPollStore >= heartbeatTimeoutInSec_ * 1000) {
lastTimePollStore = currentTime;
- if (store_->check({std::string(TIMEOUT_DUMP)}) && !optAsyncDebugDump) {
+ if (globalStore_->check({std::string(TIMEOUT_DUMP)}) &&
+ !optAsyncDebugDump) {
optAsyncDebugDump = launchAsyncDebugDump();
waitForDumpOrTimeout(*optAsyncDebugDump);
const auto exitMsg = c10::str(
@@ -1524,7 +1532,7 @@
// abort process immediately.
collectiveDebugInfoMode_.store(true);
std::vector<uint8_t> vec(1);
- store_->set(std::string(TIMEOUT_DUMP), vec);
+ globalStore_->set(std::string(TIMEOUT_DUMP), vec);
}
if (dumpOnTimeout_ && !optAsyncDebugDump) {
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index eef677e..09c65c7 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -13,6 +13,7 @@
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
+#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
@@ -756,9 +757,16 @@
static const int64_t kWatchdogThreadSleepMillis;
- // The store is used to broadcast the NCCL unique ID of rank 0.
+ // The store is used to broadcast the NCCL unique ID of rank 0. This store
+ // comes with prefix and it is different across ProcessGroup NCCL instances
+ // (aka, different ProcessGroups).
c10::intrusive_ptr<Store> store_;
+ // Reference to the store without prefix so that keys are same across all
+ // ProcessGroup NCCL instances and (key, value) pairs written to the store are
+ // global.
+ c10::intrusive_ptr<Store> globalStore_;
+
bool storeError_{false};
const c10::intrusive_ptr<Options> options_;