[2/N] Add flag to control which rank should perform NaN check (#134345)

Fixes https://github.com/pytorch/pytorch/issues/134062.
For example, in case of broadcast / scatter, only the root rank should perform the NaN check.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134345
Approved by: https://github.com/shuqiangzhang, https://github.com/wconstab
ghstack dependencies: #134300
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 64c51f7..55b785a 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -367,6 +367,28 @@
         os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
 
     @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    def test_nan_p2p(self):
+        # Putting NaN at recv buffer, program should not fail as NaN checker
+        # should not check on receive buffer
+        os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
+        store = c10d.FileStore(self.file_name, self.world_size)
+        device = torch.device("cuda:%d" % self.rank)
+        c10d.init_process_group(
+            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
+        )
+        t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
+        if self.rank == 0:
+            c10d.send(t, 1)
+        elif self.rank == 1:
+            # Putting NaN at recv buffer
+            t[1, 1] = float("nan")
+            c10d.recv(t, 0)
+        c10d.destroy_process_group()
+        # reset env
+        os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
+
+    @requires_nccl()
     @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
     def test_destruct_before_terminate_pg(self):
         # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index abcf493..291eb79 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -2637,9 +2637,12 @@
     PostProcess post,
     OpType opType,
     const char* profilingTitle,
-    bool avoidRecordStreams) {
+    bool avoidRecordStreams,
+    bool nanCheck) {
   // Environment setting by the user may add onto collective call's option
   avoidRecordStreams |= avoidRecordStreams_;
+  nanCheck &= enableNanCheck_;
+
   c10::cuda::CaptureStatus capture_status =
       c10::cuda::currentStreamCaptureStatusMayInitCtx();
   errorIfCapturingNonCapturableNCCL(capture_status);
@@ -2693,7 +2696,7 @@
 
   at::cuda::OptionalCUDAGuard gpuGuard;
 
-  if (enableNanCheck_) {
+  if (nanCheck) {
     checkForNan(input, ncclStream);
   }
 
@@ -3126,7 +3129,9 @@
   // is gpuGuard needed for the if block below, or can i swap them
   at::cuda::OptionalCUDAGuard gpuGuard;
 
-  if (enableNanCheck_) {
+  // Only check for NaN for send ops, for recv ops `tensor` can be a random
+  // placeholder
+  if (enableNanCheck_ && opType == OpType::SEND) {
     checkForNan(tensor, ncclStream);
   }
 
@@ -3223,7 +3228,8 @@
     Fn fn,
     OpType opType,
     const char* profilingTitle,
-    bool avoidRecordStreams) {
+    bool avoidRecordStreams,
+    bool nanCheck) {
   return collective(
       input,
       output,
@@ -3234,7 +3240,8 @@
          c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
       opType,
       profilingTitle,
-      avoidRecordStreams);
+      avoidRecordStreams,
+      nanCheck);
 }
 
 template <typename Fn>
@@ -3484,6 +3491,9 @@
   // avoidRecordStreams_ note: collective() will stash tensors.
   bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
 
+  const auto root = opts.rootRank + opts.rootTensor;
+  bool nanCheck = (root == rank_);
+
   return collective(
       tensor,
       tensor,
@@ -3491,7 +3501,6 @@
           at::Tensor& output,
           ncclComm_t comm,
           at::cuda::CUDAStream& stream) {
-        const auto root = opts.rootRank + opts.rootTensor;
         return ncclBcast(
             input.data_ptr(),
             input.numel(),
@@ -3502,7 +3511,8 @@
       },
       OpType::BROADCAST,
       "nccl:broadcast",
-      avoidRecordStreams);
+      avoidRecordStreams,
+      nanCheck);
 }
 
 // _broadcast_oop adds an out-of-place broadcast in PGNCCL
@@ -3522,6 +3532,9 @@
         "Tensor input and output of _broadcast_oop must have the same number of elements ");
   }
 
+  const auto root = opts.rootRank + opts.rootTensor;
+  bool nanCheck = (root == rank_);
+
   return collective(
       inputTensor,
       outputTensor,
@@ -3529,7 +3542,6 @@
           at::Tensor& output,
           ncclComm_t comm,
           at::cuda::CUDAStream& stream) {
-        const auto root = opts.rootRank + opts.rootTensor;
         return ncclBroadcast(
             input.data_ptr(),
             output.data_ptr(),
@@ -3540,7 +3552,9 @@
             stream.stream());
       },
       OpType::BROADCAST,
-      "nccl:_broadcast_oop");
+      "nccl:_broadcast_oop",
+      /*avoidRecordStreams=*/false,
+      nanCheck);
 }
 
 c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
@@ -4491,6 +4505,9 @@
   // inputs, which == inputTensors[0] on the root rank where it matters.
   bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
 
+  const auto root = opts.rootRank;
+  bool nanCheck = (rank_ == root);
+
   return collective(
       outputTensor,
       inputs[0], // just to fit the collective interface
@@ -4498,7 +4515,6 @@
           at::Tensor& /* unused */,
           ncclComm_t comm,
           at::cuda::CUDAStream& stream) {
-        const auto root = opts.rootRank;
         if (getRank() == root) {
           if (!avoidRecordStreams) {
             for (auto input : inputs) {
@@ -4512,7 +4528,8 @@
       },
       OpType::SCATTER,
       "nccl:scatter",
-      avoidRecordStreams);
+      avoidRecordStreams,
+      nanCheck);
 }
 
 c10::intrusive_ptr<Work> ProcessGroupNCCL::recvAnysource(
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index 2ba68cd..b663515 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -762,7 +762,8 @@
       Fn fn,
       OpType opType,
       const char* profilingTitle = nullptr,
-      bool avoidRecordStreams = false);
+      bool avoidRecordStreams = false,
+      bool nanCheck = true);
 
   template <typename Fn, typename PreProcess, typename PostProcess>
   c10::intrusive_ptr<Work> collective(
@@ -773,7 +774,8 @@
       PostProcess post,
       OpType opType,
       const char* profilingTitle = nullptr,
-      bool avoidRecordStreams = false);
+      bool avoidRecordStreams = false,
+      bool nanCheck = true);
 
   template <typename Fn>
   c10::intrusive_ptr<Work> collectiveCoalesced(