NCCL process group: avoid workEnqueue when capturing cuda graph (#102542)
Summary:
In torch.distributed, we make ProcessGroupNCCL not call workEnqueue when the cuda stream is capturing. I.e., when capturing a CUDA graph, we do not enqueue anything for the watchdog thread to consider. This allows capturing NCCL operations in a CUDA Graph.
This is followup to an internal discussion [1] where the watchdog thread was observed to crash when using cuda graphs containing an all_reduce. The watchdog thread wants to query events pertaining to enqueued work items, but this can't be done for "events" created during cuda graph capture.
[1] https://fb.workplace.com/groups/1405155842844877/posts/6975201909173548/
Test Plan: Test added. Also, the repro mentioned in https://fb.workplace.com/groups/1405155842844877/posts/7003002339726838/ runs successfully after this change.
Differential Revision: D46274814
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102542
Approved by: https://github.com/kwen2501
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index f9b54b6..3ab3a03 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -410,6 +410,28 @@
work.wait()
torch.cuda.synchronize(local_device)
+ @requires_nccl()
+ @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
+ def test_allreduce_in_cudagraph(self):
+ store = c10d.FileStore(self.file_name, self.world_size)
+ pg = self._create_process_group_nccl(store, self.opts())
+ local_device_idx = self.rank_to_GPU[self.rank][0]
+
+ xs = [torch.FloatTensor([1]).cuda(local_device_idx)]
+ ys = [torch.FloatTensor([8]).cuda(local_device_idx)]
+
+ # single warmup
+ pg.allreduce(xs).wait()
+ self.assertEqual(2, xs[0].item())
+
+ graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(graph):
+ pg.allreduce(xs).wait()
+ self.assertEqual(2, xs[0].item())
+
+ graph.replay()
+ graph.replay()
+ self.assertEqual(xs, ys)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 5787515..1d8db19 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -268,8 +268,7 @@
}
}
-inline void errorIfCapturingNonCapturableNCCL() {
- auto status = c10::cuda::currentStreamCaptureStatusMayInitCtx();
+inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
// parentheses avoid some compiler warnings
static const uint64_t min_version =
(((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6);
@@ -1441,7 +1440,11 @@
work->opTimeout_ = options_->timeout;
work->store_ = store_;
- if (coalescing_state_ & CoalColl) {
+ c10::cuda::CaptureStatus capture_status =
+ c10::cuda::currentStreamCaptureStatusMayInitCtx();
+
+ if ((coalescing_state_ & CoalColl) &&
+ capture_status == c10::cuda::CaptureStatus::None) {
workEnqueue(work);
// TODO: it seems we never enqueue work for single send/recv or batch P2P,
// see the `pointToPoint` function. This should be fixed. Otherwise, we risk
@@ -1460,7 +1463,9 @@
PostProcess post,
OpType opType,
const char* profilingTitle) {
- errorIfCapturingNonCapturableNCCL();
+ c10::cuda::CaptureStatus capture_status =
+ c10::cuda::currentStreamCaptureStatusMayInitCtx();
+ errorIfCapturingNonCapturableNCCL(capture_status);
// Bump collective counter
seq_++;
@@ -1603,7 +1608,7 @@
work->numelIn_ = inputs[0].numel();
work->numelOut_ = outputs[0].numel();
- if (!coalescing_state_) {
+ if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
workEnqueue(work);
}