blob: 0ebda30e2fe3a4085f51a7897d7fda0fe02e7ce7 [file] [log] [blame]
#pragma once
#include <mutex>
#include <thread>
#include <unordered_map>
#include <c10d/NCCLUtils.hpp>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <ATen/Parallel.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
namespace c10d {
// Environment variable which controls whether or not wait() is blocking or
// non-blocking.
constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
// ProcessGroupNCCL implements NCCL bindings for c10d.
//
// All functions of the class are expected to be called in the same order
// across all processes in the process group. This is the only way that we
// can guarantee to match up the same calls among all processes.
//
// All NCCL functions provided by this class are asynchronous functions. More
// specifically, each NCCL call is scheduled on a separate CUDA stream that is
// different from the current CUDA stream. This is for the purpose of
// achieving potentially concurrency and better performance. As a result,
// it is the callers' responsibility to make sure that the CUDA stream their
// code works on needs to wait for the NCCL operation from
// this class.
//
// This can be done by calling:
//
// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
// functionality and are synonyms.
//
// Also note that WorkNCCL::finishedGPUExecution() is a helper function only
// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
// finished execution on the GPU (not just scheduled).
//
// Example on using the NCCL process group
//
// ProcessGroupNCCL pg(store, rank, size);
// std::shared_ptr<WorkNCCL> work = pg.allreduce(tensors);
//
// // At this point, NCCL kernel has already by queued successfully
// // Now, let current stream wait for the NCCL to finish, this function is
// // async operation as well
//
// work->wait()
//
// // Now continue on other work in the current stream.
class ProcessGroupNCCL : public ProcessGroup {
public:
class WorkNCCL : public ProcessGroup::Work {
public:
// Constructor takes a list of CUDA devices
WorkNCCL(const std::vector<at::Device>& devices);
virtual ~WorkNCCL();
// Checks if request has completed. In this specific case of NCCL, it checks
// if the NCCL operation has completed on the GPU in its own NCCL stream.
// Non-blocking operation.
bool isCompleted() override;
bool isSuccess() const override;
// Same as calling synchronize() for NCCL work.
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
void abort() override;
// Let current stream wait on the completing of the NCCL work
// Throws on exceptions. Blocking operation, which will wait for work
// completion.
void synchronize() override;
// Synchronize streams by blocking each on the NCCL stream
void synchronizeStreams();
// Helper function that checks if the NCCL kernels have finished
// execution on the GPUs
bool finishedGPUExecution();
// Get a Future object that will be marked as completed internally.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
protected:
// The cached list of CUDA devices to operate on
std::vector<at::Device> devices_;
// The CUDA events tracking this work item on multiple CUDA devices
std::vector<at::cuda::CUDAEvent> cudaEvents_;
// The NCCL communicators used for this work item.
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
// Tensors used for barrier op
std::vector<at::Tensor> barrierTensors_;
// Clone of blockingWait_ from ProcessGroupNCCL.
bool blockingWait_ = false;
// Clone of opTimeout_ from ProcessGroupNCCL.
std::chrono::milliseconds opTimeout_;
// Time point representing when the work started.
std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
// Wrapper method for the static checkForNCCLErrors which can be overridden
// for tests.
virtual std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const;
private:
// Helper function for synchronize
void synchronizeInternal(std::chrono::milliseconds timeout);
// Checks for NCCL errors and sets an appropriate exception_ptr.
void checkAndSetException();
// Checks for NCCL errors and throws an appropriate exception.
void checkAndThrowException();
// Just checks whether GPU execution has completed, without modifying
// exception_ptr.
bool finishedGPUExecutionInternal() const;
// Reference to the store so that we can write aborted communicators
// to the store.
std::shared_ptr<Store> store_;
// Store a Future work associated with WorkNCCL.
c10::intrusive_ptr<c10::ivalue::Future> futureWork_;
friend class ProcessGroupNCCL;
};
// CheckFutureWork is used by a cudaStremCallback function called
// ncclKernelCompletionCallback. The purpose of CheckFutureWork is to mark
// workNCCL's Future work when, all NCCL streams created in collective
// operation are finished. We use enable_shared_from_this to safely call
// shared_ptr<CheckFutureWork> (this) to remove its pointer from
// ProcessGroupNCCL's checkFutObjs after at::launch operation of
// markFutureCompleted is done. markFutureCompleted increments streamCounter
// by one and once it is called by number of replicas times, it marks Future
// work completed by setting outputs vector as its value.
struct CheckFutureWork : std::enable_shared_from_this<CheckFutureWork> {
public:
CheckFutureWork(
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> work,
std::vector<at::Tensor>& outputs,
std::shared_ptr<std::unordered_set<std::shared_ptr<CheckFutureWork>>>
checkFutObjs,
std::mutex& checkFutObjMutex_)
: work_(work),
outputs_(std::make_shared<std::vector<at::Tensor>>(outputs)),
streamCounter_(0),
checkFutObjs_(checkFutObjs),
checkFutObjMutex_(checkFutObjMutex_) {}
std::shared_ptr<CheckFutureWork> getPtr() {
return shared_from_this();
}
void markFutureCompleted() {
// Passing `this` here as the capture is fine because we store a
// `shared_ptr` to `this` in ProcessGroupNCCL's `checkFutObjs_` to
// ensure that the object is alive when the lambda function is called.
at::launch(([this]() {
if (++streamCounter_ == (*outputs_).size()) {
// Need to synchronize before passing outputs to Future because
// operations using the outputs might be running on different streams
work_->synchronizeStreams();
TORCH_CHECK(
!work_->getFuture()->completed(),
"Future work of workNCCL can only be marked as "
"completed by ncclKernelCompletionCallback.")
work_->getFuture()->markCompleted(at::IValue(*outputs_));
std::unique_lock<std::mutex> lock(checkFutObjMutex_);
checkFutObjs_->erase(getPtr());
}
}));
}
private:
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> work_;
std::shared_ptr<std::vector<at::Tensor>> outputs_;
std::atomic<int> streamCounter_;
std::shared_ptr<std::unordered_set<std::shared_ptr<CheckFutureWork>>>
checkFutObjs_;
std::mutex& checkFutObjMutex_;
};
// If you wish to create multiple process groups, each with a potentially
// different rank and size, you can do so by passing a new store instance
// to each one. If you have only a single store object, you can
// use the `c10d::PrefixStore` to derive scoped instances.
// This is also what the Python API in torch.distributed does.
//
// The process group instance keeps a reference to the store because
// it may be used long after the constructor runs. In fact, the constructor
// doesn't create any NCCL communicators. A single NCCL communicator can
// only be used on a specific set of devices, and are therefore created
// on-demand when a collective runs. If another collective is executed later,
// against a different set of devices, the process group creates another NCCL
// communicator. These NCCL communicators are cached and reused if possible.
//
ProcessGroupNCCL(
const std::shared_ptr<Store>& store,
int rank,
int size,
const std::chrono::milliseconds& opTimeout =
std::chrono::milliseconds(kProcessGroupNCCLOpTimeoutMillis));
// This constructor includes the deprecated `groupName` argument.
// If you have existing code that uses the `groupName`, you can replace
// it by specifying a `c10d::PrefixStore(groupName, store)` for store.
C10_DEPRECATED ProcessGroupNCCL(
const std::shared_ptr<Store>& store,
int rank,
int size,
const std::string& groupName,
const std::chrono::milliseconds& opTimeout =
std::chrono::milliseconds(kProcessGroupNCCLOpTimeoutMillis))
: ProcessGroupNCCL(store, rank, size, opTimeout) {}
virtual ~ProcessGroupNCCL();
std::shared_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
std::shared_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
std::shared_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
std::shared_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
std::shared_ptr<ProcessGroup::Work> allgather_base(
at::Tensor& outputbuffer,
at::Tensor& inputbuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
std::shared_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
std::shared_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
// Unsupported Ops
std::shared_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
std::shared_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
std::shared_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
std::shared_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
std::shared_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
static const int64_t kProcessGroupNCCLOpTimeoutMillis;
protected:
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLID(ncclUniqueId* ncclID);
// Helper that either looks up the cached NCCL communicators or creates
// a new set of NCCL communicators as a cache entry
std::vector<std::shared_ptr<NCCLComm>>& getNCCLComm(
const std::string& devicesKey,
const std::vector<at::Device>& devices);
// Wrapper method which can be overridden for tests.
virtual std::exception_ptr checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
virtual std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
std::vector<at::Device> devices);
private:
// Helper that encapsulates work shared across all collective communication
// primitives. The callbacks have the following signatures:
//
// ncclResult_t fn(at::Tensor& input, at::Tensor& output,
// ncclComm_t, at::cuda::CUDAStream&);
// void {pre,post}(std::vector<at::cuda::CUDAStream&>);
template <typename Fn>
std::shared_ptr<ProcessGroup::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn);
template <typename Fn, typename PreProcess, typename PostProcess>
std::shared_ptr<ProcessGroup::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
PreProcess pre,
PostProcess post);
// Checks for NCCL errors on each of the communicators and returns an
// appropriate exception_ptr (nullptr if no errors).
static std::exception_ptr checkForNCCLErrorsInternal(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
// Function that runs as part of a separate thread and checks for errors on
// NCCL communicators. We need a separate thread to check for NCCL errors
// since we can't rely on the user calling certain methods like wait(),
// isCompleted() etc. to detect and remediate errors. In addition to this, we
// need a mechanism to safely abort and remove NCCL communicators from our
// cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
// class. Attempting to modify the communicator cache from the WorkNCCL class
// might run into issues with object lifetime since the ProcessGroupNCCL
// object might get destroyed before the WorkNCCL object.
void ncclCommWatchdog();
void ncclCommWatchdogInternal();
// Reads the NCCL_BLOCKING_WAIT environment variable and sets blockingWait_
// accordingly.
void parseNcclBlockingWait();
protected:
static const int64_t kWatchdogThreadSleepMillis;
// The store is used to broadcast the NCCL unique ID of rank 0.
std::shared_ptr<Store> store_;
// The number of NCCL communicators that have been created during
// the lifetime of this process group. This sequence number is
// used to scope keys used in the store.
uint64_t ncclCommCounter_{0};
// The NCCL communicator that the process group has cached.
// The key is a list of GPU devices that an operation is operating on
// The GPU devices are stored in a device sequence and the cache NCCL
// communicator is associated with this GPU device sequence
//
// e.g. If the process group op only uses device 0, then the value of
// the used device string stored (value of the hashmap) would be "0".
//
// If the process group op uses device 0 - 7 and the each tensor of the
// input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately,
// then the value of the used device string (key) stored would be
// "0,1,2,3,4,5,6,7"
//
// If the process group op uses device 0 - 7 and the each tensor of the
// input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately,
// then the value of the used device string stored would be
// "0,4,5,6,7,1,2,3"
//
// Note that the order of the device for the tensor list matters.
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
devNCCLCommMap_;
// Map from ncclUniqueId to appropriate communicator.
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
ncclIdToCommMap_;
// Mutex to guard maps like devNCCLCommMap_ and ncclIdToCommMap_.
std::mutex mutex_;
// Watchdog thread which looks for errors on the cached NCCL communicators.
std::thread ncclCommWatchdogThread_;
// Whether or not we should terminate the watchdog thread.
std::atomic<bool> terminateWatchdog_;
// Condition variable to control how long the watchdog thread waits.
std::condition_variable watchdogCV_;
// Mutex for watchdog.
std::mutex watchdogCVMutex_;
// The CUDA steams used by NCCL kernels
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
ncclStreams_;
// The CUDA events used to sync NCCL streams
std::unordered_map<std::string, std::vector<at::cuda::CUDAEvent>> ncclEvents_;
// Device Indexes used for all collectives in this group
std::set<int> usedDeviceIdxs_;
// map from the key: "group name + pg counter (ID)" to the
// unique NCCL ID count. This needs to be group and pg specific
//
// For each process group, we need a uniform unique NCCL ID counter to ensure
// that NCCL operation in this process group can be completed successfully.
// Since each process group ID belongs to a group name, the key to this map
// is a combination of group name and ProcessGroupNCCL ID.
static std::unordered_map<std::string, ssize_t> pgUniqueNCCLIDCnt_;
// map from group name to the pg counter (ID) within that group
//
// For each group with the "group name" (which is the key), we need to
// keep track of a unique process group ID when creating a new
// ProcessGroupNCCL for this "group name". Therefore, the value of this
// map keeps the unique ProcessGroupNCCL's ID for a specific group with
// the "group name". The reason we need a per-group process group ID counter
// is that different group can have different ranks and we need ensure that
// each group has its own uniform process group ID for all its ranks.
static std::unordered_map<std::string, ssize_t> processGroupCounterMap_;
// Whether or not wait() and synchronize() are blocking operations that wait
// for the operation to complete.
bool blockingWait_ = false;
// Timeout for operations. This is only used when blockingWait_ is enabled.
std::chrono::milliseconds opTimeout_;
// Set of communicators that this process group has aborted and their
// ncclUniqueId has been written to the store. We don't need a lock
// for this map since only the watchdog thread accesses this set. The
// set contains the string representation of ncclUniqueId.
std::unordered_set<std::string> abortedComms_;
// The set of CheckFutureWork pointers ensures that the CheckFutureWork object
// is not deleted before the at::launch operation inside CheckFutureWork's
// markFutureCompleted called by cudaStremCallback function.
std::shared_ptr<std::unordered_set<std::shared_ptr<CheckFutureWork>>>
checkFutObjs_;
// Mutex to guard the unordered set checkFutObjs_.
mutable std::mutex checkFutObjMutex_;
};
} // namespace c10d