[c10d] To make ProcessGroupNCCL to use globalStore for coordination (#117075)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117075
Approved by: https://github.com/wconstab
ghstack dependencies: #117074
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index c1c86cd..9ce4d92 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -746,6 +746,13 @@
       getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 10 /*10 Mins*/);
   ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0);
   enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
+  // store_ usually is wrapped with PrefixStore and the prefix is different
+  // across different ProcessGroupNCCL(PG) instances. We need to get the
+  // underlying non-PrefixStore for sharing global information shared across
+  // different PGs.
+  PrefixStore* prefixStore = dynamic_cast<PrefixStore*>(store_.get());
+  globalStore_ =
+      prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_;
 #ifdef ENABLE_NCCL_ERROR_CHECKING
   enableTiming_.store(
       getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
@@ -1488,7 +1495,8 @@
       if (timeSinceLastWorkListUpdate >= kWatchdogThreadSleepMillis &&
           timeSinceLastPollStore >= heartbeatTimeoutInSec_ * 1000) {
         lastTimePollStore = currentTime;
-        if (store_->check({std::string(TIMEOUT_DUMP)}) && !optAsyncDebugDump) {
+        if (globalStore_->check({std::string(TIMEOUT_DUMP)}) &&
+            !optAsyncDebugDump) {
           optAsyncDebugDump = launchAsyncDebugDump();
           waitForDumpOrTimeout(*optAsyncDebugDump);
           const auto exitMsg = c10::str(
@@ -1524,7 +1532,7 @@
               // abort process immediately.
               collectiveDebugInfoMode_.store(true);
               std::vector<uint8_t> vec(1);
-              store_->set(std::string(TIMEOUT_DUMP), vec);
+              globalStore_->set(std::string(TIMEOUT_DUMP), vec);
             }
 
             if (dumpOnTimeout_ && !optAsyncDebugDump) {
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index eef677e..09c65c7 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -13,6 +13,7 @@
 
 #include <torch/csrc/distributed/c10d/Backend.hpp>
 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
+#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
 #include <torch/csrc/distributed/c10d/Store.hpp>
 #include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
 
@@ -756,9 +757,16 @@
 
   static const int64_t kWatchdogThreadSleepMillis;
 
-  // The store is used to broadcast the NCCL unique ID of rank 0.
+  // The store is used to broadcast the NCCL unique ID of rank 0. This store
+  // comes with prefix and it is different across ProcessGroup NCCL instances
+  // (aka, different ProcessGroups).
   c10::intrusive_ptr<Store> store_;
 
+  // Reference to the store without prefix so that keys are same across all
+  // ProcessGroup NCCL instances and (key, value) pairs written to the store are
+  // global.
+  c10::intrusive_ptr<Store> globalStore_;
+
   bool storeError_{false};
 
   const c10::intrusive_ptr<Options> options_;