[flight recorder] record process group configuration (#120262)
Summary: Record process group configuration (i.e. ranks involved in a process group) to facilitate NCCL related debugging.
Differential Revision: D53792087
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120262
Approved by: https://github.com/shuqiangzhang
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index fc6ea3f..81e403b 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -4091,7 +4091,10 @@
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
ver = t['version']
- self.assertEqual(ver, "1.1")
+ self.assertEqual(ver, "1.2")
+ pg_config = t['pg_config']
+ self.assertEqual(len(pg_config), 1)
+ self.assertEqual(len(pg_config[0]), self.world_size)
t = t['entries']
self.assertEqual(len(t), 2)
last = t[-1]
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 5a66038..9b20076 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -733,6 +733,7 @@
getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000 /*60 Sec*/);
coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000);
ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0);
+ NCCLTraceBuffer::get()->record_pg_ranks(uid_, groupRanks());
enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
// store_ usually is wrapped with PrefixStore and the prefix is different
// across different ProcessGroupNCCL(PG) instances. We need to get the
@@ -1472,6 +1473,15 @@
return globalRank;
}
+const std::vector<uint64_t>& ProcessGroupNCCL::groupRanks() const {
+ if (options_->global_ranks_in_group.empty() && uid_ == 0) {
+ static std::vector<uint64_t> globalRanks(size_);
+ std::iota(globalRanks.begin(), globalRanks.end(), 0);
+ return globalRanks;
+ }
+ return options_->global_ranks_in_group;
+}
+
void ProcessGroupNCCL::watchdogHandler() {
bool done = false;
lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index 1217362..6199a1a 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -726,6 +726,9 @@
// return the rank_ of the the very first PG created, aka, default global PG.
const int& globalRank() const;
+ // Returns the global ranks of a PG.
+ const std::vector<uint64_t>& groupRanks() const;
+
protected:
// Function that runs as part of a separate thread aside from watchdog
// thread because we need to check the heartbeat from watchdog thread
diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h
index 965d8a3..b5552f5 100644
--- a/torch/csrc/distributed/c10d/TraceUtils.h
+++ b/torch/csrc/distributed/c10d/TraceUtils.h
@@ -378,6 +378,7 @@
max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0);
capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false);
enabled_ = max_entries_ > 0;
+ pg_id_to_ranks_ = {};
}
using Event = at::cuda::CUDAEvent;
struct Entry {
@@ -426,6 +427,7 @@
size_t max_entries_ = 0;
size_t next_ = 0;
size_t id_ = 0;
+ std::map<size_t, std::vector<uint64_t>> pg_id_to_ranks_;
c10::optional<size_t> record(
size_t pg_id,
@@ -475,6 +477,14 @@
return id_++;
}
+ void record_pg_ranks(size_t pg_id, std::vector<uint64_t> ranks) {
+ if (!enabled_) {
+ return;
+ }
+ std::lock_guard<std::mutex> guard(mutex_);
+ pg_id_to_ranks_[pg_id] = ranks;
+ }
+
void update_state(Entry& r) {
if (r.start_ != nullptr) {
bool started = r.start_->query();
@@ -576,7 +586,8 @@
c10::IValue version_key = "version";
// Update whenever changing contents or formatting of the dump
// (minor when adding fields, major when changing existing fields)
- c10::IValue version_val = "1.1";
+ c10::IValue version_val = "1.2";
+ c10::IValue pg_config_key = "pg_config";
c10::IValue pg_id_key = "pg_id";
c10::IValue seq_id_key = "seq_id";
@@ -664,6 +675,14 @@
dict.insert(frames_key, frames);
entries.push_back(dict);
}
+ auto pg_config = new_dict();
+ for (const auto& [pg_id, ranks] : pg_id_to_ranks_) {
+ auto pg_ranks = new_list();
+ for (const auto& rank : ranks) {
+ pg_ranks.push_back(static_cast<int>(rank));
+ }
+ pg_config.insert(static_cast<int>(pg_id), pg_ranks);
+ }
// convert ncclDumpMap into a dictionary
auto per_comm_dict = new_dict();
@@ -683,6 +702,7 @@
if (per_comm_dict.size() > 0) {
dict.insert(nccl_comm_key, per_comm_dict);
}
+ dict.insert(pg_config_key, pg_config);
return pickle_str(dict);
}