[XLA:GPU] Add background thread that periodically checks for asynchronous NCCL errors.

NCCL can have asynchronous errors from, e.g., sockets. If an async error occurs, the GPU may deadlock. If the thread detects an asynchronous error, the communicator is aborted and an error logged.

PiperOrigin-RevId: 413425827
Change-Id: I7565c660895c7348792033d842c3efe7869e45af
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
index 14aede1..776e02f 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
@@ -28,6 +28,7 @@
 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
 
 namespace xla {
@@ -272,6 +273,31 @@
   absl::BlockingCounter* counter_;
 };
 
+// Periodically checks all NCCL communicators for asynchronous errors.
+// If an asynchronous error is observed, the communicator is aborted and an
+// error message logged.
+void CheckNcclAsyncErrors() {
+  while (true) {
+    absl::SleepFor(absl::Seconds(30));
+
+    NcclCliqueCache().ForEach([](const auto&, const NcclClique& clique) {
+      for (const auto& it : clique.GetComms()) {
+        ncclComm_t comm = it.second.get();
+        Status status = [comm] {
+          ncclResult_t async_err;
+          XLA_CUDA_RETURN_IF_ERROR(ncclCommGetAsyncError(comm, &async_err));
+          if (async_err != ncclSuccess) {
+            LOG(ERROR) << "Async NCCL error. Aborting communicator: " << comm;
+            XLA_CUDA_RETURN_IF_ERROR(ncclCommAbort(comm));
+          }
+          return XLA_CUDA_STATUS(async_err);
+        }();
+        if (!status.ok()) LOG(ERROR) << status.ToString();
+      }
+    });
+  }
+}
+
 }  // namespace
 
 StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
@@ -360,6 +386,13 @@
     const NcclCliqueParticipantData& participant,
     const std::vector<LocalParticipant>& local_participants,
     const NcclUniqueIdCallback* callback) {
+  // Launch a thread to check for async NCCL errors.
+  static auto check_async_error_thread =
+      tensorflow::Env::Default()->StartThread(tensorflow::ThreadOptions(),
+                                              "nccl_async_error_thread",
+                                              CheckNcclAsyncErrors);
+  (void)check_async_error_thread;  // Silence unused variable warning.
+
   VLOG(2) << "Rendezvous key: " << participant.rendezvous_key.ToString()
           << ", local participants: "
           << LocalParticipantsToString(local_participants);
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.h b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
index abe1471..9c96d42 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.h
@@ -99,14 +99,16 @@
 // GPUs, you'll need a different clique.
 class NcclClique {
  public:
-  explicit NcclClique(
-      absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal);
+  using CommMap = absl::flat_hash_map<int, NcclComm>;
+
+  explicit NcclClique(CommMap comms_by_device_ordinal);
 
   ncclComm_t GetCommForDeviceOrdinal(int device_ordinal) const;
+  const CommMap& GetComms() const { return comms_by_device_ordinal_; }
   absl::Mutex* mu() { return &mu_; }
 
  private:
-  absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal_;
+  CommMap comms_by_device_ordinal_;
   absl::Mutex mu_;
 };