| #include <c10d/NCCLUtils.hpp> |
| |
| #include <c10/util/CallOnce.h> |
| |
| #ifdef USE_C10D_NCCL |
| |
| #include <mutex> |
| |
| namespace c10d { |
| |
| ncclComm_t NCCLComm::getNcclComm() { |
| std::unique_lock<std::mutex> lock(mutex_); |
| if (aborted_) { |
| auto commFailureMsg = commFailureReason_ != c10::nullopt |
| ? c10::str(" Original reason for failure was: ", *commFailureReason_) |
| : ""; |
| TORCH_CHECK( |
| false, |
| c10::str( |
| "NCCL communicator was aborted on rank ", |
| rank_, |
| ". ", |
| commFailureMsg)); |
| } |
| return ncclComm_; |
| } |
| |
| std::string getNcclVersion() { |
| static c10::once_flag ncclGetVersionFlag; |
| static std::string versionString; |
| |
| c10::call_once(ncclGetVersionFlag, []() { |
| int version; |
| ncclResult_t status = ncclGetVersion(&version); |
| // can't compute the version if call did not return successfully or version |
| // code < 100 (corresponding to 0.1.0) |
| if (status != ncclSuccess || version < 100) { |
| versionString = "Unknown NCCL version"; |
| } else { |
| // NCCL changed version coding starting 2.9 |
| const int majorBase = version < 2900 ? 1000 : 10000; |
| const int minorBase = 100; |
| auto ncclMajor = version / majorBase; |
| auto ncclMinor = (version % majorBase) / minorBase; |
| auto ncclPatch = |
| version % (ncclMajor * majorBase + ncclMinor * minorBase); |
| versionString = std::to_string(ncclMajor) + "." + |
| std::to_string(ncclMinor) + "." + std::to_string(ncclPatch); |
| } |
| }); |
| |
| return versionString; |
| } |
| |
| std::string ncclGetErrorWithVersion(ncclResult_t error) { |
| return std::string(ncclGetErrorString(error)) + ", NCCL version " + |
| getNcclVersion(); |
| } |
| |
| } // namespace c10d |
| |
| #endif // USE_C10D_NCCL |