[2/N] Add flag to control which rank should perform NaN check (#134345)
Fixes https://github.com/pytorch/pytorch/issues/134062.
For example, in case of broadcast / scatter, only the root rank should perform the NaN check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134345
Approved by: https://github.com/shuqiangzhang, https://github.com/wconstab
ghstack dependencies: #134300
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 64c51f7..55b785a 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -367,6 +367,28 @@
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
@requires_nccl()
+ @skip_if_lt_x_gpu(2)
+ def test_nan_p2p(self):
+ # Putting NaN at recv buffer, program should not fail as NaN checker
+ # should not check on receive buffer
+ os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
+ store = c10d.FileStore(self.file_name, self.world_size)
+ device = torch.device("cuda:%d" % self.rank)
+ c10d.init_process_group(
+ backend="nccl", store=store, rank=self.rank, world_size=self.world_size
+ )
+ t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
+ if self.rank == 0:
+ c10d.send(t, 1)
+ elif self.rank == 1:
+ # Putting NaN at recv buffer
+ t[1, 1] = float("nan")
+ c10d.recv(t, 0)
+ c10d.destroy_process_group()
+ # reset env
+ os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
+
+ @requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_destruct_before_terminate_pg(self):
# Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index abcf493..291eb79 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -2637,9 +2637,12 @@
PostProcess post,
OpType opType,
const char* profilingTitle,
- bool avoidRecordStreams) {
+ bool avoidRecordStreams,
+ bool nanCheck) {
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
+ nanCheck &= enableNanCheck_;
+
c10::cuda::CaptureStatus capture_status =
c10::cuda::currentStreamCaptureStatusMayInitCtx();
errorIfCapturingNonCapturableNCCL(capture_status);
@@ -2693,7 +2696,7 @@
at::cuda::OptionalCUDAGuard gpuGuard;
- if (enableNanCheck_) {
+ if (nanCheck) {
checkForNan(input, ncclStream);
}
@@ -3126,7 +3129,9 @@
// is gpuGuard needed for the if block below, or can i swap them
at::cuda::OptionalCUDAGuard gpuGuard;
- if (enableNanCheck_) {
+ // Only check for NaN for send ops, for recv ops `tensor` can be a random
+ // placeholder
+ if (enableNanCheck_ && opType == OpType::SEND) {
checkForNan(tensor, ncclStream);
}
@@ -3223,7 +3228,8 @@
Fn fn,
OpType opType,
const char* profilingTitle,
- bool avoidRecordStreams) {
+ bool avoidRecordStreams,
+ bool nanCheck) {
return collective(
input,
output,
@@ -3234,7 +3240,8 @@
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
opType,
profilingTitle,
- avoidRecordStreams);
+ avoidRecordStreams,
+ nanCheck);
}
template <typename Fn>
@@ -3484,6 +3491,9 @@
// avoidRecordStreams_ note: collective() will stash tensors.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
+ const auto root = opts.rootRank + opts.rootTensor;
+ bool nanCheck = (root == rank_);
+
return collective(
tensor,
tensor,
@@ -3491,7 +3501,6 @@
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
- const auto root = opts.rootRank + opts.rootTensor;
return ncclBcast(
input.data_ptr(),
input.numel(),
@@ -3502,7 +3511,8 @@
},
OpType::BROADCAST,
"nccl:broadcast",
- avoidRecordStreams);
+ avoidRecordStreams,
+ nanCheck);
}
// _broadcast_oop adds an out-of-place broadcast in PGNCCL
@@ -3522,6 +3532,9 @@
"Tensor input and output of _broadcast_oop must have the same number of elements ");
}
+ const auto root = opts.rootRank + opts.rootTensor;
+ bool nanCheck = (root == rank_);
+
return collective(
inputTensor,
outputTensor,
@@ -3529,7 +3542,6 @@
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
- const auto root = opts.rootRank + opts.rootTensor;
return ncclBroadcast(
input.data_ptr(),
output.data_ptr(),
@@ -3540,7 +3552,9 @@
stream.stream());
},
OpType::BROADCAST,
- "nccl:_broadcast_oop");
+ "nccl:_broadcast_oop",
+ /*avoidRecordStreams=*/false,
+ nanCheck);
}
c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
@@ -4491,6 +4505,9 @@
// inputs, which == inputTensors[0] on the root rank where it matters.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
+ const auto root = opts.rootRank;
+ bool nanCheck = (rank_ == root);
+
return collective(
outputTensor,
inputs[0], // just to fit the collective interface
@@ -4498,7 +4515,6 @@
at::Tensor& /* unused */,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
- const auto root = opts.rootRank;
if (getRank() == root) {
if (!avoidRecordStreams) {
for (auto input : inputs) {
@@ -4512,7 +4528,8 @@
},
OpType::SCATTER,
"nccl:scatter",
- avoidRecordStreams);
+ avoidRecordStreams,
+ nanCheck);
}
c10::intrusive_ptr<Work> ProcessGroupNCCL::recvAnysource(
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index 2ba68cd..b663515 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -762,7 +762,8 @@
Fn fn,
OpType opType,
const char* profilingTitle = nullptr,
- bool avoidRecordStreams = false);
+ bool avoidRecordStreams = false,
+ bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> collective(
@@ -773,7 +774,8 @@
PostProcess post,
OpType opType,
const char* profilingTitle = nullptr,
- bool avoidRecordStreams = false);
+ bool avoidRecordStreams = false,
+ bool nanCheck = true);
template <typename Fn>
c10::intrusive_ptr<Work> collectiveCoalesced(