blob: 63ca1d6a8f280b1cd91e8e5378baf00bb0504133 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
#include <memory>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/global_device_id.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/errors.h"
namespace xla {
namespace gpu {
ncclRedOp_t ToNcclReduction(ReductionKind kind) {
switch (kind) {
case ReductionKind::SUM:
return ncclSum;
case ReductionKind::PRODUCT:
return ncclProd;
case ReductionKind::MIN:
return ncclMin;
case ReductionKind::MAX:
return ncclMax;
}
}
StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type) {
switch (element_type) {
case S8:
return ncclInt8;
case PRED:
case U8:
return ncclUint8;
case S32:
return ncclInt32;
case U32:
return ncclUint32;
case S64:
return ncclInt64;
case U64:
return ncclUint64;
case F16:
return ncclFloat16;
case F32:
return ncclFloat32;
case F64:
return ncclFloat64;
default:
return tensorflow::errors::InvalidArgument(absl::StrFormat(
"Unsupported data type: %s", PrimitiveType_Name(element_type)));
}
}
bool IsGlobalNcclConfig() {
static const bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
return global_nccl_config;
}
bool IsNcclLaunchModeParallel() {
static const bool is_launch_mode_parallel =
absl::string_view(std::getenv("NCCL_LAUNCH_MODE")) == "PARALLEL";
return is_launch_mode_parallel;
}
Status ToStatus(ncclResult_t s, const char* file, int64 line,
const char* expr) {
if (s == ncclSuccess) {
return Status::OK();
}
return tensorflow::errors::Internal(
absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
ncclGetErrorString(s)));
}
Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr) {
if (s == cudaSuccess) {
return Status::OK();
}
return tensorflow::errors::Internal(
absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
cudaGetErrorString(s)));
}
NcclClique::NcclClique(
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal)
: comms_by_device_ordinal_(std::move(comms_by_device_ordinal)) {}
ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
return comms_by_device_ordinal_.at(device_ordinal).get();
}
NcclCliqueMap& NcclCliqueCache() {
// Global cache of NCCL cliques. An entry in this map is always kept alive.
//
// A consequence of the fact that this is process-global is that we'll only
// ever have one clique alive for a given set of GPUs. This means that a
// process will never do two collective operations concurrently on the same
// set of GPUs.
static auto& cache = *new NcclCliqueMap();
return cache;
}
namespace {
void DestroyNcclComm(ncclComm_t comm) {
VLOG(3) << absl::StreamFormat("Destroying comm %p", comm);
XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(comm));
}
Status ToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
return InvalidArgument(
"ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
NCCL_UNIQUE_ID_BYTES);
}
// NcclUniqueId is internally just a char[].
static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
"NCCL_UNIQUE_ID_BYTES");
std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
return Status::OK();
}
std::string LocalParticipantsToString(
const std::vector<LocalParticipant>& local_participants) {
std::vector<std::string> parts;
for (const LocalParticipant& local_participant : local_participants) {
parts.push_back(absl::StrFormat("%d/rank=%d",
local_participant.device_ordinal,
local_participant.rank));
}
return absl::StrJoin(parts, ",");
}
StatusOr<std::unique_ptr<NcclClique>> CreateNcclClique(
const NcclCliqueKey& key,
const std::vector<LocalParticipant>& local_participants,
const NcclUniqueIdCallback* callback) {
int num_participants = key.devices().size();
ncclUniqueId unique_id;
if (callback) { // Multi-host collective.
TF_ASSIGN_OR_RETURN(std::string id_string, (*callback)(key));
TF_RETURN_IF_ERROR(ToNcclUniqueId(id_string, &unique_id));
} else {
TF_RET_CHECK((num_participants == local_participants.size()) ||
IsGlobalNcclConfig())
<< "If non-local devices are taking part of a collective API on GPU, "
"the nccl_unique_id_callback must be provided by the client.";
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&unique_id));
}
VLOG(3) << "Initializing nccl comms for local participants: "
<< LocalParticipantsToString(local_participants);
// Restore CUDA device after running this. XLA shouldn't care, but maybe
// another consumer does.
int initial_cuda_device;
XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
auto cuda_device_restorer = MakeCleanup(
[&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
// When using ncclGroupStart/End it seems that the ncclComm_t's are not
// populated until the End() call.
std::vector<ncclComm_t> raw_comms(local_participants.size(), nullptr);
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
Status status = [&] {
for (int i = 0; i < local_participants.size(); ++i) {
XLA_CUDA_RETURN_IF_ERROR(
cudaSetDevice(local_participants[i].device_ordinal));
XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i], num_participants,
unique_id,
local_participants[i].rank));
}
return Status::OK();
}();
// Always call ncclGroupEnd().
status.Update(XLA_CUDA_STATUS(ncclGroupEnd()));
// Always copy raw comms to RAII type, so they are cleaned up properly.
absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal(raw_comms.size());
for (int i = 0; i < raw_comms.size(); ++i) {
int device_ordinal = local_participants[i].device_ordinal;
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
device_ordinal, raw_comms[i]);
CHECK(raw_comms[i] != nullptr || !status.ok());
comms_by_device_ordinal.emplace(device_ordinal,
NcclComm(raw_comms[i], &DestroyNcclComm));
}
// Now we can check if there was an error creating the communicators.
TF_RETURN_IF_ERROR(status);
return std::make_unique<NcclClique>(std::move(comms_by_device_ordinal));
}
struct NcclCliqueParticipantData : public ParticipantData {
using ParticipantData::ParticipantData;
std::string ToString() const override { return ""; }
};
class NcclCliqueRendezvous
: public Rendezvous<NcclCliqueParticipantData, LockedNcclClique> {
public:
NcclCliqueRendezvous(const RendezvousKey& rendezvous_key,
const std::vector<LocalParticipant>& local_participants,
const NcclUniqueIdCallback* callback)
: Rendezvous(rendezvous_key),
key_(std::move(rendezvous_key.global_devices)),
local_participants_(local_participants),
callback_(callback),
counter_(nullptr) {}
StatusOr<LockedNcclClique> RunCollectiveOp(
const NcclCliqueParticipantData&) override {
tensorflow::mutex_lock lock(mu_);
bool primary = !initialized_;
if (primary) {
maybe_clique_ = NcclCliqueCache().GetOrTryCreateIfAbsent(
key_, [&](const NcclCliqueKey& key) {
return CreateNcclClique(key, local_participants_, callback_);
});
initialized_ = true;
}
TF_ASSIGN_OR_RETURN(NcclClique * clique, maybe_clique_);
std::unique_ptr<absl::MutexLock> clique_lock;
if (primary) {
clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
counter_ = new absl::BlockingCounter(local_participants_.size());
}
return LockedNcclClique(*clique, std::move(clique_lock), counter_);
}
private:
NcclCliqueKey key_;
const std::vector<LocalParticipant>& local_participants_;
const NcclUniqueIdCallback* callback_;
StatusOr<NcclClique*> maybe_clique_;
absl::BlockingCounter* counter_;
};
} // namespace
StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
const std::vector<GlobalDeviceId>& participants,
const std::vector<GlobalDeviceId>* local_devices) {
std::vector<LocalParticipant> local_participants;
if (local_devices) {
absl::flat_hash_map<GlobalDeviceId, int> device_ranks(participants.size());
for (int rank = 0; rank < participants.size(); ++rank) {
auto result = device_ranks.emplace(participants[rank], rank);
TF_RET_CHECK(result.second) << "Duplicate device found";
}
local_participants.reserve(local_devices->size());
for (int device_ordinal = 0; device_ordinal < local_devices->size();
++device_ordinal) {
auto it = device_ranks.find((*local_devices)[device_ordinal]);
if (it != device_ranks.end()) {
local_participants.push_back({device_ordinal, /*rank=*/it->second});
}
}
} else { // Single host, so use identity mapping (device ordinal == id).
local_participants.reserve(participants.size());
for (int rank = 0; rank < participants.size(); ++rank) {
int device_ordinal = participants[rank].value();
local_participants.push_back({device_ordinal, rank});
}
}
return local_participants;
}
LockedNcclClique::LockedNcclClique(NcclClique& clique,
std::unique_ptr<absl::MutexLock> lock,
absl::BlockingCounter* counter)
: clique(clique), lock_(std::move(lock)), counter_(counter) {}
LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
: clique(other.clique),
lock_(std::move(other.lock_)),
counter_(std::exchange(other.counter_, nullptr)) {}
LockedNcclClique::~LockedNcclClique() {
if (counter_) {
counter_->DecrementCount();
if (lock_) {
counter_->Wait(); // Don't release lock until all threads are finished.
delete counter_;
}
}
}
StatusOr<NcclClique*> NcclCliqueMap::GetOrTryCreateIfAbsent(
const NcclCliqueKey& key,
const std::function<StatusOr<std::unique_ptr<NcclClique>>(
const NcclCliqueKey&)>& value_factory) {
{
absl::MutexLock lock(&mu_);
auto it = map_.find(key);
if (it != map_.end()) {
return it->second.get();
}
}
// We release the lock to allow different cliques to be created in parallel
// (avoiding a potential deadlock in multi-host settings). This is safe
// provided that there aren't two threads trying to create cliques with the
// same key - which we know will not happen as this method is only called by
// the primary thread from the clique rendezvous. If this assumption is not
// valid, the method will return an error.
TF_ASSIGN_OR_RETURN(std::unique_ptr<NcclClique> value, value_factory(key));
absl::MutexLock lock(&mu_);
auto result = map_.emplace(key, std::move(value));
TF_RET_CHECK(result.second) << "Clique already in cache.";
return result.first->second.get();
}
void NcclCliqueMap::ForEach(
const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn) {
absl::MutexLock lock(&mu_);
for (const auto& kv : map_) {
fn(kv.first, *kv.second);
}
}
StatusOr<LockedNcclClique> AcquireNcclClique(
const RendezvousKey& rendezvous_key, int local_device_ordinal,
se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
const NcclUniqueIdCallback* callback) {
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
<< ", local participants: "
<< LocalParticipantsToString(local_participants);
static auto& rendezvous_map =
*new RefcountingHashMap<RendezvousKey, NcclCliqueRendezvous>();
NcclCliqueParticipantData participant(rendezvous_key, local_device_ordinal,
stream);
return NcclCliqueRendezvous::SubmitParticipant(
/*rendezvous_getter=*/
[&] {
return rendezvous_map.GetOrCreateIfAbsent(
rendezvous_key, [&](const RendezvousKey& rendezvous_key) {
return std::make_unique<NcclCliqueRendezvous>(
rendezvous_key, local_participants, callback);
});
},
participant);
}
} // namespace gpu
} // namespace xla