[XLA:GPU] Adds a barrier that requires all local participants to call ncclCommInitRank() before allowing any participant to proceed.
Without a barrier, we can experience deadlocks. As best I understand it, the deadlock scenario looks like this:
Thread A:
* calls ncclCommInitRank(), which succeeds,
* issues the collective operation,
* calls an operation that manipulates the device page tables, e.g., copying a device buffer to an unpinned host buffer.
* Since this action manipulates the device page tables, it seems that this action blocks waiting for the device stream.
Thread B:
* calls ncclCommInitRank(), which calls cudaMalloc().
* cudaMalloc() also manipulates device page tables, and cannot proceed without acquiring an internal lock around the device page table state
But thread A already holds this lock, but thread A cannot make progress until thread B issues its collective.
This is a deadlock: neither thread can make progress. We can avoid the problem by requiring a barrier after the calls to ncclCommInitRank(), requiring all GPUs to finish initialization before any of them can issue their collective operation.
Fixes https://github.com/google/jax/issues/11637
PiperOrigin-RevId: 464164328
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 1bea110..e718767 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -544,6 +544,7 @@
deps = if_gpu_is_configured([
":gpu_executable_run_options",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"//tensorflow/compiler/xla:debug_options_flags",
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
index ffa1578..4f7e762 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
@@ -19,7 +19,9 @@
#include <string_view>
#include <utility>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
+#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
@@ -122,6 +124,14 @@
struct NcclCliqueState {
ncclUniqueId unique_id;
int64_t run_id = -1;
+
+ // mu guards ready, status, and communicators during initialization.
+ // Once 'ready' has been notified, the communicators may be accessed without
+ // synchronization.
+ absl::Mutex mu;
+ absl::Notification ready;
+ Status status;
+ absl::flat_hash_map<int, std::unique_ptr<NcclComm>> communicators;
};
using NcclClique = Lockable<NcclCliqueState>;
@@ -221,9 +231,13 @@
run_id, op_id, clique_key, unique_id_callback, num_local_participants);
if (!clique->ok()) return clique->status();
+ NcclCliqueState& state = ***clique;
- auto comm_key = std::make_pair(std::move(clique_key), rank);
- static auto& comms = *new ThreadSafeMap<decltype(comm_key), NcclComm>;
+ struct AllCommunicators {
+ absl::Mutex mu;
+ std::vector<NcclComm*> communicators ABSL_GUARDED_BY(mu);
+ };
+ static auto& all_communicators = *new AllCommunicators;
// Launch a thread that periodically checks all NCCL communicators for
// asynchronous errors. If an asynchronous error is observed, the communicator
@@ -233,19 +247,53 @@
tensorflow::ThreadOptions(), "nccl_async_error_thread", [&] {
while (true) {
absl::SleepFor(absl::Seconds(30));
- comms.ForEachValue(CheckNcclAsyncError);
+ absl::MutexLock lock(&all_communicators.mu);
+ for (NcclComm* comm : all_communicators.communicators) {
+ CheckNcclAsyncError(*comm);
+ }
}
});
(void)check_async_error_thread; // Silence unused variable warning.
- NcclComm::Lock comm = comms[comm_key].Acquire();
- if (*comm == nullptr) {
- int nranks = comm_key.first.devices().size();
- const ncclUniqueId& id = (**clique)->unique_id;
- XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(comm.get(), nranks, id, rank));
+ NcclComm::Lock comm;
+ if (state.ready.HasBeenNotified()) {
+ comm = state.communicators[rank]->Acquire();
+ } else {
+ auto comm_ptr = std::make_unique<NcclComm>();
+ comm = comm_ptr->Acquire();
+ int nranks = clique_key.devices().size();
+ const ncclUniqueId& id = state.unique_id;
+ Status status =
+ XLA_CUDA_STATUS(ncclCommInitRank(comm.get(), nranks, id, rank));
+
+ // Add the communicator to the all_communicators list.
+ {
+ absl::MutexLock lock(&all_communicators.mu);
+ all_communicators.communicators.push_back(comm_ptr.get());
+ }
+
+ absl::MutexLock lock(&state.mu);
+ state.status.Update(status);
+ state.communicators[rank] = std::move(comm_ptr);
+
+ // Wait for all communicators to initialize before allowing any progress.
+ // Otherwise we may get deadlocks, because ncclCommInitRank may allocate,
+ // which may block on the completion of device activity on a peer device,
+ // which may depend on the completion of this collective if we do not have a
+ // barrier to prevent it.
+ auto all_initialized = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu) {
+ return state.communicators.size() == num_local_participants;
+ };
+ state.mu.Await(absl::Condition(&all_initialized));
+ status = state.status;
+ if (!state.ready.HasBeenNotified()) {
+ state.ready.Notify();
+ }
+ }
+ if (!state.status.ok()) {
+ return state.status;
}
return comm;
}
-
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.h b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
index 39d7fba..93fa99a 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
@@ -84,7 +84,12 @@
// RAII type that will release the exclusive lock when it is destroyed.
using Lock = std::unique_ptr<T, std::function<void(T*)>>;
- explicit Lockable(T value = T()) : value_(std::move(value)) {}
+ Lockable() = default;
+ explicit Lockable(T value) : value_(std::move(value)) {}
+ Lockable(const Lockable&) = delete;
+ Lockable(Lockable&&) = delete;
+ Lockable& operator=(const Lockable&) = delete;
+ Lockable& operator=(Lockable&&) = delete;
Lock Acquire() {
absl::MutexLock lock(&mutex_);