[flight recorder] record process group configuration (#120262)

Summary: Record process group configuration (i.e. ranks involved in a process group) to facilitate NCCL related debugging.

Differential Revision: D53792087

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120262
Approved by: https://github.com/shuqiangzhang
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index fc6ea3f..81e403b 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -4091,7 +4091,10 @@
 
         t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
         ver = t['version']
-        self.assertEqual(ver, "1.1")
+        self.assertEqual(ver, "1.2")
+        pg_config = t['pg_config']
+        self.assertEqual(len(pg_config), 1)
+        self.assertEqual(len(pg_config[0]), self.world_size)
         t = t['entries']
         self.assertEqual(len(t), 2)
         last = t[-1]
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 5a66038..9b20076 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -733,6 +733,7 @@
       getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000 /*60 Sec*/);
   coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000);
   ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0);
+  NCCLTraceBuffer::get()->record_pg_ranks(uid_, groupRanks());
   enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
   // store_ usually is wrapped with PrefixStore and the prefix is different
   // across different ProcessGroupNCCL(PG) instances. We need to get the
@@ -1472,6 +1473,15 @@
   return globalRank;
 }
 
+const std::vector<uint64_t>& ProcessGroupNCCL::groupRanks() const {
+  if (options_->global_ranks_in_group.empty() && uid_ == 0) {
+    static std::vector<uint64_t> globalRanks(size_);
+    std::iota(globalRanks.begin(), globalRanks.end(), 0);
+    return globalRanks;
+  }
+  return options_->global_ranks_in_group;
+}
+
 void ProcessGroupNCCL::watchdogHandler() {
   bool done = false;
   lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index 1217362..6199a1a 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -726,6 +726,9 @@
   // return the rank_ of the the very first PG created, aka, default global PG.
   const int& globalRank() const;
 
+  // Returns the global ranks of a PG.
+  const std::vector<uint64_t>& groupRanks() const;
+
  protected:
   // Function that runs as part of a separate thread aside from watchdog
   // thread because we need to check the heartbeat from watchdog thread
diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h
index 965d8a3..b5552f5 100644
--- a/torch/csrc/distributed/c10d/TraceUtils.h
+++ b/torch/csrc/distributed/c10d/TraceUtils.h
@@ -378,6 +378,7 @@
     max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0);
     capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false);
     enabled_ = max_entries_ > 0;
+    pg_id_to_ranks_ = {};
   }
   using Event = at::cuda::CUDAEvent;
   struct Entry {
@@ -426,6 +427,7 @@
   size_t max_entries_ = 0;
   size_t next_ = 0;
   size_t id_ = 0;
+  std::map<size_t, std::vector<uint64_t>> pg_id_to_ranks_;
 
   c10::optional<size_t> record(
       size_t pg_id,
@@ -475,6 +477,14 @@
     return id_++;
   }
 
+  void record_pg_ranks(size_t pg_id, std::vector<uint64_t> ranks) {
+    if (!enabled_) {
+      return;
+    }
+    std::lock_guard<std::mutex> guard(mutex_);
+    pg_id_to_ranks_[pg_id] = ranks;
+  }
+
   void update_state(Entry& r) {
     if (r.start_ != nullptr) {
       bool started = r.start_->query();
@@ -576,7 +586,8 @@
     c10::IValue version_key = "version";
     // Update whenever changing contents or formatting of the dump
     // (minor when adding fields, major when changing existing fields)
-    c10::IValue version_val = "1.1";
+    c10::IValue version_val = "1.2";
+    c10::IValue pg_config_key = "pg_config";
 
     c10::IValue pg_id_key = "pg_id";
     c10::IValue seq_id_key = "seq_id";
@@ -664,6 +675,14 @@
       dict.insert(frames_key, frames);
       entries.push_back(dict);
     }
+    auto pg_config = new_dict();
+    for (const auto& [pg_id, ranks] : pg_id_to_ranks_) {
+      auto pg_ranks = new_list();
+      for (const auto& rank : ranks) {
+        pg_ranks.push_back(static_cast<int>(rank));
+      }
+      pg_config.insert(static_cast<int>(pg_id), pg_ranks);
+    }
 
     // convert ncclDumpMap into a dictionary
     auto per_comm_dict = new_dict();
@@ -683,6 +702,7 @@
     if (per_comm_dict.size() > 0) {
       dict.insert(nccl_comm_key, per_comm_dict);
     }
+    dict.insert(pg_config_key, pg_config);
 
     return pickle_str(dict);
   }