[XLA:GPU] Adds a barrier that requires all local participants to call ncclCommInitRank() before allowing any participant to proceed.

Without a barrier, we can experience deadlocks. As best I understand it, the deadlock scenario looks like this:

Thread A:
* calls ncclCommInitRank(), which succeeds,
* issues the collective operation,
* calls an operation that manipulates the device page tables, e.g., copying a device buffer to an unpinned host buffer.
* Since this action manipulates the device page tables, it seems that this action blocks waiting for the device stream.

Thread B:
* calls ncclCommInitRank(), which calls cudaMalloc().
* cudaMalloc() also manipulates device page tables, and cannot proceed without acquiring an internal lock around the device page table state
But thread A already holds this lock, but thread A cannot make progress until thread B issues its collective.

This is a deadlock: neither thread can make progress. We can avoid the problem by requiring a barrier after the calls to ncclCommInitRank(), requiring all GPUs to finish initialization before any of them can issue their collective operation.

Fixes https://github.com/google/jax/issues/11637

PiperOrigin-RevId: 464164328
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 1bea110..e718767 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -544,6 +544,7 @@
     deps = if_gpu_is_configured([
         ":gpu_executable_run_options",
         "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/time",
         "//tensorflow/compiler/xla:debug_options_flags",
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
index ffa1578..4f7e762 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
@@ -19,7 +19,9 @@
 #include <string_view>
 #include <utility>
 
+#include "absl/container/flat_hash_map.h"
 #include "absl/strings/str_format.h"
+#include "absl/synchronization/notification.h"
 #include "absl/time/time.h"
 #include "tensorflow/compiler/xla/debug_options_flags.h"
 #include "tensorflow/compiler/xla/service/global_device_id.h"
@@ -122,6 +124,14 @@
 struct NcclCliqueState {
   ncclUniqueId unique_id;
   int64_t run_id = -1;
+
+  // mu guards ready, status, and communicators during initialization.
+  // Once 'ready' has been notified, the communicators may be accessed without
+  // synchronization.
+  absl::Mutex mu;
+  absl::Notification ready;
+  Status status;
+  absl::flat_hash_map<int, std::unique_ptr<NcclComm>> communicators;
 };
 
 using NcclClique = Lockable<NcclCliqueState>;
@@ -221,9 +231,13 @@
       run_id, op_id, clique_key, unique_id_callback, num_local_participants);
 
   if (!clique->ok()) return clique->status();
+  NcclCliqueState& state = ***clique;
 
-  auto comm_key = std::make_pair(std::move(clique_key), rank);
-  static auto& comms = *new ThreadSafeMap<decltype(comm_key), NcclComm>;
+  struct AllCommunicators {
+    absl::Mutex mu;
+    std::vector<NcclComm*> communicators ABSL_GUARDED_BY(mu);
+  };
+  static auto& all_communicators = *new AllCommunicators;
 
   // Launch a thread that periodically checks all NCCL communicators for
   // asynchronous errors. If an asynchronous error is observed, the communicator
@@ -233,19 +247,53 @@
           tensorflow::ThreadOptions(), "nccl_async_error_thread", [&] {
             while (true) {
               absl::SleepFor(absl::Seconds(30));
-              comms.ForEachValue(CheckNcclAsyncError);
+              absl::MutexLock lock(&all_communicators.mu);
+              for (NcclComm* comm : all_communicators.communicators) {
+                CheckNcclAsyncError(*comm);
+              }
             }
           });
   (void)check_async_error_thread;  // Silence unused variable warning.
 
-  NcclComm::Lock comm = comms[comm_key].Acquire();
-  if (*comm == nullptr) {
-    int nranks = comm_key.first.devices().size();
-    const ncclUniqueId& id = (**clique)->unique_id;
-    XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(comm.get(), nranks, id, rank));
+  NcclComm::Lock comm;
+  if (state.ready.HasBeenNotified()) {
+    comm = state.communicators[rank]->Acquire();
+  } else {
+    auto comm_ptr = std::make_unique<NcclComm>();
+    comm = comm_ptr->Acquire();
+    int nranks = clique_key.devices().size();
+    const ncclUniqueId& id = state.unique_id;
+    Status status =
+        XLA_CUDA_STATUS(ncclCommInitRank(comm.get(), nranks, id, rank));
+
+    // Add the communicator to the all_communicators list.
+    {
+      absl::MutexLock lock(&all_communicators.mu);
+      all_communicators.communicators.push_back(comm_ptr.get());
+    }
+
+    absl::MutexLock lock(&state.mu);
+    state.status.Update(status);
+    state.communicators[rank] = std::move(comm_ptr);
+
+    // Wait for all communicators to initialize before allowing any progress.
+    // Otherwise we may get deadlocks, because ncclCommInitRank may allocate,
+    // which may block on the completion of device activity on a peer device,
+    // which may depend on the completion of this collective if we do not have a
+    // barrier to prevent it.
+    auto all_initialized = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu) {
+      return state.communicators.size() == num_local_participants;
+    };
+    state.mu.Await(absl::Condition(&all_initialized));
+    status = state.status;
+    if (!state.ready.HasBeenNotified()) {
+      state.ready.Notify();
+    }
+  }
+  if (!state.status.ok()) {
+    return state.status;
   }
   return comm;
 }
-
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.h b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
index 39d7fba..93fa99a 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
@@ -84,7 +84,12 @@
   // RAII type that will release the exclusive lock when it is destroyed.
   using Lock = std::unique_ptr<T, std::function<void(T*)>>;
 
-  explicit Lockable(T value = T()) : value_(std::move(value)) {}
+  Lockable() = default;
+  explicit Lockable(T value) : value_(std::move(value)) {}
+  Lockable(const Lockable&) = delete;
+  Lockable(Lockable&&) = delete;
+  Lockable& operator=(const Lockable&) = delete;
+  Lockable& operator=(Lockable&&) = delete;
 
   Lock Acquire() {
     absl::MutexLock lock(&mutex_);