add mutex getter/setter to synchronize CUDA and NCCL ops
Summary: Allow gloo consumers to assign a mutex to synchronize CUDA malloc/free and NCCL operations.
Reviewed By: pietern
Differential Revision: D4622135
fbshipit-source-id: 60acd7c01a677a0df5415fe38e6ef5a2e7c8606a
diff --git a/gloo/cuda.cu b/gloo/cuda.cu
index fa52e64..12612ef 100644
--- a/gloo/cuda.cu
+++ b/gloo/cuda.cu
@@ -14,6 +14,10 @@
const cudaStream_t kStreamNotSet = (cudaStream_t)(-1);
+// Default mutex to synchronize contentious CUDA and NCCL operations
+static std::mutex defaultCudaMutex;
+std::atomic<std::mutex*> CudaShared::mutex_(&defaultCudaMutex);
+
template<typename T>
CudaDevicePointer<T>
CudaDevicePointer<T>::create(
diff --git a/gloo/cuda.h b/gloo/cuda.h
index d5ecc40..9d96196 100644
--- a/gloo/cuda.h
+++ b/gloo/cuda.h
@@ -11,11 +11,28 @@
#include <cuda.h>
#include <cuda_runtime.h>
+#include <mutex>
namespace gloo {
extern const cudaStream_t kStreamNotSet;
+class CudaShared {
+ public:
+ // Get the mutex used to synchronize CUDA and NCCL operations
+ static std::mutex& getMutex() {
+ return *mutex_;
+ }
+
+ // Set the mutex used to synchronize CUDA and NCCL operations
+ static void SetMutex(std::mutex* m) {
+ mutex_ = m;
+ }
+
+ private:
+ static std::atomic<std::mutex*> mutex_;
+};
+
template<typename T>
class CudaDevicePointer {
public:
diff --git a/gloo/cuda_allreduce_ring_chunked.cc b/gloo/cuda_allreduce_ring_chunked.cc
index 9f62b3b..0609d92 100644
--- a/gloo/cuda_allreduce_ring_chunked.cc
+++ b/gloo/cuda_allreduce_ring_chunked.cc
@@ -90,7 +90,7 @@
// Setup host and device memory
{
// Synchronize memory allocation with NCCL operations
- std::lock_guard<std::mutex> lock(gCudaMutex);
+ std::lock_guard<std::mutex> lock(CudaShared::getMutex());
CUDA_CHECK(cudaMallocHost(&hostPtr_, bytes_));
}
for (auto offset = 0; offset < count_; offset += chunkSize_) {
@@ -157,7 +157,7 @@
CudaAllreduceRingChunked<T>::~CudaAllreduceRingChunked() {
{
// Synchronize memory allocation with NCCL operations
- std::lock_guard<std::mutex> lock(gCudaMutex);
+ std::lock_guard<std::mutex> lock(CudaShared::getMutex());
CUDA_CHECK(cudaFreeHost(hostPtr_));
}
for (auto i = 0; i < 2; i++) {
diff --git a/gloo/cuda_nccl.cu b/gloo/cuda_nccl.cu
index 55bf864..be1df52 100644
--- a/gloo/cuda_nccl.cu
+++ b/gloo/cuda_nccl.cu
@@ -30,7 +30,7 @@
}
{
// Initialze comms. Synchronize with conflicting CUDA and NCCL operations.
- std::lock_guard<std::mutex> lock(gCudaMutex);
+ std::lock_guard<std::mutex> lock(CudaShared::getMutex());
comms_.resize(elements_.size());
NCCL_CHECK(ncclCommInitAll(comms_.data(), devices.size(), devices.data()));
}
@@ -63,7 +63,7 @@
CUDA_CHECK(cudaEventDestroy(events_[i]));
{
// Synchronize memory allocation with NCCL operations
- std::lock_guard<std::mutex> lock(gCudaMutex);
+ std::lock_guard<std::mutex> lock(CudaShared::getMutex());
ncclCommDestroy(comms_[i]);
}
}
@@ -102,7 +102,7 @@
// Kick off the NCCL operation on each device
{
// Synchronize memory allocation with NCCL operations
- std::lock_guard<std::mutex> lock(gCudaMutex);
+ std::lock_guard<std::mutex> lock(CudaShared::getMutex());
const auto& elements = context_.elements_;
for (auto i = 0; i < elements.size(); i++) {
diff --git a/gloo/cuda_private.cu b/gloo/cuda_private.cu
index 5843f65..daf09ea 100644
--- a/gloo/cuda_private.cu
+++ b/gloo/cuda_private.cu
@@ -13,8 +13,6 @@
namespace gloo {
-std::mutex gCudaMutex;
-
template<typename T>
__global__ void initializeMemory(T* ptr, const T val, const size_t n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
@@ -27,7 +25,7 @@
CudaMemory<T>::CudaMemory(size_t n): n_(n), bytes_(n * sizeof(T)) {
CUDA_CHECK(cudaGetDevice(&device_));
// Sychronize memory allocation with NCCL operations
- std::lock_guard<std::mutex> lock(gCudaMutex);
+ std::lock_guard<std::mutex> lock(CudaShared::getMutex());
CUDA_CHECK(cudaMalloc(&ptr_, bytes_));
}
@@ -46,7 +44,7 @@
CudaDeviceScope scope(device_);
if (ptr_ != nullptr) {
// Sychronize memory allocation with NCCL operations
- std::lock_guard<std::mutex> lock(gCudaMutex);
+ std::lock_guard<std::mutex> lock(CudaShared::getMutex());
CUDA_CHECK(cudaFree(ptr_));
}
}
diff --git a/gloo/cuda_private.h b/gloo/cuda_private.h
index 0eeb531..4ebc001 100644
--- a/gloo/cuda_private.h
+++ b/gloo/cuda_private.h
@@ -17,9 +17,6 @@
namespace gloo {
-// Default mutex to synchronize contentious CUDA and NCCL operations
-extern std::mutex gCudaMutex;
-
#define CUDA_CHECK(condition) \
do { \
cudaError_t error = condition; \