add lock for ncclCommAbort (#31901)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31901
ncclCommAbort is not thread safe, so adding a lock for it
ghstack-source-id: 96829715
Test Plan: unit tests
Differential Revision: D19293869
fbshipit-source-id: 711b4a07605d6e5a81577247d2f90a78041c1809
diff --git a/torch/lib/c10d/NCCLUtils.hpp b/torch/lib/c10d/NCCLUtils.hpp
index 79dd0e2..bac37f4 100644
--- a/torch/lib/c10d/NCCLUtils.hpp
+++ b/torch/lib/c10d/NCCLUtils.hpp
@@ -4,6 +4,7 @@
#include <stdlib.h>
#include <memory>
+#include <mutex>
#include <nccl.h>
@@ -57,6 +58,9 @@
NCCLComm() : NCCLComm(nullptr) {}
~NCCLComm() noexcept {
+ // Add lock in this destructor, as aborted_ needs to be read after memory
+ // barrier here.
+ std::unique_lock<std::mutex> lock(mutex_);
if (ncclComm_ && !aborted_) {
#ifdef ENABLE_NCCL_ERROR_CHECKING
// Use ncclCommAbort instead of ncclCommDestroy here since
@@ -83,22 +87,21 @@
NCCLComm(const NCCLComm&) = delete;
NCCLComm& operator=(const NCCLComm&) = delete;
+ // Do not support move assignment as there is no valid use case
+ NCCLComm& operator=(NCCLComm&& other) = delete;
+
// Move constructable
NCCLComm(NCCLComm&& other) {
+ // Using other's lock, as it reads other's states
+ // Can not use this.mutex_, as this object is being constructed.
+ std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
}
- // Move assignable
- NCCLComm& operator=(NCCLComm&& other) {
- std::swap(ncclComm_, other.ncclComm_);
- std::swap(aborted_, other.aborted_);
- std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
- return *this;
- }
-
ncclComm_t getNcclComm() {
+ std::unique_lock<std::mutex> lock(mutex_);
if (aborted_) {
throw std::runtime_error("NCCL communicator was aborted.");
}
@@ -106,6 +109,7 @@
}
void ncclCommAbort() {
+ std::unique_lock<std::mutex> lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (aborted_) {
// Should not abort twice.
@@ -127,10 +131,12 @@
}
bool isAborted() const {
+ std::unique_lock<std::mutex> lock(mutex_);
return aborted_;
}
ncclResult_t checkForNcclError() {
+ std::unique_lock<std::mutex> lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_;
@@ -147,6 +153,7 @@
ncclComm_t ncclComm_;
bool aborted_;
ncclResult_t ncclAsyncErr_;
+ mutable std::mutex mutex_;
};
} // namespace c10d