blob: 0e4a7df03dc722d79069c001e576263a4f4899c1 [file] [log] [blame]
#include <c10d/NCCLUtils.hpp>
#include <c10/util/CallOnce.h>
#ifdef USE_C10D_NCCL
#include <mutex>
namespace c10d {
ncclComm_t NCCLComm::getNcclComm() {
std::unique_lock<std::mutex> lock(mutex_);
if (aborted_) {
auto commFailureMsg = commFailureReason_ != c10::nullopt
? c10::str(" Original reason for failure was: ", *commFailureReason_)
: "";
TORCH_CHECK(
false,
c10::str(
"NCCL communicator was aborted on rank ",
rank_,
". ",
commFailureMsg));
}
return ncclComm_;
}
std::string getNcclVersion() {
static c10::once_flag ncclGetVersionFlag;
static std::string versionString;
c10::call_once(ncclGetVersionFlag, []() {
int version;
ncclResult_t status = ncclGetVersion(&version);
// can't compute the version if call did not return successfully or version
// code < 100 (corresponding to 0.1.0)
if (status != ncclSuccess || version < 100) {
versionString = "Unknown NCCL version";
} else {
// NCCL changed version coding starting 2.9
const int majorBase = version < 2900 ? 1000 : 10000;
const int minorBase = 100;
auto ncclMajor = version / majorBase;
auto ncclMinor = (version % majorBase) / minorBase;
auto ncclPatch =
version % (ncclMajor * majorBase + ncclMinor * minorBase);
versionString = std::to_string(ncclMajor) + "." +
std::to_string(ncclMinor) + "." + std::to_string(ncclPatch);
}
});
return versionString;
}
std::string ncclGetErrorWithVersion(ncclResult_t error) {
return std::string(ncclGetErrorString(error)) + ", NCCL version " +
getNcclVersion();
}
} // namespace c10d
#endif // USE_C10D_NCCL