add lock for ncclCommAbort (#31901)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31901

ncclCommAbort is not thread safe, so adding a lock for it
ghstack-source-id: 96829715

Test Plan: unit tests

Differential Revision: D19293869

fbshipit-source-id: 711b4a07605d6e5a81577247d2f90a78041c1809
diff --git a/torch/lib/c10d/NCCLUtils.hpp b/torch/lib/c10d/NCCLUtils.hpp
index 79dd0e2..bac37f4 100644
--- a/torch/lib/c10d/NCCLUtils.hpp
+++ b/torch/lib/c10d/NCCLUtils.hpp
@@ -4,6 +4,7 @@
 #include <stdlib.h>
 
 #include <memory>
+#include <mutex>
 
 #include <nccl.h>
 
@@ -57,6 +58,9 @@
   NCCLComm() : NCCLComm(nullptr) {}
 
   ~NCCLComm() noexcept {
+    // Add lock in this destructor, as aborted_ needs to be read after memory
+    // barrier here.
+    std::unique_lock<std::mutex> lock(mutex_);
     if (ncclComm_ && !aborted_) {
 #ifdef ENABLE_NCCL_ERROR_CHECKING
       // Use ncclCommAbort instead of ncclCommDestroy here since
@@ -83,22 +87,21 @@
   NCCLComm(const NCCLComm&) = delete;
   NCCLComm& operator=(const NCCLComm&) = delete;
 
+  // Do not support move assignment as there is no valid use case
+  NCCLComm& operator=(NCCLComm&& other) = delete;
+
   // Move constructable
   NCCLComm(NCCLComm&& other) {
+    // Using other's lock, as it reads other's states
+    // Can not use this.mutex_, as this object is being constructed.
+    std::unique_lock<std::mutex> lock(other.mutex_);
     std::swap(ncclComm_, other.ncclComm_);
     std::swap(aborted_, other.aborted_);
     std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
   }
 
-  // Move assignable
-  NCCLComm& operator=(NCCLComm&& other) {
-    std::swap(ncclComm_, other.ncclComm_);
-    std::swap(aborted_, other.aborted_);
-    std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
-    return *this;
-  }
-
   ncclComm_t getNcclComm() {
+    std::unique_lock<std::mutex> lock(mutex_);
     if (aborted_) {
       throw std::runtime_error("NCCL communicator was aborted.");
     }
@@ -106,6 +109,7 @@
   }
 
   void ncclCommAbort() {
+    std::unique_lock<std::mutex> lock(mutex_);
 #ifdef ENABLE_NCCL_ERROR_CHECKING
     if (aborted_) {
       // Should not abort twice.
@@ -127,10 +131,12 @@
   }
 
   bool isAborted() const {
+    std::unique_lock<std::mutex> lock(mutex_);
     return aborted_;
   }
 
   ncclResult_t checkForNcclError() {
+    std::unique_lock<std::mutex> lock(mutex_);
 #ifdef ENABLE_NCCL_ERROR_CHECKING
     if (ncclAsyncErr_ != ncclSuccess) {
       return ncclAsyncErr_;
@@ -147,6 +153,7 @@
   ncclComm_t ncclComm_;
   bool aborted_;
   ncclResult_t ncclAsyncErr_;
+  mutable std::mutex mutex_;
 };
 
 } // namespace c10d