add mutex getter/setter to synchronize CUDA and NCCL ops

Summary: Allow gloo consumers to assign a mutex to synchronize CUDA malloc/free and NCCL operations.

Reviewed By: pietern

Differential Revision: D4622135

fbshipit-source-id: 60acd7c01a677a0df5415fe38e6ef5a2e7c8606a
diff --git a/gloo/cuda.cu b/gloo/cuda.cu
index fa52e64..12612ef 100644
--- a/gloo/cuda.cu
+++ b/gloo/cuda.cu
@@ -14,6 +14,10 @@
 
 const cudaStream_t kStreamNotSet = (cudaStream_t)(-1);
 
+// Default mutex to synchronize contentious CUDA and NCCL operations
+static std::mutex defaultCudaMutex;
+std::atomic<std::mutex*> CudaShared::mutex_(&defaultCudaMutex);
+
 template<typename T>
 CudaDevicePointer<T>
 CudaDevicePointer<T>::create(
diff --git a/gloo/cuda.h b/gloo/cuda.h
index d5ecc40..9d96196 100644
--- a/gloo/cuda.h
+++ b/gloo/cuda.h
@@ -11,11 +11,28 @@
 
 #include <cuda.h>
 #include <cuda_runtime.h>
+#include <mutex>
 
 namespace gloo {
 
 extern const cudaStream_t kStreamNotSet;
 
+class CudaShared {
+ public:
+  // Get the mutex used to synchronize CUDA and NCCL operations
+  static std::mutex& getMutex() {
+    return *mutex_;
+  }
+
+  // Set the mutex used to synchronize CUDA and NCCL operations
+  static void SetMutex(std::mutex* m) {
+    mutex_ = m;
+  }
+
+ private:
+  static std::atomic<std::mutex*> mutex_;
+};
+
 template<typename T>
 class CudaDevicePointer {
  public:
diff --git a/gloo/cuda_allreduce_ring_chunked.cc b/gloo/cuda_allreduce_ring_chunked.cc
index 9f62b3b..0609d92 100644
--- a/gloo/cuda_allreduce_ring_chunked.cc
+++ b/gloo/cuda_allreduce_ring_chunked.cc
@@ -90,7 +90,7 @@
   // Setup host and device memory
   {
     // Synchronize memory allocation with NCCL operations
-    std::lock_guard<std::mutex> lock(gCudaMutex);
+    std::lock_guard<std::mutex> lock(CudaShared::getMutex());
     CUDA_CHECK(cudaMallocHost(&hostPtr_, bytes_));
   }
   for (auto offset = 0; offset < count_; offset += chunkSize_) {
@@ -157,7 +157,7 @@
 CudaAllreduceRingChunked<T>::~CudaAllreduceRingChunked() {
   {
     // Synchronize memory allocation with NCCL operations
-    std::lock_guard<std::mutex> lock(gCudaMutex);
+    std::lock_guard<std::mutex> lock(CudaShared::getMutex());
     CUDA_CHECK(cudaFreeHost(hostPtr_));
   }
   for (auto i = 0; i < 2; i++) {
diff --git a/gloo/cuda_nccl.cu b/gloo/cuda_nccl.cu
index 55bf864..be1df52 100644
--- a/gloo/cuda_nccl.cu
+++ b/gloo/cuda_nccl.cu
@@ -30,7 +30,7 @@
   }
   {
     // Initialze comms. Synchronize with conflicting CUDA and NCCL operations.
-    std::lock_guard<std::mutex> lock(gCudaMutex);
+    std::lock_guard<std::mutex> lock(CudaShared::getMutex());
     comms_.resize(elements_.size());
     NCCL_CHECK(ncclCommInitAll(comms_.data(), devices.size(), devices.data()));
   }
@@ -63,7 +63,7 @@
     CUDA_CHECK(cudaEventDestroy(events_[i]));
     {
       // Synchronize memory allocation with NCCL operations
-      std::lock_guard<std::mutex> lock(gCudaMutex);
+      std::lock_guard<std::mutex> lock(CudaShared::getMutex());
       ncclCommDestroy(comms_[i]);
     }
   }
@@ -102,7 +102,7 @@
   // Kick off the NCCL operation on each device
   {
     // Synchronize memory allocation with NCCL operations
-    std::lock_guard<std::mutex> lock(gCudaMutex);
+    std::lock_guard<std::mutex> lock(CudaShared::getMutex());
 
     const auto& elements = context_.elements_;
     for (auto i = 0; i < elements.size(); i++) {
diff --git a/gloo/cuda_private.cu b/gloo/cuda_private.cu
index 5843f65..daf09ea 100644
--- a/gloo/cuda_private.cu
+++ b/gloo/cuda_private.cu
@@ -13,8 +13,6 @@
 
 namespace gloo {
 
-std::mutex gCudaMutex;
-
 template<typename T>
 __global__ void initializeMemory(T* ptr, const T val, const size_t n) {
   int i = blockIdx.x * blockDim.x + threadIdx.x;
@@ -27,7 +25,7 @@
 CudaMemory<T>::CudaMemory(size_t n): n_(n), bytes_(n * sizeof(T)) {
   CUDA_CHECK(cudaGetDevice(&device_));
   // Sychronize memory allocation with NCCL operations
-  std::lock_guard<std::mutex> lock(gCudaMutex);
+  std::lock_guard<std::mutex> lock(CudaShared::getMutex());
   CUDA_CHECK(cudaMalloc(&ptr_, bytes_));
 }
 
@@ -46,7 +44,7 @@
   CudaDeviceScope scope(device_);
   if (ptr_ != nullptr) {
     // Sychronize memory allocation with NCCL operations
-    std::lock_guard<std::mutex> lock(gCudaMutex);
+    std::lock_guard<std::mutex> lock(CudaShared::getMutex());
     CUDA_CHECK(cudaFree(ptr_));
   }
 }
diff --git a/gloo/cuda_private.h b/gloo/cuda_private.h
index 0eeb531..4ebc001 100644
--- a/gloo/cuda_private.h
+++ b/gloo/cuda_private.h
@@ -17,9 +17,6 @@
 
 namespace gloo {
 
-// Default mutex to synchronize contentious CUDA and NCCL operations
-extern std::mutex gCudaMutex;
-
 #define CUDA_CHECK(condition)                   \
   do {                                          \
     cudaError_t error = condition;              \