[3/N] Set correct device to CUDA guards (#134357)

In `collective()`, `pointToPoint()` and `collectiveCoalesced()`, CUDA guards were created with an unset (default) CUDA device. This is the reason for the IMA facing the NaN checker in issue https://github.com/pytorch/pytorch/issues/134062.

With this fix, `torch.cuda.set_device(device)` is not needed to work around the IMA.

Also refactored a couple places where the guard is created -- preferably we create the guard with a known device, rather than setting the device later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134357
Approved by: https://github.com/wconstab, https://github.com/shuqiangzhang
ghstack dependencies: #134300, #134345
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 55b785a..3a57572 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -350,6 +350,7 @@
     @parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16])
     @skip_if_rocm
     def test_nan_assert(self, type):
+        # Expecting a device-side error when NaN is detected
         os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
         store = c10d.FileStore(self.file_name, self.world_size)
         pg = self._create_process_group_nccl(store, self.opts())
@@ -389,6 +390,24 @@
         os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
 
     @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    def test_nan_check(self):
+        # Not expecting an error, NaN check should not make legit code fail
+        os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
+        store = c10d.FileStore(self.file_name, self.world_size)
+        device = torch.device("cuda:%d" % self.rank)
+        c10d.init_process_group(
+            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
+        )
+        x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
+        t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
+        c10d.broadcast(x, src=0)
+        c10d.all_reduce(t)
+        c10d.destroy_process_group()
+        # reset env
+        os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
+
+    @requires_nccl()
     @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
     def test_destruct_before_terminate_pg(self):
         # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 291eb79..8f74118 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -1148,6 +1148,10 @@
     at::cuda::OptionalCUDAGuard gpuGuard;
     at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName);
     if (deviceIndex >= 0) {
+      // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in
+      // the map could be non deviceIndex, but rank to rank numbers. So we
+      // indeed need to check if deviceIndex >= 0
+      // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey`
       gpuGuard.set_index(deviceIndex);
     }
     LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ "
@@ -2141,7 +2145,9 @@
   bool batchP2P = ncclActiveGroupCounter_ > 0;
   bool singleP2POp = isP2POp(opType, batchP2P);
 
-  at::cuda::OptionalCUDAGuard gpuGuard;
+  // Get the device index
+  auto deviceIndex = device.index();
+  at::cuda::OptionalCUDAGuard gpuGuard(device);
 
   // [Group Start/End Note] This is used to ensure that nccl communicator will
   // be created before communication primitives are called. Let's look at this
@@ -2181,10 +2187,6 @@
     rank = p2pRank;
   }
 
-  // Get the device index
-  auto deviceIndex = device.index();
-  gpuGuard.set_index(deviceIndex);
-
 #ifdef NCCL_HAS_COMM_SPLIT
   if (options_->split_from) {
     TORCH_CHECK(
@@ -2694,7 +2696,7 @@
     work->stashed_for_allocator_safety_->push_back(input);
   }
 
-  at::cuda::OptionalCUDAGuard gpuGuard;
+  at::cuda::OptionalCUDAGuard gpuGuard(device);
 
   if (nanCheck) {
     checkForNan(input, ncclStream);
@@ -2859,7 +2861,7 @@
         std::make_shared<std::vector<at::Tensor>>(inputs);
   }
 
-  at::cuda::OptionalCUDAGuard gpuGuard;
+  at::cuda::OptionalCUDAGuard gpuGuard(device);
 
   // Start event should only be recorded before the ncclGroupStart() (which
   // happens inside AutoNcclGroup guard below)
@@ -3127,7 +3129,7 @@
   }
 
   // is gpuGuard needed for the if block below, or can i swap them
-  at::cuda::OptionalCUDAGuard gpuGuard;
+  at::cuda::OptionalCUDAGuard gpuGuard(device);
 
   // Only check for NaN for send ops, for recv ops `tensor` can be a random
   // placeholder