NCCL process group: avoid workEnqueue when capturing cuda graph (#102542)

Summary:
In torch.distributed, we make ProcessGroupNCCL not call workEnqueue when the cuda stream is capturing. I.e., when capturing a CUDA graph, we do not enqueue anything for the watchdog thread to consider. This allows capturing NCCL operations in a CUDA Graph.

This is followup to an internal discussion [1] where the watchdog thread was observed to crash when using cuda graphs containing an all_reduce. The watchdog thread wants to query events pertaining to enqueued work items, but this can't be done for "events" created during cuda graph capture.

[1] https://fb.workplace.com/groups/1405155842844877/posts/6975201909173548/

Test Plan: Test added. Also, the repro mentioned in https://fb.workplace.com/groups/1405155842844877/posts/7003002339726838/ runs successfully after this change.

Differential Revision: D46274814

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102542
Approved by: https://github.com/kwen2501
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index f9b54b6..3ab3a03 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -410,6 +410,28 @@
             work.wait()
         torch.cuda.synchronize(local_device)
 
+    @requires_nccl()
+    @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
+    def test_allreduce_in_cudagraph(self):
+        store = c10d.FileStore(self.file_name, self.world_size)
+        pg = self._create_process_group_nccl(store, self.opts())
+        local_device_idx = self.rank_to_GPU[self.rank][0]
+
+        xs = [torch.FloatTensor([1]).cuda(local_device_idx)]
+        ys = [torch.FloatTensor([8]).cuda(local_device_idx)]
+
+        # single warmup
+        pg.allreduce(xs).wait()
+        self.assertEqual(2, xs[0].item())
+
+        graph = torch.cuda.CUDAGraph()
+        with torch.cuda.graph(graph):
+            pg.allreduce(xs).wait()
+        self.assertEqual(2, xs[0].item())
+
+        graph.replay()
+        graph.replay()
+        self.assertEqual(xs, ys)
 
     @requires_nccl()
     @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 5787515..1d8db19 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -268,8 +268,7 @@
   }
 }
 
-inline void errorIfCapturingNonCapturableNCCL() {
-  auto status = c10::cuda::currentStreamCaptureStatusMayInitCtx();
+inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
   // parentheses avoid some compiler warnings
   static const uint64_t min_version =
       (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6);
@@ -1441,7 +1440,11 @@
   work->opTimeout_ = options_->timeout;
   work->store_ = store_;
 
-  if (coalescing_state_ & CoalColl) {
+  c10::cuda::CaptureStatus capture_status =
+      c10::cuda::currentStreamCaptureStatusMayInitCtx();
+
+  if ((coalescing_state_ & CoalColl) &&
+      capture_status == c10::cuda::CaptureStatus::None) {
     workEnqueue(work);
     // TODO: it seems we never enqueue work for single send/recv or batch P2P,
     // see the `pointToPoint` function. This should be fixed. Otherwise, we risk
@@ -1460,7 +1463,9 @@
     PostProcess post,
     OpType opType,
     const char* profilingTitle) {
-  errorIfCapturingNonCapturableNCCL();
+  c10::cuda::CaptureStatus capture_status =
+      c10::cuda::currentStreamCaptureStatusMayInitCtx();
+  errorIfCapturingNonCapturableNCCL(capture_status);
 
   // Bump collective counter
   seq_++;
@@ -1603,7 +1608,7 @@
   work->numelIn_ = inputs[0].numel();
   work->numelOut_ = outputs[0].numel();
 
-  if (!coalescing_state_) {
+  if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
     workEnqueue(work);
   }