blob: ba21abcbaa8bdf3f738055830d587194863ddb71 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include <stddef.h>
#include <algorithm>
#include <unordered_map>
#include <utility>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
void CollectiveParamResolverLocal::InstanceRec::WaitForOutMu(mutex_lock& lock) {
while (!out_mu_available) out_cv.wait(lock);
}
CollectiveParamResolverLocal::CollectiveParamResolverLocal(
const ConfigProto& config, const DeviceMgr* dev_mgr,
DeviceResolverInterface* dev_resolver, const string& task_name)
: nccl_(config.experimental().collective_nccl()),
dev_mgr_(dev_mgr),
dev_resolver_(dev_resolver),
task_name_(task_name) {}
void CollectiveParamResolverLocal::CompleteGroupAsync(
const CompleteGroupRequest* request, CompleteGroupResponse* response,
CancellationManager* cancel_mgr, const StatusCallback& done) {
done(
errors::Internal("CompleteGroup is not implemented by "
"CollectiveParamResolverLocal which is "
"intended only for non-distributed deployment."));
}
namespace {
const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
switch (cp->instance.type) {
case BROADCAST_COLLECTIVE:
return "HierarchicalTreeBroadcast";
case REDUCTION_COLLECTIVE:
return nccl ? "NcclReduce" : "RingReduce";
case GATHER_COLLECTIVE:
return "RingGather";
default:
return "undef";
}
}
} // namespace
void CollectiveParamResolverLocal::CompleteGroupLocal(
const string& device, CollectiveParams* cp, const GroupRecCallback& done) {
VLOG(1) << "CompleteGroupLocal device=" << device << " cp: " << cp << ": "
<< cp->ToString();
std::vector<StatusCallback> to_be_called;
GroupRec* gr = nullptr;
Status status;
{
mutex_lock l(group_mu_);
auto it = group_table_.find(cp->group.group_key);
if (it == group_table_.end()) {
gr = new GroupRec;
gr->group.group_key = cp->group.group_key;
gr->group.group_size = cp->group.group_size;
gr->group.device_type = cp->group.device_type;
// Initialize group runtime details.
CollectiveImplementationInterface* col_impl;
// Try to lookup a NCCL collective kernel. This will return error status
// if `NcclReduce` kernel is not present in the registry, e.g. on an
// environment that does not support NCCL.
status = CollectiveRegistry::LookupParamResolverInstance("NcclReduce",
&col_impl);
if (!status.ok()) {
// Fallback to non-NCCL collective.
status = CollectiveRegistry::LookupParamResolverInstance(
GetCollectiveName(cp, /*nccl=*/false), &col_impl);
}
if (status.ok()) {
status = col_impl->InitializeCollectiveGroupRuntimeDetails(
&gr->group.runtime_details);
}
if (!status.ok()) {
done(status, gr);
return;
}
// Store GroupRec in group_table_ which is shared between all devices on
// this worker.
group_table_[gr->group.group_key].reset(gr);
VLOG(2) << "New group_key=" << gr->group.group_key
<< " group_size=" << gr->group.group_size
<< " runtime_details=" << gr->group.runtime_details.ToString();
} else {
gr = it->second.get();
}
}
{
mutex_lock l(status_mu_);
status = status_;
}
if (!status.ok()) {
done(status, nullptr);
return;
}
{
mutex_lock gr_lock(gr->mu);
// If there is ever an error associated with a group key, we store the error
// status and invoke all waiting and future callbacks with this error
// status.
VLOG(2) << "gr device_type=" << gr->group.device_type
<< " cp device_type=" << cp->group.device_type
<< " current device=" << device;
if (gr->status.ok()) {
// Check for consistency with existing GroupRec.
if (cp->group.device_type != gr->group.device_type) {
gr->status = errors::Internal(
"Collective Op ", cp->name, " is assigned to device ", device,
" with type ", cp->group.device_type.type_string(),
" and group_key ", cp->group.group_key, " but that group has type ",
gr->group.device_type.type_string());
} else if (cp->group.group_size != gr->group.group_size) {
gr->status = errors::Internal(
"Collective Op ", cp->name, " has group_size ",
cp->group.group_size, " and group_key ", cp->group.group_key,
" but that group has size ", gr->group.group_size);
}
}
if (gr->status.ok()) {
// Insert device if not already present.
auto it = gr->device_set.find(device);
if (it == gr->device_set.end()) {
if (gr->device_set.size() == gr->group.group_size) {
// The group is already full.
gr->status = errors::Internal(
"Collective Op ", cp->name, " is assigned to device ", device,
" and group_key ", cp->group.group_key,
" but that group doesn't contain that device.");
} else {
// This is a new device that has not yet joined the group.
gr->device_set.insert(device);
gr->device_list.push_back(device);
DeviceNameUtils::ParsedName parsed_device;
DeviceNameUtils::ParseFullName(device, &parsed_device);
string task_name = strings::StrCat("/job:", parsed_device.job,
"/replica:", parsed_device.replica,
"/task:", parsed_device.task);
gr->task_set.insert(task_name);
gr->task_list.push_back(task_name);
gr->group.num_tasks = static_cast<int32>(gr->task_set.size());
if (VLOG_IS_ON(1)) {
string dev_buf;
for (const auto& d : gr->device_set) {
strings::StrAppend(&dev_buf, ",", d);
}
VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
<< " group_size=" << gr->group.group_size << " (current"
<< " devices)=(" << dev_buf << ") (number of"
<< " devices pending)="
<< (gr->group.group_size - gr->device_set.size());
}
}
}
}
if (gr->status.ok()) {
cp->group.runtime_details = gr->group.runtime_details;
// If the group is not yet complete, queue to wait for it.
VLOG(2) << "group_size " << gr->group.group_size << " set size "
<< gr->device_set.size() << " gr " << gr;
if (gr->device_set.size() < gr->group.group_size) {
gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
return;
}
CHECK_EQ(gr->device_set.size(), gr->group.group_size);
}
// At this point, we either have a full group, or an error status. Ensure
// that all callbacks are invoked with the appropriate status.
if (!gr->waiting.empty()) {
std::swap(to_be_called, gr->waiting);
}
status = gr->status;
}
done(status, gr);
for (int i = 0; i < to_be_called.size(); ++i) {
to_be_called[i](status);
}
}
namespace {
struct DevRec {
string task;
string device;
int original_rank;
int local_rank;
int global_rank;
const DeviceLocality* locality;
};
typedef std::unordered_map<string, DevRec> TaskDeviceMap;
typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
// Create a populated GlobalDeviceMap from CollInstanceParams and localities.
GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip,
const std::vector<DeviceAttributes>& attributes) {
GlobalDeviceMap gdm;
CHECK_EQ(ip.device_names.size(), ip.task_names.size());
CHECK_EQ(ip.device_names.size(), attributes.size());
for (int i = 0; i < ip.device_names.size(); ++i) {
TaskDeviceMap& tdm = gdm[ip.task_names[i]];
DevRec* dr = &tdm[ip.device_names[i]];
dr->task = ip.task_names[i];
dr->device = ip.device_names[i];
dr->original_rank = i;
dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap.
dr->global_rank = 0; // Will be populated later by EstablishGlobalRank.
dr->locality = &attributes[i].locality();
}
return gdm;
}
bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) {
std::vector<string> split_gpu_ring_order_str =
str_util::Split(gpu_ring_order_str, ',');
if (split_gpu_ring_order_str.size() != tdm->size()) return false;
// gpu id -> local rank
gtl::FlatMap<int32, int32> gpu_ranks;
for (int32 rank = 0;
rank < static_cast<int32>(split_gpu_ring_order_str.size()); ++rank) {
int32 tmp;
if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) {
gpu_ranks[tmp] = rank;
} else {
return false;
}
}
for (auto& tdm_it : *tdm) {
DeviceNameUtils::ParsedName parsed_name;
DevRec* dr = &tdm_it.second;
if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) {
return false;
}
auto rank_it = gpu_ranks.find(parsed_name.id);
if (rank_it == gpu_ranks.end()) return false;
dr->local_rank = rank_it->second;
}
VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str;
return true;
}
void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices
// If a valid ring order has been passed in via ConfigProto, use that.
if (ParseRingOrder(gpu_ring_order, tdm)) return;
// Either no ring order was passed in, or the format was unexpected.
// We now assign a ring order based on link strengths. Note that this
// algorithm is not optimal and may not always find the best ring order.
int least_rank = -1;
string next_device;
std::set<string> selected;
// Starting device is one with the least initial rank.
for (const auto& it : *tdm) {
if (least_rank < 0 || it.second.original_rank < least_rank) {
least_rank = it.second.original_rank;
next_device = it.second.device;
}
}
CHECK_GE(least_rank, 0);
DeviceNameUtils::ParsedName parsed_name;
CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
// NOTE: InterconnectLink has only a device_id, nothing more, so for
// the time being if there's more than one device at a task we
// assume they're all GPUs.
int next_rank = 0;
while (true) {
selected.insert(next_device);
auto next_dev_it = tdm->find(next_device);
CHECK(next_dev_it != tdm->end());
DevRec* dr = &next_dev_it->second;
dr->local_rank = next_rank;
++next_rank;
if (selected.size() == tdm->size()) {
break;
}
// For the present time we assume Locality links only cover GPUs.
// For multiple CPUs, just take them in order.
const InterconnectLink* best_link = nullptr;
if (parsed_name.type == "GPU") {
for (const InterconnectLink& il : dr->locality->links().link()) {
parsed_name.id = il.device_id();
string endpoint_device =
DeviceNameUtils::ParsedNameToString(parsed_name);
// Skip the device if we've already seen it.
if (selected.find(endpoint_device) != selected.end()) {
continue;
}
// Skip the device if it is not participating in this collective
// instance.
if (tdm->find(endpoint_device) == tdm->end()) {
continue;
}
if (best_link == nullptr || il.strength() > best_link->strength()) {
best_link = &il;
}
}
}
if (best_link != nullptr) {
// Follow the best edge
parsed_name.id = best_link->device_id();
next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
} else {
// No good edges, alas. Pick the lowest initial rank among remaining
// devices.
least_rank = -1;
for (const auto& it : *tdm) {
if (selected.find(it.second.device) != selected.end()) {
continue;
}
if (least_rank < 0 || it.second.original_rank < least_rank) {
least_rank = it.second.original_rank;
next_device = it.second.device;
}
}
CHECK_GE(least_rank, 0);
}
}
}
// The first time a shared CollectiveParams is established for a
// shared set of instances we compute a good rank order for all the
// devices in the group, that is appropriate for a ring algorithm.
// This order need not be the same across different instance groups
// sharing the same device group where there is more than one good
// order.
GlobalDeviceMap EstablishGlobalRank(
CollectiveParams* cp, const std::vector<DeviceAttributes>& attributes) {
VLOG(1) << "EstablishGlobalRank";
GlobalDeviceMap gdm = BuildDevRecs(cp->instance, attributes);
for (auto& iter : gdm) {
TaskDeviceMap& tdm = iter.second;
OrderTaskDeviceMap(cp->instance.gpu_ring_order, &tdm);
}
// Connect the global rank order by the order in which tasks first appear.
std::set<string> ordered_tasks;
int next_rank = 0;
for (int i = 0; i < cp->instance.task_names.size(); ++i) {
const string& task_name = cp->instance.task_names[i];
if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
continue;
}
ordered_tasks.insert(task_name);
TaskDeviceMap* tdm = &gdm[task_name];
for (auto& it : *tdm) {
it.second.global_rank = it.second.local_rank + next_rank;
}
next_rank += tdm->size();
}
return gdm;
}
// Count the devices associated with each task and set
// cp->same_num_devices_per_task. Requires cp->instance.task_names
// be sorted.
void SetDevPerTask(CollectiveParams* cp) {
cp->instance.num_devices_per_task.clear();
const string* last_task_name = &cp->instance.task_names[0];
int count = 0;
for (const string& task_name : cp->instance.task_names) {
if (task_name == *last_task_name) {
++count;
} else {
cp->instance.num_devices_per_task[*last_task_name] = count;
count = 1;
last_task_name = &task_name;
}
}
cp->instance.num_devices_per_task[*last_task_name] = count;
cp->instance.same_num_devices_per_task = false;
int dev_per_task = -1;
for (const auto& task_dev : cp->instance.num_devices_per_task) {
if (dev_per_task == -1) {
dev_per_task = task_dev.second;
} else if (dev_per_task != task_dev.second) {
return;
}
}
cp->instance.same_num_devices_per_task = true;
CHECK_EQ((cp->group.group_size % cp->group.num_tasks), 0);
}
// Sort cp->instance.device_names lexicographically, but do by first
// computing a reordering permutation so we can keep cp->instance.task_names
// in corresponding order.
void SortDevicesAndTasks(CollectiveParams* cp) {
VLOG(1) << "SortDevicesAndTasks " << cp << " instance " << &cp->instance;
CHECK(cp);
CHECK_EQ(cp->group.group_size, cp->instance.device_names.size());
CHECK_EQ(cp->group.group_size, cp->instance.task_names.size());
std::vector<int> perm(cp->group.group_size);
// TODO(tucker): substitute std::iota when the windows build supports it.
// std::iota(perm.begin(), perm.end(), 0);
for (int i = 0; i < perm.size(); ++i) {
perm[i] = i;
}
std::sort(perm.begin(), perm.end(), [cp](int a, int b) {
return cp->instance.device_names[a] < cp->instance.device_names[b];
});
std::vector<string> new_devs;
std::vector<string> new_tasks;
new_devs.reserve(cp->group.group_size);
new_tasks.reserve(cp->group.group_size);
for (int pi : perm) {
new_devs.push_back(cp->instance.device_names[pi]);
new_tasks.push_back(cp->instance.task_names[pi]);
}
cp->instance.device_names = std::move(new_devs);
cp->instance.task_names = std::move(new_tasks);
VLOG(1) << "Modified device_names on " << cp;
SetDevPerTask(cp);
}
} // namespace
void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
CollectiveParams* cp) {
cp->task.is_local.resize(cp->group.group_size, false);
for (int i = 0; i < cp->group.group_size; ++i) {
cp->task.is_local[i] = (cp->instance.task_names[i] == task_name);
}
}
void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
CollectiveParams* cp) {
CHECK_EQ(cp->group.group_size, cp->instance.device_names.size()) << cp;
for (int i = 0; i < cp->group.group_size; ++i) {
if (cp->instance.device_names[i] == device) {
cp->default_rank = i;
break;
}
}
}
void CollectiveParamResolverLocal::InitInstanceSharedParams(
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
const StatusCallback& done) {
ir->shared.instance = cp->instance;
{
mutex_lock gl(gr->mu);
ir->shared.group = gr->group;
ir->shared.instance.device_names.assign(gr->device_list.begin(),
gr->device_list.end());
ir->shared.instance.task_names.assign(gr->task_list.begin(),
gr->task_list.end());
VLOG(2) << "Initialized names for instance: "
<< ir->shared.instance.ToString();
}
ir->shared.default_rank = -1;
// Sort device_names lexicographically, keeping task_names in corresponding
// order. Also set number of devices per task.
SortDevicesAndTasks(&ir->shared);
// Get Locality data for all devices.
// Set is_local and task_names in *shared prior to invoking
// GetDeviceAttributesAsync. In a distributed context this function can be
// called by a derived class, some of the devices may be non-local and
// GetDeviceAttributesAsync will use those fields to launch RPCs.
CompleteTaskIsLocal(task_name_, &ir->shared);
// Because the callback may execute in a different thread, we release
// ir->out_mu here. Before releasing, we mark it as unavailable for other
// threads.
ir->out_mu_available = false;
const auto device_names = ir->shared.instance.device_names;
const auto task_names = ir->shared.instance.task_names;
ir->out_mu.unlock();
std::vector<DeviceAttributes>* attributes = new std::vector<DeviceAttributes>;
// Suppress linter warning about access to shared without mutex because in
// principle the members are locked due to out_mu_available=false.
dev_resolver_->GetAllDeviceAttributesAsync(
ir->shared.instance.device_names, // NOLINT
ir->shared.instance.task_names, // NOLINT
attributes,
[this, gr, cp, ir, attributes, done](const Status& s)
TF_EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) {
// Then we recover the lock in the callback thread that will hold it
// through the rest of the call chain. Signal the cv now, any
// waiting threads will wake only when out_mu is released later.
ir->out_mu.lock();
DCHECK(!ir->out_mu_available);
ir->out_mu_available = true;
ir->out_cv.notify_all();
if (s.ok()) {
CompleteDefaultRanking(gr, cp, ir, *attributes);
done(Status::OK());
} else {
done(s);
}
delete attributes;
});
}
// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
// to all devices that they are physically connected to and visible to the
// TensorFlow runtime. This set of devices may be a superset of the devices
// participating in this instance of collectives.
void CollectiveParamResolverLocal::CompleteDefaultRanking(
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
const std::vector<DeviceAttributes>& attributes) {
// Establish an instance-specific default rank order for devices
// based on localities. This rank order should be a good ring
// order, if possible.
GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, attributes);
// Reflect the new global ranking on shared
size_t num_devices = ir->shared.group.group_size;
std::vector<string> new_device_names(num_devices, "");
std::vector<string> new_task_names(num_devices, "");
for (const auto& git : gdm) {
const TaskDeviceMap& tdm = git.second;
for (const auto& tit : tdm) {
const DevRec& dr = tit.second;
new_device_names[dr.global_rank] =
ir->shared.instance.device_names[dr.original_rank];
new_task_names[dr.global_rank] =
ir->shared.instance.task_names[dr.original_rank];
}
}
ir->shared.instance.device_names = new_device_names;
ir->shared.instance.task_names = new_task_names;
if (VLOG_IS_ON(2)) {
string buf;
for (const auto& d : new_device_names) strings::StrAppend(&buf, "\n", d);
VLOG(2) << "Optimized device order for " << ir->shared.name << ": " << buf;
}
}
void CollectiveParamResolverLocal::CallbackWithStatus(
const InstanceRecCallback& done, InstanceRec* irec) {
Status s;
{
mutex_lock l(irec->out_mu);
irec->WaitForOutMu(l);
s = irec->status;
}
done(s, irec);
}
void CollectiveParamResolverLocal::FindInstanceRec(
const GroupRec* gr, CollectiveParams* cp, const InstanceRecCallback& done) {
InstanceRec* irec = nullptr;
bool exit_outside_locks = false;
{
bool found_instance = false;
mutex_lock l(instance_mu_);
auto group_it = instance_table_.find(gr->group.group_key);
if (group_it != instance_table_.end()) {
auto instance_it = group_it->second.find(cp->instance.instance_key);
if (instance_it != group_it->second.end()) {
irec = instance_it->second.get();
{
mutex_lock l(irec->in_mu);
if (irec->is_init) {
exit_outside_locks = true;
} else {
irec->init_waiters.push_back([this, done](InstanceRec* irec) {
CallbackWithStatus(done, irec);
});
return;
}
}
found_instance = true;
}
}
if (!found_instance) {
// Create new InstanceRec.
irec = new InstanceRec;
instance_table_[gr->group.group_key][cp->instance.instance_key].reset(
irec);
}
}
Status status;
{
mutex_lock l(status_mu_);
status = status_;
}
if (!status.ok()) {
mutex_lock il(irec->out_mu);
irec->WaitForOutMu(il);
irec->status = status;
}
if (exit_outside_locks) {
CallbackWithStatus(done, irec);
return;
}
CallInitInstanceSharedParams(gr, cp, irec, done);
}
void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
const InstanceRecCallback& done) TF_NO_THREAD_SAFETY_ANALYSIS {
// This function serves merely to make a function call that should
// be thread/mutex safe but violates the simple model applied by
// static analysis, so we turn off analysis only within this
// function body.
//
// A lock on ir->out_mu must be held* throughout the _bodies_ of the
// chain of function calls initiated here, each of which calls
// another as its last action, but it will be dropped within the
// callback defined below, which means that the lock can be dropped
// before all the function stack frames pop. The static analysis will
// not allow that.
//
// *the lock is dropped just before calling GetDeviceAttributesAsync, because
// there is no guarantee that the thread that executes the callback is the
// same as the one that locked ir->out_mu. To prevent other threads from
// grabbing ir->out_mu, we mark ir->out_mu_available as false. Hence, in
// principle, the lock is held throughout.
ir->out_mu.lock();
DCHECK(ir->out_mu_available);
ir->known.resize(cp->group.group_size, false);
InitInstanceSharedParams(
gr, cp, ir,
[this, ir, done](const Status& s) TF_UNLOCK_FUNCTION(ir->out_mu) {
DCHECK(ir->out_mu_available);
ir->status.Update(s);
ir->out_mu.unlock();
// Prepare to invoke any waiters that accumulated during
// initialization.
std::vector<IRConsumer> init_waiters;
{
mutex_lock tl(instance_mu_);
{
mutex_lock l(ir->in_mu);
ir->is_init = true;
if (!ir->init_waiters.empty()) {
std::swap(init_waiters, ir->init_waiters);
}
}
}
CallbackWithStatus(done, ir);
for (auto& f : init_waiters) {
f(ir);
}
});
}
void CollectiveParamResolverLocal::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
const StatusCallback& done) {
VLOG(1) << "CompleteParams local " << device << " for " << cp << ": "
<< cp->ToString();
CompleteGroupLocal(
device, cp,
[this, device, cp, done](const Status& s, const GroupRec* gr) {
if (s.ok()) {
CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
} else {
done(s);
}
});
}
void CollectiveParamResolverLocal::CompleteInstanceAsync(
const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
CancellationManager* cancel_mgr, const StatusCallback& done) {
done(
errors::Internal("CompleteInstance is not implemented by "
"CollectiveParamResolverLocal which is "
"intended only for non-distributed deployment."));
}
// TODO(b/111897089): we need a better way to pick the collective
// implementation. The ideal way would depend upon the topology and link
// strength before picking a particular implementation.
void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
// We use the NCCL implementation if this is an environment which supports
// NCCL, i.e. `LookupParamResolverInstance` for `NcclReduce` returns OK, and
// also if indicated either in `ConfigProto` or `communication_hint`.
//
// After enough testing, we may simplify this logic to use NCCL whenever
// available.
CollectiveImplementationInterface* col_impl;
bool use_nccl =
(nccl_ || cp->instance.impl_details.communication_hint == "nccl") &&
CollectiveRegistry::LookupParamResolverInstance("NcclReduce", &col_impl)
.ok();
cp->instance.impl_details.collective_name = GetCollectiveName(cp, use_nccl);
VLOG(1) << "AssignCollectiveType "
<< cp->instance.impl_details.collective_name;
}
void CollectiveParamResolverLocal::CompleteInstanceLocal(
const string& device, const GroupRec* gr, CollectiveParams* cp,
bool is_source, const StatusCallback& done) {
VLOG(1) << "CompleteInstanceLocal " << device
<< " instance_key: " << cp->instance.instance_key << " gr " << gr;
// Populate the group portion of *cp from *gr. Most of it should already
// match.
DCHECK_EQ(cp->group.group_key, gr->group.group_key);
DCHECK_EQ(cp->group.group_size, gr->group.group_size);
DCHECK_EQ(cp->group.device_type, gr->group.device_type);
cp->group = gr->group;
// Get the shared InstanceRec for this instance.
FindInstanceRec(gr, cp,
[this, device, gr, cp, is_source, done](const Status& s,
InstanceRec* ir) {
if (s.ok()) {
CompleteInstanceFromInitializedIRec(device, gr, cp, ir,
is_source, done);
} else {
done(s);
}
});
}
void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
const string& device, const GroupRec* gr, CollectiveParams* cp,
InstanceRec* ir, bool is_source, const StatusCallback& done) {
auto expected_shape = cp->instance.shape;
// Populate the fields common across instance.
{
mutex_lock l(ir->out_mu);
ir->WaitForOutMu(l);
// custom operator= does a deep copy.
cp->instance = ir->shared.instance;
}
if (expected_shape != cp->instance.shape) {
done(errors::InvalidArgument(
"Shape mismatch in the collective instance ", cp->instance.instance_key,
". Op at device ", device, " expected shape ",
expected_shape.DebugString(), " but another member in the group ",
"expected shape ", cp->instance.shape.DebugString(), ". This is likely",
" due to different input shapes at different members of the collective",
" op."));
return;
}
// Populate the fields common across task.
AssignCollectiveType(cp);
SetDefaultRank(device, cp);
CompleteTaskIsLocal(task_name_, cp);
CollectiveImplementationInterface* col_impl;
Status status = CollectiveRegistry::LookupParamResolverInstance(
cp->instance.impl_details.collective_name, &col_impl);
if (!status.ok()) {
done(status);
return;
}
// We may need to wait for the group, if this is a broadcast, for source
// discovery.
if (cp->instance.type == BROADCAST_COLLECTIVE) {
WaitForGroup(ir, cp, is_source,
[col_impl, ir, device, cp, done](InstanceRec* irec) {
Status s;
if (ir != irec) {
s = errors::Internal("Expected ir ", ir, " and irec ",
irec, " to be equal");
} else {
mutex_lock l(irec->out_mu);
irec->WaitForOutMu(l);
s = irec->status;
cp->source_rank = irec->source_rank;
}
if (s.ok()) {
s = col_impl->InitializeCollectiveParams(cp);
}
done(s);
});
} else {
done(col_impl->InitializeCollectiveParams(cp));
}
}
void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
CollectiveParams* cp,
bool is_source,
const IRConsumer& f) {
std::vector<IRConsumer> ready_waiters;
do {
mutex_lock l(ir->out_mu);
ir->WaitForOutMu(l);
if (!ir->status.ok()) {
break;
}
CHECK_EQ(cp->group.group_size, ir->known.size());
CHECK_GE(cp->default_rank, 0);
if (!ir->known[cp->default_rank]) {
ir->known[cp->default_rank] = true;
++ir->known_count;
if (is_source) {
// Initialize source rank.
if (ir->source_rank >= 0) {
ir->status = errors::Internal("Instance ", cp->instance.instance_key,
" already has source ", ir->source_rank,
", received second claim from ",
cp->default_rank);
} else {
ir->source_rank = cp->default_rank;
}
}
}
if (ir->known_count < ir->shared.group.group_size) {
ir->known_waiters.push_back(f);
return;
}
CHECK_EQ(ir->known_count, ir->shared.group.group_size);
if (ir->source_rank < 0) {
// NOTE(ayushd): changing the error message below would also require
// updating CompleteParamsBroadcastForgotSend test in
// CollectiveParamResolverLocalTest.
ir->status =
errors::Internal("Instance ", cp->instance.instance_key,
" found no source for broadcast. This "
"could mean that there were group_size=",
ir->known_count, " BcastRecvs but no BcastSend.");
}
if (!ir->known_waiters.empty()) {
ready_waiters = std::move(ir->known_waiters);
}
} while (false);
f(ir);
for (auto& f : ready_waiters) {
f(ir);
}
}
void CollectiveParamResolverLocal::StartAbort(const Status& s) {
{
mutex_lock l(status_mu_);
if (!status_.ok()) {
VLOG(1) << "CollectiveParamResolverLocal already aborted. Ignoring "
"subsequent abortion with status: "
<< s;
return;
}
status_ = s;
}
StartAbortLocal(s);
}
void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) {
{
mutex_lock l(group_mu_);
for (const auto& item : group_table_) {
GroupRec* gr = item.second.get();
std::vector<StatusCallback> waiting;
{
mutex_lock gl(gr->mu);
gr->status = s;
waiting.swap(gr->waiting);
}
for (const StatusCallback& done : waiting) {
done(s);
}
}
}
std::vector<InstanceRec*> instances;
{
mutex_lock l(instance_mu_);
for (const auto& group_entry : instance_table_) {
for (const auto& item : group_entry.second) {
instances.push_back(item.second.get());
}
}
}
for (InstanceRec* ir : instances) {
std::vector<IRConsumer> known_waiters;
{
mutex_lock il(ir->out_mu);
ir->WaitForOutMu(il);
ir->status = s;
known_waiters.swap(ir->known_waiters);
}
for (const IRConsumer& done : known_waiters) {
done(ir);
}
}
}
} // namespace tensorflow