blob: 4ab1f5757691f702e59df4b9fed633a14084b3c3 [file] [log] [blame]
#pragma once
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR < 4)
#error "Need NCCL version 2.4+"
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
#error "Need NCCL version 2.4+"
#endif
#include <nccl.h>
#include <memory>
#define C10D_NCCL_CHECK(cmd) \
do { \
ncclResult_t error = cmd; \
if (error != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + \
std::string(ncclGetErrorString(error)); \
throw std::runtime_error(err); \
} \
} while (0)
namespace c10d {
// RAII wrapper for NCCL communicator
class NCCLComm {
public:
explicit NCCLComm(ncclComm_t ncclComm)
: ncclComm_(ncclComm), aborted_(false), ncclAsyncErr_(ncclSuccess) {}
NCCLComm() : NCCLComm(nullptr) {}
~NCCLComm() noexcept(false) {
if (ncclComm_ && !aborted_) {
// Use ncclCommAbort instead of ncclCommDestroy here since ncclCommDestroy
// could block forever waiting for work to complete on the communicator.
ncclCommAbort();
}
}
static std::shared_ptr<NCCLComm> create(
int numRanks,
int rank,
ncclUniqueId commId) {
auto comm = std::make_shared<NCCLComm>();
C10D_NCCL_CHECK(
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank));
return comm;
}
// Must not be copyable
NCCLComm(const NCCLComm&) = delete;
NCCLComm& operator=(const NCCLComm&) = delete;
// Move constructable
NCCLComm(NCCLComm&& other) {
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
}
// Move assignable
NCCLComm& operator=(NCCLComm&& other) {
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
return *this;
}
ncclComm_t getNcclComm() {
if (aborted_) {
throw std::runtime_error("NCCL communicator was aborted.");
}
return ncclComm_;
}
void ncclCommAbort() {
if (aborted_) {
// Should not abort twice.
return;
}
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_));
aborted_ = true;
ncclComm_ = nullptr;
// Set an appropriate error so that we avoid using the communicator.
if (ncclAsyncErr_ == ncclSuccess) {
ncclAsyncErr_ = ncclSystemError;
}
}
bool isAborted() const {
return aborted_;
}
ncclResult_t checkForNcclError() {
if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_;
}
C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_));
return ncclAsyncErr_;
}
protected:
ncclComm_t ncclComm_;
bool aborted_;
ncclResult_t ncclAsyncErr_;
};
} // namespace c10d