blob: 3b52616ee3fa6850f9b877174eb8b375a643167c [file] [log] [blame]
#pragma once
#include <iostream>
#include <list>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <c10d/NCCLUtils.hpp>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include <c10/core/Stream.h>
#include <c10/core/StreamGuard.h>
namespace c10d {
// Environment variable which controls whether or not wait() is blocking or
// non-blocking.
constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
// Environment variable which controls whether or not we perform Async Error
// Handling with NCCL.
constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING";
// ProcessGroupNCCL implements NCCL bindings for c10d.
//
// All functions of the class are expected to be called in the same order
// across all processes in the process group. This is the only way that we
// can guarantee to match up the same calls among all processes.
//
// All NCCL functions provided by this class are asynchronous functions. More
// specifically, each NCCL call is scheduled on a separate CUDA stream that is
// different from the current CUDA stream. This is for the purpose of
// achieving potentially concurrency and better performance. As a result,
// it is the callers' responsibility to make sure that the CUDA stream their
// code works on needs to wait for the NCCL operation from
// this class.
//
// This can be done by calling:
//
// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
// functionality and are synonyms.
//
// Also note that WorkNCCL::finishedGPUExecution() is a helper function only
// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
// finished execution on the GPU (not just scheduled).
//
// Example on using the NCCL process group
//
// ProcessGroupNCCL pg(store, rank, size);
// std::shared_ptr<WorkNCCL> work = pg.allreduce(tensors);
//
// // At this point, NCCL kernel has already by queued successfully
// // Now, let current stream wait for the NCCL to finish, this function is
// // async operation as well
//
// work->wait()
//
// // Now continue on other work in the current stream.
class ProcessGroupNCCL : public ProcessGroup {
public:
class WorkNCCL : public ProcessGroup::Work,
public std::enable_shared_from_this<WorkNCCL> {
public:
// Constructor takes a list of CUDA devices
WorkNCCL(const std::vector<at::Device>& devices, int rank, OpType opType);
// Copy constructor doing partial copy without outputs_. Cleanup thread
// monitors and removes finished works. However it will deadlock when
// destructs outputs_ tensors who are view tensors in autograd graph.
WorkNCCL(const WorkNCCL& w);
virtual ~WorkNCCL();
// Checks if request has completed. In this specific case of NCCL, it checks
// if the NCCL operation has completed on the GPU in its own NCCL stream.
// Non-blocking operation.
bool isCompleted() override;
bool isSuccess() const override;
// Same as calling synchronize() for NCCL work.
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
void abort() override;
// Let current stream wait on the completing of the NCCL work
// Throws on exceptions. Blocking operation, which will wait for work
// completion.
void synchronize() override;
// Synchronize streams by blocking each on the NCCL stream
void synchronizeStreams();
// Helper function used in CUDA Stream callbacks to complete WorkNCCL
// objects and throw exceptions when neeeded.
void handleNCCLGuard();
// Helper function that checks if the NCCL kernels have finished
// execution on the GPUs
bool finishedGPUExecution();
// Get a Future object that will be marked as completed internally.
// It actually returns a FutureNCCL object which is a sub class Future.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
// Helper function that sets an exception_ptr on the WorkNCCL object.
void setException(std::exception_ptr exception_ptr);
// Helper function that returns True if the WorkNCCL object has timed out
// and False otherwise.
bool timedOut();
std::vector<at::Tensor> result() override;
protected:
// The cached list of CUDA devices to operate on
std::vector<at::Device> devices_;
// The CUDA events tracking this work item on multiple CUDA devices
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
// The NCCL communicators used for this work item.
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
// Tensors used for barrier op
std::vector<at::Tensor> barrierTensors_;
// Clone of blockingWait_ from ProcessGroupNCCL.
bool blockingWait_ = false;
// Clone of opTimeout_ from ProcessGroupNCCL.
std::chrono::milliseconds opTimeout_;
// Time point representing when the work started.
std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
// Wrapper method for the static checkForNCCLErrors which can be overridden
// for tests.
virtual std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const;
friend std::ostream& operator<<(
std::ostream& output,
const WorkNCCL& workNCCL);
private:
// Helper function for synchronize
void synchronizeInternal(std::chrono::milliseconds timeout);
// Checks for NCCL errors and sets an appropriate exception_ptr.
void checkAndSetException();
// Checks for NCCL errors and throws an appropriate exception.
void checkAndThrowException();
// Just checks whether GPU execution has completed, without modifying
// exception_ptr.
bool finishedGPUExecutionInternal() const;
// Reference to the store so that we can write aborted communicators
// to the store.
std::shared_ptr<Store> store_;
// Store a reference to NCCL collective's outputs to be used by getFuture.
std::shared_ptr<std::vector<at::Tensor>> outputs_;
// Store streams that run FutureNCCL then callbacks.
std::vector<std::shared_ptr<at::cuda::CUDAStream>>
futureNCCLCallbackStreams_;
friend class ProcessGroupNCCL;
};
struct Options {
explicit Options();
std::chrono::milliseconds opTimeout;
bool isHighPriorityStream;
};
// FutureNCCL is a subclass of ivalue's Future. The goal is to use
// this class in getFuture API of WorkNCCL. This Future is mostly a
// wrapper to synchronize streams appropriately and it mostly enables
// the async programming model of CUDA while trying to adhere to the
// Future interface. FutureNCCL does not support NCCL_BLOCKING_WAIT flag
// or NCCL's barrier().
//
// If created by WorkNCCL's getFuture API, FutureNCCL has a reference to
// WorkNCCL's cudaEvents, NCCL collective's outputs, device index of
// outputs' device, and the ProcesGroupNCCL's dedicated
// futureNCCLCallbackStream for outputs' device that runs all the then
// callbacks called from this FutureNCCL. Its value is NCCL collective's
// outputs. FutureNCCL only supports single-process single-device mode where
// the size of outputs is equal to 1.
//
// If created by FutureNCCL's then callback, its value becomes the value of
// callback() and its cudaEvents will record the NCCL stream that runs that
// callback. Before invoking the callback, FutureNCCL will synchronize its
// own cudaEvents with the stream that runs the callback. This design
// enables synchronizing the appropriate streams and avoids stalling PyTorch's
// default stream while running the callback. In case of multiple then
// callbacks, the design will work like a chain such that FutureNCCL n will
// wait on the cudaEvents from FutureNCCL n - 1. All callbacks are executed on
// outputs' device's dedicated futureNCCLCallbackStream.
struct FutureNCCL : at::ivalue::Future {
public:
explicit FutureNCCL(
at::IValue value,
c10::DeviceIndex deviceIndex,
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents,
std::shared_ptr<at::cuda::CUDAStream> futureNCCLCallbackStream)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
value_(std::move(value)),
deviceIndex_(deviceIndex),
cudaEvents_(cudaEvents),
futureNCCLCallbackStream_(futureNCCLCallbackStream) {
TORCH_INTERNAL_ASSERT(
cudaEvents_->size() == 1,
"FutureNCCL only supports single-process single-device mode.");
}
// This constructor is used by then callback, it skips setting the value at
// the beginning. Later, the value will be set using markCompleted with the
// return value of callback.
explicit FutureNCCL(
c10::DeviceIndex deviceIndex,
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents,
std::shared_ptr<at::cuda::CUDAStream> futureNCCLCallbackStream)
: at::ivalue::Future(c10::ListType::create(c10::TensorType::get())),
deviceIndex_(deviceIndex),
cudaEvents_(cudaEvents),
futureNCCLCallbackStream_(futureNCCLCallbackStream) {
TORCH_INTERNAL_ASSERT(
cudaEvents_->size() == 1,
"FutureNCCL only supports single-process single-device mode.");
}
// Gets the current stream of the device and synchronizes recorded streams
// with that. It will return after synchronizing the correct GPU streams to
// ensure we can have async CUDA execution and it does not wait for the
// entire operation to complete on GPU.
void wait() override {
if (error_) {
throw *error_;
}
auto stream = at::cuda::getCurrentCUDAStream(deviceIndex_);
(*cudaEvents_)[0].block(stream);
}
// If FutureNCCL was created by FutureNCCL::then, its value would be empty
// initially. FutureNCCL::then will later use this method to set its value
// to the return value of the callback.
void markCompleted(at::IValue value) override {
TORCH_INTERNAL_ASSERT(
value_.isNone(),
"Attempting to set value of a FutureNCCL which has a value."
"FutureNCCL's value was internally set to NCCL collective's "
"outputs or the return value of the callback.");
value_ = std::move(value);
}
// Just returns FutureNCCL's value after wait returns.
at::IValue value() override {
TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.")
wait();
return value_;
}
const at::IValue& constValue() override {
TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.")
wait();
return value_;
}
// Adds a callback to FutureNCCL. It invokes the callback inline after
// synchronizing FutureNCCL's own cudaEvents with the stream that runs
// this callback. This new FutureNCCL's cudaEvents will record the
// callback's stream and will have the result value of the callback.
void addCallback(std::function<void(void)> callback) override {
(*cudaEvents_)[0].block(*futureNCCLCallbackStream_);
c10::OptionalStreamGuard streamGuard{
c10::Stream(*futureNCCLCallbackStream_)};
callback();
}
// Adds a callback to FutureNCCL, and returns another FutureNCCL to hold
// the return value of the callback and new cudaEvents that recorded the
// stream that runs this callback.
c10::intrusive_ptr<Future> then(
std::function<at::IValue(void)> callback,
at::TypePtr /* unused */) override {
// Create a new cudaEvents object of size 1 that will record
// futureNCCLCallbackStream_ after callback and will be passed to the new
// FutureNCCL.
auto thenFutCudaEvents =
std::make_shared<std::vector<at::cuda::CUDAEvent>>(1);
// Create a FutureNCCL without setting a value.
auto fut = c10::make_intrusive<FutureNCCL>(
deviceIndex_, thenFutCudaEvents, futureNCCLCallbackStream_);
// Do not free the underlying data storage of value_ before its
// usage on futureNCCLCallbackStream_ finish.
TORCH_INTERNAL_ASSERT(record_stream_cb_);
record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap());
// Use the dedicated callback stream to run callback.
// Cannot move capture std::function in lambda, because it cannot deduce
// the template type for std::function. Hence use std::bind to explicitly
// specify types.
addCallback(std::bind(
[&](std::function<at::IValue(void)> cb) {
try {
fut->markCompleted(at::IValue(cb()));
// In case of chained then callback calls, thenFutCudaEvents
// records callback's stream.
(*thenFutCudaEvents)[0].record(*futureNCCLCallbackStream_);
} catch (const std::exception& e) {
fut->setError(std::current_exception());
}
},
std::move(callback)));
return fut;
}
// Checks cudaEventQuery with cudaEvents. Returns true if a FutureError was
// recorded or the entire operation is completed on the GPU.
bool completed() const override {
if (error_) {
return true;
}
// Checking the work's corresponding CUDA events' status
auto ret = cudaEventQuery((*cudaEvents_)[0]);
return ret != cudaErrorNotReady || ret == cudaSuccess;
}
bool hasValue() const override {
return !value_.isNone();
}
void setRecordStreamCallback(
std::function<void(const at::IValue&, const c10::Stream&)>
record_stream_cb) override {
record_stream_cb_ = std::move(record_stream_cb);
}
private:
at::IValue value_;
c10::DeviceIndex deviceIndex_;
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
std::shared_ptr<at::cuda::CUDAStream> futureNCCLCallbackStream_;
std::function<void(const at::IValue&, const c10::Stream&)>
record_stream_cb_;
c10::optional<FutureError> error_;
};
// If you wish to create multiple process groups, each with a potentially
// different rank and size, you can do so by passing a new store instance
// to each one. If you have only a single store object, you can
// use the `c10d::PrefixStore` to derive scoped instances.
// This is also what the Python API in torch.distributed does.
//
// The process group instance keeps a reference to the store because
// it may be used long after the constructor runs. In fact, the constructor
// doesn't create any NCCL communicators. A single NCCL communicator can
// only be used on a specific set of devices, and are therefore created
// on-demand when a collective runs. If another collective is executed later,
// against a different set of devices, the process group creates another NCCL
// communicator. These NCCL communicators are cached and reused if possible.
//
ProcessGroupNCCL(
const std::shared_ptr<Store>& store,
int rank,
int size,
Options options = Options());
// This constructor includes the deprecated `groupName` argument.
// If you have existing code that uses the `groupName`, you can replace
// it by specifying a `c10d::PrefixStore(groupName, store)` for store.
C10_DEPRECATED ProcessGroupNCCL(
const std::shared_ptr<Store>& store,
int rank,
int size,
const std::string& groupName,
Options options = Options())
: ProcessGroupNCCL(store, rank, size, options) {}
virtual ~ProcessGroupNCCL();
std::shared_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
std::shared_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
std::shared_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
std::shared_ptr<ProcessGroup::Work> allgather_base(
at::Tensor& outputbuffer,
at::Tensor& inputbuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
std::shared_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
std::shared_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
std::shared_ptr<ProcessGroup::Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;
std::shared_ptr<ProcessGroup::Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) override;
std::shared_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
std::shared_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
static void groupStart();
static void groupEnd();
// Unsupported Ops
std::shared_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
std::shared_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
std::shared_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
static const int64_t kProcessGroupNCCLOpTimeoutMillis;
protected:
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLID(
ncclUniqueId* ncclID,
OpType opType,
const std::string& devicesKey,
int p2pRank);
// Helper that either looks up the cached NCCL communicators or creates
// a new set of NCCL communicators as a cache entry
std::vector<std::shared_ptr<NCCLComm>>& getNCCLComm(
const std::string& devicesKey,
const std::vector<at::Device>& devices,
OpType opType,
int p2pRank = 0,
bool isSendRecvSelf = false);
// Wrapper method which can be overridden for tests.
virtual std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
virtual std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices,
int rank,
OpType opType);
private:
// Helper that encapsulates work shared across all collective communication
// primitives. The callbacks have the following signatures:
//
// ncclResult_t fn(at::Tensor& input, at::Tensor& output,
// ncclComm_t, at::cuda::CUDAStream&);
// void {pre,post}(std::vector<at::cuda::CUDAStream&>);
template <typename Fn>
std::shared_ptr<ProcessGroup::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
OpType opType);
template <typename Fn, typename PreProcess, typename PostProcess>
std::shared_ptr<ProcessGroup::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
PreProcess pre,
PostProcess post,
OpType opType);
// Helper that encapsulates work shared across point-to-point communication
// primitives. It is the same structure as the helper used for collective
// communicaiton primitives.
template <typename Fn>
std::shared_ptr<ProcessGroup::Work> pointToPoint(
std::vector<at::Tensor>& tensor,
Fn fn,
int peer,
OpType opType);
template <typename Fn, typename PreProcess, typename PostProcess>
std::shared_ptr<ProcessGroup::Work> pointToPoint(
std::vector<at::Tensor>& tensor,
Fn fn,
int peer,
OpType opType,
PreProcess pre,
PostProcess post);
// Checks for NCCL errors on each of the communicators and returns an
// appropriate exception_ptr (nullptr if no errors).
static std::exception_ptr checkForNCCLErrorsInternal(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
// Function that runs as part of a separate thread and checks for errors on
// NCCL communicators. We need a separate thread to check for NCCL errors
// since we can't rely on the user calling certain methods like wait(),
// isCompleted() etc. to detect and remediate errors. In addition to this, we
// need a mechanism to safely abort and remove NCCL communicators from our
// cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
// class. Attempting to modify the communicator cache from the WorkNCCL class
// might run into issues with object lifetime since the ProcessGroupNCCL
// object might get destroyed before the WorkNCCL object.
void ncclCommWatchdog();
void ncclCommWatchdogInternal();
// This function iterates through the list of WorkNCCL objects in the
// workList_ corresponding to incomplete collectives and then aborts NCCL
// communicators associated with timed out collectives.
void abortTimedOutCollectives(std::unordered_set<std::string>& abortedCommIds);
void workCleanupLoop();
protected:
static const int64_t kWatchdogThreadSleepMillis;
static const int64_t kWorkCleanupThreadSleepMillis;
// The store is used to broadcast the NCCL unique ID of rank 0.
std::shared_ptr<Store> store_;
// The number of NCCL communicators that have been created during
// the lifetime of this process group. This sequence number is
// used to scope keys used in the store.
uint64_t ncclCommCounter_{0};
// The NCCL communicator that the process group has cached.
//
// For collective operations:
// The key is a list of GPU devices that an operation is operating on
// The GPU devices are stored in a device sequence and the cache NCCL
// communicator is associated with this GPU device sequence
//
// e.g. If the process group op only uses device 0, then the value of
// the used device string stored (value of the hashmap) would be "0".
//
// If the process group op uses device 0 - 7 and the each tensor of the
// input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately,
// then the value of the used device string (key) stored would be
// "0,1,2,3,4,5,6,7"
//
// If the process group op uses device 0 - 7 and the each tensor of the
// input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately,
// then the value of the used device string stored would be
// "0,4,5,6,7,1,2,3"
//
// Note that the order of the device for the tensor list matters.
//
// For point-to-point operations:
// The key is a string of my current rank and the peer process rank.
// e.g. If process 1 and process 2 are involved in a point-to-point
// communication, the key will be "1:2" on both processes. Note: this is for
// the scenario where there is only 1 GPU per process. When it comes to
// multiple GPUs per process, this part may need to redesigned.
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
devNCCLCommMap_;
// Map from ncclUniqueId to appropriate communicator.
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
ncclIdToCommMap_;
// Mutex to guard maps like devNCCLCommMap_ and ncclIdToCommMap_.
std::mutex mutex_;
// Watchdog thread which looks for errors on the cached NCCL communicators.
std::thread ncclCommWatchdogThread_;
// Whether or not we should terminate the watchdog and workCleanup threads.
std::atomic<bool> terminateProcessGroup_;
// Condition variable to control how long the watchdog thread waits.
std::condition_variable watchdogCV_;
// Mutex for watchdog.
std::mutex watchdogCVMutex_;
// Thread that removes NCCL Work upon timeout
std::thread workCleanupThread_;
// Mutex to Guard workMetaList_
std::mutex workMetaListMutex_;
// Condition Variable for timeout thread sleep
std::condition_variable workMetaListCV_;
// Vector to Store WorkNCCL pointers
std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
// Add Work Pointer to workVector
void workEnqueue(std::shared_ptr<ProcessGroupNCCL::WorkNCCL>);
// The CUDA steams used by NCCL kernels
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
ncclStreams_;
// The CUDA events used to sync NCCL streams
std::unordered_map<std::string, std::vector<at::cuda::CUDAEvent>> ncclEvents_;
// Device Indexes used for all collectives in this group
std::set<int> usedDeviceIdxs_;
// map from the key: "group name + pg counter (ID)" to the
// unique NCCL ID count. This needs to be group and pg specific
//
// For each process group, we need a uniform unique NCCL ID counter to ensure
// that NCCL operation in this process group can be completed successfully.
// Since each process group ID belongs to a group name, the key to this map
// is a combination of group name and ProcessGroupNCCL ID.
static std::unordered_map<std::string, ssize_t> pgUniqueNCCLIDCnt_;
// map from group name to the pg counter (ID) within that group
//
// For each group with the "group name" (which is the key), we need to
// keep track of a unique process group ID when creating a new
// ProcessGroupNCCL for this "group name". Therefore, the value of this
// map keeps the unique ProcessGroupNCCL's ID for a specific group with
// the "group name". The reason we need a per-group process group ID counter
// is that different group can have different ranks and we need ensure that
// each group has its own uniform process group ID for all its ranks.
static std::unordered_map<std::string, ssize_t> processGroupCounterMap_;
// Whether or not wait() and synchronize() are blocking operations that wait
// for the operation to complete.
bool blockingWait_ = false;
// Whether ot not the workCleanupThread is used to perform async error
// handling.
bool asyncErrorHandling_ = false;
// Timeout for operations. This is only used when blockingWait_ is enabled.
std::chrono::milliseconds opTimeout_;
// Set of communicators that this process group has aborted and their
// ncclUniqueId has been written to the store. We don't need a lock
// for this map since only the watchdog thread accesses this set. The
// set contains the string representation of ncclUniqueId.
std::unordered_set<std::string> abortedComms_;
// In single-process single-device mode, WorkNCCL::getFuture is supported.
// Depending on the device index of collective outputs, WorkNCCL will pass
// the corresponding device's then callback stream to FutureNCCL.
// We just inititalize futureNCCLCallbackStreams_ inside the constructor and
// set its size to the total number of available devices and depending on the
// device of the NCCL collective's outputs, we later set the callback stream
// of the corresponding device inside ProcessGroupNCCL::getNCCLComm if not set
// before.
std::vector<std::shared_ptr<at::cuda::CUDAStream>> futureNCCLCallbackStreams_;
// Schedule NCCL operations on high priority CUDA streams.
bool isHighPriorityStream_ = false;
// The number of active ncclGroupStart() calls. This counter will be increased
// by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd()
// is called.
static thread_local uint64_t ncclActiveGroupCounter_;
};
} // namespace c10d