[C10D] Fix pointToPoint op Flight Recording (#120270)
Fix and test issues with both coalesced and individual send/recv ops
Considered an alternate approach and then ditched it
- alternate approach: #119757
- reason ditched: prefer recording individual collective events inside
coalescing region instead of just the event at the end of the region,
which also would not have tensor sizes or opnames without additional
state variables added
Another approach also ditched
- record events on workEnqueue instead of initWork
- reason ditched: too messy to get input/output shapes tagged on
recording when recording in workEnqueue. Adding the info onto the
Work obj would be possible, but adds to overhead of copying Works
which we do on every collective. We can get info off the input/output
tensors directly in initWork, but we don't want to keep refs to those
tensors alive while the work is Enqueued, so we'd have to specifically
copy size lists or something.
This PR instead avoids creating a work inside pointToPoint when
coalescing is active. Instead, only at endCoalescing() is a work finally
intialized and enqueued. But it adds a record() call inside
pointToPoint() instead of creating a work, during coalescing. This
record() call picks up tensor shapes and op names.
It ALSO changes initWork to accept a 'record' argument. This defaults to
false, and should only be set to true if the caller ensures the work
will be enqueued by workEnqueue, ensuring its cuda events are live when
used by flight recorder's update_state().
The testing uncovers some odd pre-existing behavior and leaves them
alone for now. We could change some of these
- seq starts off at 1, not 0 for first op (but this is inconistent)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120270
Approved by: https://github.com/shuqiangzhang
ghstack dependencies: #120724
diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp
index 2201b6a..991a8c5 100644
--- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp
+++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp
@@ -66,7 +66,8 @@
c10d::OpType opType,
const char* profilingTitle,
const std::vector<at::Tensor>& inputs = {},
- const std::vector<at::Tensor>& outputs = {}) override {
+ const std::vector<at::Tensor>& outputs = {},
+ bool record = false) override {
return c10::make_intrusive<WorkNCCLSimulateErrors>(
device, simulateError_, rank, opType, seq_);
}
@@ -127,7 +128,8 @@
c10d::OpType opType,
const char* profilingTitle,
const std::vector<at::Tensor>& inputs = {},
- const std::vector<at::Tensor>& outputs = {}) override {
+ const std::vector<at::Tensor>& outputs = {},
+ bool record = false) override {
return c10::make_intrusive<WorkNCCLTimedoutErrors>(
device, setTimedoutError_, rank, opType, seq_);
}
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index b06d638..0724ad4 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -3997,7 +3997,7 @@
def setUp(self):
super().setUp()
os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0' # see 'timing_enabled' parametrized tests
- os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '10'
+ os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '1000'
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
self.tempdir = tempfile.TemporaryDirectory()
os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = self._trace_basename()
@@ -4159,6 +4159,7 @@
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_long(self):
+ os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '10'
if self.rank == self.MAIN_PROCESS_RANK:
return
pg = self._create_process_group_nccl()
@@ -4287,6 +4288,132 @@
pg.allreduce(a).wait()
torch.cuda.synchronize(device=device)
+ @requires_nccl()
+ @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
+ @parametrize("op_sizes_per_coalesce", [
+ [(2, 3)],
+ [(2, 3), (5, 5), (1,)],
+ ])
+ @parametrize("timing_enabled", [True, False])
+ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled):
+ """
+ 'WorkEnqueue' was skipped for isendirecv, leading to segfault on dump_entries when update_state tried to use
+ a destructed Work obj's cuda events
+ """
+
+ if self.rank == self.MAIN_PROCESS_RANK:
+ return
+ pg = self._create_process_group_nccl()
+ if timing_enabled:
+ pg._enable_collectives_timing()
+
+ num_coalesced_ops = 20
+ ops_per_coalesce = len(op_sizes_per_coalesce)
+ for i in range(num_coalesced_ops):
+ ops = []
+ for input_sizes in op_sizes_per_coalesce:
+ tensor = torch.zeros(input_sizes).to(self.local_device)
+ if self.rank == 0:
+ ops.append(dist.P2POp(dist.irecv, tensor, 1))
+ elif self.rank == 1:
+ tensor *= 2
+ ops.append(dist.P2POp(dist.isend, tensor, 0))
+
+ dist.batch_isend_irecv(ops).pop().wait()
+
+ torch.cuda.synchronize()
+
+ if timing_enabled:
+ # wait for watchdog thread to process the queue of works
+ time.sleep(1)
+
+ t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
+ self.assertEqual(len(t['entries']), num_coalesced_ops * (ops_per_coalesce + 1))
+ expected_record_id = 0
+ for seq in range(num_coalesced_ops):
+ first_op = seq * (ops_per_coalesce + 1)
+ coalesced_op = first_op + ops_per_coalesce
+ expected_seq = seq + 1
+ for p2p_op_idx, input_sizes in zip(range(first_op, coalesced_op, 1), op_sizes_per_coalesce):
+ # the indivudal ops inside the coalescing group the individual op metadata,
+ # but not the timing info coming from the actual coalesced kernel
+ profiling_name = 'nccl:recv 0<-1' if self.rank == 0 else 'nccl:send 1->0'
+ self.assertEqual(t['entries'][p2p_op_idx]['record_id'], expected_record_id)
+ expected_record_id += 1
+ self.assertEqual(t['entries'][p2p_op_idx]['profiling_name'], profiling_name)
+ self.assertEqual(t['entries'][p2p_op_idx]['seq_id'], expected_seq)
+ self.assertEqual(t['entries'][p2p_op_idx]['input_sizes'], [input_sizes])
+ self.assertEqual(t['entries'][p2p_op_idx]['output_sizes'], [input_sizes])
+ # duration doesn't get tagged onto individual ops yet, nor is their state updated
+ self.assertEqual(t['entries'][p2p_op_idx]['state'], 'scheduled')
+ self.assertTrue('duration_ms' not in t['entries'][p2p_op_idx])
+
+ # the coalesced op has no metadata but indicates that coalescing was used,
+ # and accurately reflects the timing and state info for the whole group
+ self.assertEqual(t['entries'][coalesced_op]['record_id'], expected_record_id)
+ expected_record_id += 1
+ self.assertEqual(t['entries'][coalesced_op]['profiling_name'], 'nccl:coalesced')
+ self.assertEqual(t['entries'][coalesced_op]['seq_id'], expected_seq)
+ self.assertEqual(t['entries'][coalesced_op]['state'], 'completed')
+ self.assertEqual(t['entries'][coalesced_op]['input_sizes'], [])
+ self.assertEqual(t['entries'][coalesced_op]['output_sizes'], [])
+
+ if timing_enabled:
+ duration = t['entries'][coalesced_op]['duration_ms']
+ self.assertTrue(0.001 < duration < 10000, duration)
+ else:
+ self.assertTrue('duration_ms' not in t['entries'][coalesced_op])
+
+ @parametrize("op_sizes", [
+ [(2, 3)],
+ [(2, 3), (5, 5), (1,)],
+ ])
+ @parametrize("timing_enabled", [True, False])
+ def test_individual_send_recv(self, op_sizes, timing_enabled):
+ """
+ 'WorkEnqueue' was skipped for isendirecv, leading to segfault on dump_entries when update_state tried to use
+ a destructed Work obj's cuda events
+ """
+
+ if self.rank == self.MAIN_PROCESS_RANK:
+ return
+ pg = self._create_process_group_nccl()
+ if timing_enabled:
+ pg._enable_collectives_timing()
+ num_repeats = 10
+ ops_per_repeat = len(op_sizes)
+ for i in range(num_repeats):
+ for input_sizes in op_sizes:
+ tensor = torch.zeros(input_sizes).to(self.local_device)
+ if self.rank == 0:
+ dist.recv(tensor, 1)
+ elif self.rank == 1:
+ tensor *= 2
+ dist.send(tensor, 0)
+
+ torch.cuda.synchronize()
+ if timing_enabled:
+ # wait for watchdog thread to process the queue of works
+ time.sleep(1)
+
+ t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
+ self.assertEqual(len(t['entries']), num_repeats * (ops_per_repeat))
+ for seq in range(num_repeats * ops_per_repeat):
+ input_sizes = op_sizes[seq % ops_per_repeat]
+ profiling_name = 'nccl:recv 0<-1' if self.rank == 0 else 'nccl:send 1->0'
+ expected_seq = seq + 1
+ self.assertEqual(t['entries'][seq]['profiling_name'], profiling_name)
+ self.assertEqual(t['entries'][seq]['seq_id'], expected_seq)
+ self.assertEqual(t['entries'][seq]['input_sizes'], [input_sizes])
+ self.assertEqual(t['entries'][seq]['output_sizes'], [input_sizes])
+ self.assertEqual(t['entries'][seq]['state'], 'completed')
+
+ if timing_enabled:
+ duration = t['entries'][seq]['duration_ms']
+ self.assertTrue(0.001 < duration < 10000, duration)
+ else:
+ self.assertTrue('duration_ms' not in t['entries'][seq])
+
class NCCLTraceTestDumpOnTimeoutBase(NCCLTraceTestBase):
timeout_sec = 1
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 9b20076..1151163 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -2155,7 +2155,8 @@
OpType opType,
const char* profilingTitle,
const std::vector<at::Tensor>& inputs,
- const std::vector<at::Tensor>& outputs) { // TODO(kwen2501): necessary?
+ const std::vector<at::Tensor>& outputs, // TODO(kwen2501): necessary?
+ bool record) {
auto r = c10::make_intrusive<ProcessGroupNCCL::WorkNCCL>(
device,
rank,
@@ -2167,15 +2168,29 @@
desyncDebug_,
enableTiming_.load(),
dist_debug_level_);
- r->trace_id_ = NCCLTraceBuffer::get()->record(
- uid_,
- seq_,
- // create a string copy of profilingTitle
- profilingTitle ? profilingTitle : "",
- inputs,
- outputs,
- r->ncclStartEvent_.get(),
- r->ncclEndEvent_.get());
+ if (record) {
+ // Ideally record every work that we enqueue, rather than every work we
+ // create.
+ // - at the time of this PR we do not currently enqueue every created work
+ // - but it is unsafe to steal refs to start/end cuda events from Works that
+ // may go out of scope before flight recorder has retired them,
+ // so we must ensure that any work that is initialized via initWork will
+ // be enqueued
+ // - initially, moved record() into workEnqueue(), but found that makes it
+ // hard to get access to profilingTitle,
+ // inputs, and outputs for metadata recording, and we don't want to attach
+ // these objects to the Work becuase it has implications for keeping those
+ // tensors alive longer and adds overhead when copying Work objects
+ // between threads
+ r->trace_id_ = NCCLTraceBuffer::get()->record(
+ uid_,
+ seq_,
+ profilingTitle ? profilingTitle : "",
+ inputs,
+ outputs,
+ r->ncclStartEvent_.get(),
+ r->ncclEndEvent_.get());
+ }
return r;
}
@@ -2229,6 +2244,15 @@
coalescedComms_.clear();
coalescing_state_ |= CoalActive;
groupStart();
+ // Other collective ops bump seq_ before creating a work. Thus, if coalesced
+ // ops bump seq_ only after initing a work they will collide with (reuse) the
+ // seq_ of the last non-coalesced collective. Previously, seq_ was bumped
+ // inside endCoalescing, but before initWork. Since we now record individual
+ // ops from a coalesce group into the flight recorder, we want to have the
+ // same seq_ for those ops and its 'endCoalescing' op. Hence we bump during
+ // start, which has one minor downside- we burn a seq_ if someone ever does a
+ // 'start' and 'end' coalescing region without doing an operation inbetween.
+ seq_++;
}
// `optype` is for specifying a composite optype, such as ALLGATHER and
@@ -2237,6 +2261,7 @@
if (coalescedComms_.size() == 0) {
// There is no actual work being coalesced, return here
groupEnd();
+ coalescing_state_ = 0;
return nullptr;
}
@@ -2249,16 +2274,19 @@
const auto key = getKeyFromDevice(device);
auto ncclStream = ncclStreams_.at(key);
- // Bump collective counter
- seq_++;
-
// Create Work object
- auto work = initWork(device, rank_, optype, "nccl:coalesced");
+ c10::cuda::CaptureStatus capture_status =
+ c10::cuda::currentStreamCaptureStatusMayInitCtx();
+ bool enqueue =
+ (coalescing_state_) && capture_status == c10::cuda::CaptureStatus::None;
+ auto work =
+ initWork(device, rank_, optype, "nccl:coalesced", {}, {}, enqueue);
work->ncclComm_ = comm;
work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams_;
work->opTimeout_ = options_->timeout;
work->store_ = store_;
+
// Record start before ncclGroupEnd
if (work->timingEnabled_) {
work->ncclStartEvent_->record(ncclStream);
@@ -2279,18 +2307,12 @@
work->stashed_for_allocator_safety_ =
std::make_shared<std::vector<at::Tensor>>();
}
- c10::cuda::CaptureStatus capture_status =
- c10::cuda::currentStreamCaptureStatusMayInitCtx();
// Notify graphs before we check the capture status preemptively
at::cuda::CUDAGraph::inc_pending_event_queries();
- if ((coalescing_state_ & CoalColl) &&
- capture_status == c10::cuda::CaptureStatus::None) {
+ if (enqueue) {
workEnqueue(work);
- // TODO: it seems we never enqueue work for single send/recv or batch P2P,
- // see the `pointToPoint` function. This should be fixed. Otherwise, we risk
- // not being able to abort hanged P2P ops.
} else {
at::cuda::CUDAGraph::dec_pending_event_queries();
}
@@ -2342,7 +2364,10 @@
std::vector<at::Tensor> inputs{input};
std::vector<at::Tensor> outputs{output};
- auto work = initWork(device, rank_, opType, profilingTitle, inputs, outputs);
+ bool enqueue =
+ !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None;
+ auto work =
+ initWork(device, rank_, opType, profilingTitle, inputs, outputs, enqueue);
// Store references to outputs to be used by WorkNCCL::result and operator<<.
work->outputs_ =
@@ -2439,8 +2464,7 @@
// Notify graphs before we check the capture status preemptively
at::cuda::CUDAGraph::inc_pending_event_queries();
-
- if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
+ if (enqueue) {
workEnqueue(work);
} else {
at::cuda::CUDAGraph::dec_pending_event_queries();
@@ -2483,7 +2507,8 @@
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
- auto work = initWork(device, rank_, opType, nullptr, inputs, outputs);
+ auto work = initWork(
+ device, rank_, opType, nullptr, inputs, outputs, /*record=*/true);
// Store references to outputs to be used by WorkNCCL::result and operator<<.
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
@@ -2643,10 +2668,14 @@
p2pRank = rank_ <= peer ? 0 : 1;
isSendRecvSelf = rank_ == peer;
p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank;
- // Bump sequence number. Don't do so if it's a batch P2P, it will be bumped
- // in `endCoalescing`.
- seq_++;
+
+ if (!coalescing_state_) {
+ // Bump sequence number. Don't do so if it's a batch P2P, it will be
+ // bumped in `endCoalescing`.
+ seq_++;
+ }
}
+
auto ncclComm = getNCCLComm(key, device, opType, p2pRank, isSendRecvSelf);
if (coalescing_state_ & CoalActive) {
@@ -2661,23 +2690,57 @@
syncStream(device, ncclEvents_[key], ncclStream);
// Work itself will create the CUDA events on all GPUs of tensors
- auto work = initWork(device, rank_, opType, profilingTitle, {tensor}, {});
+ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work;
+ if (coalescing_state_) {
+ // When coalescing, we record events per op that lack timing/state
+ // information becuase there is no 'work' associated with them, and then
+ // later in endCoalescing we record a 'coalesced' Work which has
+ // timing/state updates via watchdog thread, but lacks op metadata such as
+ // input/output sizes and profilingTitle per-op in the group.
+ auto trace_id = NCCLTraceBuffer::get()->record(
+ uid_, seq_, profilingTitle, {tensor}, {tensor}, nullptr, nullptr);
+ // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get
+ // their timings/states updated by proxy when the Work obj representing the
+ // coalesce group gets its update, we could accumulate these trace_ids
+ // together and ask FlightRecorder to take the update from one Work and
+ // apply it to multiple entries
+ (void)trace_id;
+ } else {
+ // Store references to outputs to be used by WorkNCCL::result and
+ // operator<<. Note that these outputs are only valid for recv(), as send()
+ // does not modify the inputs but we still create these outputs for use
+ // cases such as profiling.
- // Store references to outputs to be used by WorkNCCL::result and operator<<.
- // Note that these outputs are only valid for recv(), as send() does not
- // modify the inputs but we still create these outputs for use cases such as
- // profiling.
- work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
- work->outputs_->push_back(tensor);
-
- at::cuda::OptionalCUDAGuard gpuGuard;
-
- // Start event should only be recorded before the ncclGroupStart()
- if (work->timingEnabled_) {
- work->ncclStartEvent_->record(ncclStream);
+ work = initWork(
+ device, rank_, opType, profilingTitle, {tensor}, {}, /*record=*/false);
+ // This bypasses something in Work() that crashes if {tensor} is given as
+ // output, not sure what
+ work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
+ work->outputs_->push_back(tensor);
+ // TODO(whc) becuase we don't pass output {tensor} to initWork, we tell
+ // initWork to not record, and then we manually call record passing all the
+ // information it wants.
+ work->trace_id_ = NCCLTraceBuffer::get()->record(
+ uid_,
+ seq_,
+ profilingTitle,
+ {tensor},
+ {tensor},
+ work->ncclStartEvent_.get(),
+ work->ncclEndEvent_.get());
}
- pre(ncclStream, work);
+ // is gpuGuard needed for the if block below, or can i swap them
+ at::cuda::OptionalCUDAGuard gpuGuard;
+
+ if (!coalescing_state_) {
+ // Start event should only be recorded before the ncclGroupStart()
+ if (work->timingEnabled_) {
+ work->ncclStartEvent_->record(ncclStream);
+ }
+
+ pre(ncclStream, work);
+ }
// Both send tensor and recv tensor are created on a worker stream and used
// in different ncclStreams. Hence, both must record the ncclStream to
@@ -2687,7 +2750,9 @@
c10::cuda::CUDACachingAllocator::recordStream(
tensor.storage().data_ptr(), ncclStream);
+ // This part seems common to both p2p and coalesced-p2p usage?
ncclComm_t comm_ = ncclComm->getNcclComm();
+
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(
fn(tensor, comm_, ncclStream, p2pTargetRank),
@@ -2699,43 +2764,43 @@
ncclComm->getNcclCommFailureReason());
#endif
- post(ncclStream);
-
- // End event should only be recorded after the ncclGroupEnd()
if (!coalescing_state_) {
+ post(ncclStream);
+
+ // End event should only be recorded after the ncclGroupEnd()
work->ncclEndEvent_->record(ncclStream);
- }
- work->ncclComm_ = ncclComm;
- work->blockingWait_ = blockingWait_;
- work->opTimeout_ = options_->timeout;
- work->store_ = store_;
- // Record size info for debug. We only record the size on the first device as
- // multi-device per process is deprecated
- work->numelIn_ = work->numelOut_ = tensor.numel();
+ work->ncclComm_ = ncclComm;
+ work->blockingWait_ = blockingWait_;
+ work->opTimeout_ = options_->timeout;
+ work->store_ = store_;
+ // Record size info for debug. We only record the size on the first device
+ // as multi-device per process is deprecated
+ work->numelIn_ = work->numelOut_ = tensor.numel();
- // Future only needs to be created and marked completed with outputs for
- // recv(), but still create future for use cases such as profiling even for
- // send().
- {
- c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream);
- std::vector<at::Device> devices{device};
- work->future_ = c10::make_intrusive<at::ivalue::Future>(
- c10::ListType::create(c10::TensorType::get()), devices);
- work->future_->markCompleted(at::IValue(*work->outputs_));
- }
+ // Future only needs to be created and marked completed with outputs for
+ // recv(), but still create future for use cases such as profiling even for
+ // send().
+ {
+ c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStream);
+ std::vector<at::Device> devices{device};
+ work->future_ = c10::make_intrusive<at::ivalue::Future>(
+ c10::ListType::create(c10::TensorType::get()), devices);
+ work->future_->markCompleted(at::IValue(*work->outputs_));
+ }
- // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
- // future blocks the stream this callback runs on the corresponding
- // ncclEndEvents_ ensuring appropriate synchronization.
- if (work->recordFunctionEndCallback_) {
- work->future_->addCallback(
- [work](at::ivalue::Future& /* unused */) {
- work->recordFunctionEndCallback_();
- },
- // uses_future = false allows us to skip synchronization in
- // ivalue::Future, but is only valid as long as the lambda doesn't use
- // the "Future" argument.
- /*uses_future=*/false);
+ // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA
+ // future blocks the stream this callback runs on the corresponding
+ // ncclEndEvents_ ensuring appropriate synchronization.
+ if (work->recordFunctionEndCallback_) {
+ work->future_->addCallback(
+ [work](at::ivalue::Future& /* unused */) {
+ work->recordFunctionEndCallback_();
+ },
+ // uses_future = false allows us to skip synchronization in
+ // ivalue::Future, but is only valid as long as the lambda doesn't use
+ // the "Future" argument.
+ /*uses_future=*/false);
+ }
}
// Enqueue P2P op so that it can be cancelled by NCCL watchdog
@@ -2747,11 +2812,11 @@
if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
workEnqueue(work);
+ return work;
} else {
at::cuda::CUDAGraph::dec_pending_event_queries();
+ return nullptr;
}
-
- return work;
}
template <typename Fn>
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
index 6199a1a..8485441 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
@@ -594,13 +594,16 @@
virtual std::exception_ptr checkForNCCLErrors(
std::shared_ptr<NCCLComm>& ncclComm);
+ // Ensure thaht if record is True, the work obj will be enqueued via
+ // workEnqueue
virtual c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
at::Device& device,
int rank,
OpType opType,
const char* profilingTitle = nullptr,
const std::vector<at::Tensor>& inputs = {},
- const std::vector<at::Tensor>& outputs = {});
+ const std::vector<at::Tensor>& outputs = {},
+ bool record = false);
// In the timeout case and we will dump debug info such as the NCCL flight
// recorder to storage. Down the road, if we have more complicated or blocking