[Dist profiling] Fix ProcessGroupNCCL collective profiling (#55204)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55204
Implements a fix discussed offline with pritamdamia87 to run end callbacks after `CUDAFuture`'s wrapCallback has ensured appropriate synchronization. Also enables the relevant distributed profiling tests that were previously disabled for ProcessGroupNCCL.
Note that the profiling infrastructure has moved to primarily encourage the use of torch.profiler and CUPTI to trace CUDA kernels, support for distributed collectives for that will require further discussion with ilia-cher. However, this PR improves the usability of torch.autograd.profiler with respect to distributed collectives.
ghstack-source-id: 127357995
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D27491711
fbshipit-source-id: cec7703a4c5d59b5023b0aa8fef4c2e3fb8d37d0
diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp
index 396ab00..2f823ed 100644
--- a/torch/lib/c10d/ProcessGroupNCCL.cpp
+++ b/torch/lib/c10d/ProcessGroupNCCL.cpp
@@ -1133,6 +1133,15 @@
work->future_ = c10::make_intrusive<at::cuda::CUDAFuture>(
c10::ListType::create(c10::TensorType::get()),
getIndicesOfDevices(devices));
+
+ // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
+ // future blocks the stream this callback runs on the corresponding
+ // cudaEvents_ ensuring appropriate synchronization.
+ if (work->recordFunctionEndCallback_) {
+ work->future_->addCallback([work]() {
+ work->recordFunctionEndCallback_();
+ });
+ }
work->future_->markCompleted(at::IValue(*work->outputs_));
}
@@ -1141,17 +1150,6 @@
work->opTimeout_ = options_->timeout;
work->store_ = store_;
- if (work->recordFunctionEndCallback_) {
- // recordFunctionEndCallback_ is normally called in fininsh() function by
- // base class, but since finish is not called by WorkNCCL, we schedule this
- // function to be run when work is done. Note that addCallback() onto the
- // Work's CUDAFuture is not useful here, as it would just run the callback
- // inline.
- // Note when can_profile is false, profilingTitle is not provided and so,
- // recordFunctionEndCallback_ is not set.
- work->recordFunctionEndCallback_();
- }
-
if (asyncErrorHandling_) {
workEnqueue(work);
}
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py
index 37fdfb9..832d1ab 100644
--- a/torch/testing/_internal/distributed/distributed_test.py
+++ b/torch/testing/_internal/distributed/distributed_test.py
@@ -97,6 +97,7 @@
CUDA_PROFILING_SUPPORTED_BACKENDS = [
dist.Backend.GLOO,
dist.Backend.MPI,
+ dist.Backend.NCCL,
]
# Allowlist of distributed backends where profiling is supported for p2p ops
@@ -1670,6 +1671,7 @@
)
@skip_if_no_gpu
def test_all_reduce_sum_cuda(self):
+ torch.cuda.set_device(self.rank)
group, group_id, rank = self._init_global_test()
rank_to_GPU = self._init_multigpu_helper()
self._test_all_reduce_helper(
@@ -1690,6 +1692,7 @@
)
@skip_if_no_gpu
def test_all_reduce_sum_cuda_async(self):
+ torch.cuda.set_device(self.rank)
group, group_id, rank = self._init_global_test()
rank_to_GPU = self._init_multigpu_helper()
self._test_all_reduce_helper(
@@ -1734,6 +1737,7 @@
)
@skip_if_no_gpu
def test_all_reduce_sum_cuda_complex(self):
+ torch.cuda.set_device(self.rank)
group, group_id, rank = self._init_global_test()
rank_to_GPU = self._init_multigpu_helper()
self._test_all_reduce_helper(