[XLA:GPU] Fix deadlock in NCCL clique code when multi-host collectives with replica groups are used.
The creation of NCCL communicators requires a barrier across all of the devices in the clique. This caused a deadlock, as different hosts were trying to create their communicators in different orders. This fix allows them to be created in parallel.
PiperOrigin-RevId: 377925635
Change-Id: I74a679f9335b15ea5498ecb01a95f71b51ab3acb
diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
index fc19f09..63ca1d6 100644
--- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc
@@ -315,13 +315,24 @@
const NcclCliqueKey& key,
const std::function<StatusOr<std::unique_ptr<NcclClique>>(
const NcclCliqueKey&)>& value_factory) {
- absl::MutexLock lock(&mu_);
- auto it = map_.find(key);
- if (it == map_.end()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<NcclClique> value, value_factory(key));
- it = map_.emplace(key, std::move(value)).first;
+ {
+ absl::MutexLock lock(&mu_);
+ auto it = map_.find(key);
+ if (it != map_.end()) {
+ return it->second.get();
+ }
}
- return it->second.get();
+ // We release the lock to allow different cliques to be created in parallel
+ // (avoiding a potential deadlock in multi-host settings). This is safe
+ // provided that there aren't two threads trying to create cliques with the
+ // same key - which we know will not happen as this method is only called by
+ // the primary thread from the clique rendezvous. If this assumption is not
+ // valid, the method will return an error.
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<NcclClique> value, value_factory(key));
+ absl::MutexLock lock(&mu_);
+ auto result = map_.emplace(key, std::move(value));
+ TF_RET_CHECK(result.second) << "Clique already in cache.";
+ return result.first->second.get();
}
void NcclCliqueMap::ForEach(