blob: 28ceec265c4faaeb0faedbcdd7a15c0857191eb5 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include <stddef.h>
#include <algorithm>
#include <unordered_set>
#include <utility>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
CollectiveParamResolverLocal::CollectiveParamResolverLocal(
const ConfigProto& config, const DeviceMgr* dev_mgr,
DeviceResolverInterface* dev_resolver, const string& task_name)
: nccl_(config.experimental().collective_nccl()),
dev_mgr_(dev_mgr),
dev_resolver_(dev_resolver),
task_name_(task_name) {}
void CollectiveParamResolverLocal::CompleteGroupAsync(
const CompleteGroupRequest* request, CompleteGroupResponse* response,
CancellationManager* cancel_mgr, const StatusCallback& done) {
done(
errors::Internal("CompleteGroup is not implemented by "
"CollectiveParamResolverLocal which is "
"intended only for non-distributed deployment."));
}
namespace {
const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
switch (cp->instance.type) {
case BROADCAST_COLLECTIVE:
return nccl ? "NcclBroadcast" : "HierarchicalTreeBroadcast";
case REDUCTION_COLLECTIVE:
return nccl ? "NcclReduce" : "RingReduce";
case GATHER_COLLECTIVE:
return nccl ? "NcclGather" : "RingGather";
case PERMUTE_COLLECTIVE:
return "Permute";
default:
return "undef";
}
}
string TaskNameFromDeviceName(const string& device_name) {
DeviceNameUtils::ParsedName parsed_device;
CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
string task_name;
CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
return task_name;
}
} // namespace
void CollectiveParamResolverLocal::CompleteGroupLocal(
const DeviceAttributes& device, CollectiveParams* cp,
const GroupRecCallback& done, CancellationManager* cancel_mgr) {
VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp
<< ": " << cp->ToString();
std::vector<StatusCallback> to_be_called;
// Keep a reference to `cp` to avoid racing with deletion due to cancellation.
cp->Ref();
core::ScopedUnref cp_unref(cp);
std::function<void(const Status& s, GroupRec* gr)> done_with_cleanup;
if (cancel_mgr != nullptr) {
const CancellationToken token = cancel_mgr->get_cancellation_token();
const bool already_cancelled = !cancel_mgr->RegisterCallback(
token, [done]() { done(errors::Cancelled("op cancelled"), nullptr); });
if (already_cancelled) {
done(errors::Cancelled("op cancelled"), nullptr);
return;
}
done_with_cleanup = [cancel_mgr, done, token](const Status& s,
GroupRec* gr) {
if (cancel_mgr == nullptr || cancel_mgr->TryDeregisterCallback(token)) {
// The operation was never cancelled, so we'll return a normal status.
done(s, gr);
}
};
} else {
done_with_cleanup = done;
}
GroupRec* gr = nullptr;
Status status;
{
mutex_lock l(group_mu_);
auto it = group_table_.find(cp->group.group_key);
if (it == group_table_.end()) {
gr = new GroupRec;
mutex_lock grl(gr->mu);
gr->group.group_key = cp->group.group_key;
gr->group.group_size = cp->group.group_size;
gr->group.device_type = cp->group.device_type;
gr->group.gpu_ring_order = cp->group.gpu_ring_order;
// Initialize group runtime details.
CollectiveImplementationInterface* col_impl;
// Try to lookup a NCCL collective kernel. This will return error status
// if `NcclReduce` kernel is not present in the registry, e.g. on an
// environment that does not support NCCL.
status = CollectiveRegistry::LookupParamResolverInstance("NcclReduce",
&col_impl);
if (!status.ok()) {
// Fallback to non-NCCL collective.
status = CollectiveRegistry::LookupParamResolverInstance(
GetCollectiveName(cp, /*nccl=*/false), &col_impl);
}
if (status.ok()) {
status = col_impl->InitializeCollectiveGroupRuntimeDetails(
&gr->group.runtime_details);
}
if (!status.ok()) {
done_with_cleanup(status, gr);
return;
}
// Store GroupRec in group_table_ which is shared between all devices on
// this worker.
group_table_[gr->group.group_key].reset(gr);
VLOG(2) << "New group_key=" << gr->group.group_key
<< " group_size=" << gr->group.group_size
<< " runtime_details=" << gr->group.runtime_details.ToString();
} else {
gr = it->second.get();
}
}
{
mutex_lock l(status_mu_);
status = status_;
}
if (!status.ok()) {
done_with_cleanup(status, nullptr);
return;
}
{
mutex_lock gr_lock(gr->mu);
// If there is ever an error associated with a group key, we store the error
// status and invoke all waiting and future callbacks with this error
// status.
VLOG(2) << "gr device_type=" << gr->group.device_type
<< " cp device_type=" << cp->group.device_type
<< " current device=" << device.name();
if (gr->status.ok()) {
// Check for consistency with existing GroupRec.
if (cp->group.device_type != gr->group.device_type) {
gr->status = errors::Internal(
"Collective Op ", cp->name, " is assigned to device ",
device.name(), " with type ", cp->group.device_type.type_string(),
" and group_key ", cp->group.group_key, " but that group has type ",
gr->group.device_type.type_string());
} else if (cp->group.group_size != gr->group.group_size) {
gr->status = errors::Internal(
"Collective Op ", cp->name, " has group_size ",
cp->group.group_size, " and group_key ", cp->group.group_key,
" but that group has size ", gr->group.group_size);
}
}
bool new_device = false;
if (gr->status.ok()) {
// Insert device if not already present.
auto it = gr->devices.find(device.name());
if (it == gr->devices.end()) {
if (gr->devices.size() == gr->group.group_size) {
// The group is already full.
gr->status = errors::Internal(
"Collective Op ", cp->name, " is assigned to device ",
device.name(), " and group_key ", cp->group.group_key,
" but that group doesn't contain that device.");
} else {
// This is a new device that has not yet joined the group.
gr->devices[device.name()] = device;
new_device = true;
if (VLOG_IS_ON(1)) {
string dev_buf;
for (const auto& d : gr->devices) {
strings::StrAppend(&dev_buf, ",", d.first);
}
VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
<< " group_size=" << gr->group.group_size << " (current"
<< " devices)=(" << dev_buf << ") (number of"
<< " devices pending)="
<< (gr->group.group_size - gr->devices.size());
}
}
} else {
// If the device already exists, check if the incarnation matches.
if (it->second.incarnation() != device.incarnation()) {
gr->status = errors::FailedPrecondition(
"Device ", device.name(),
" current incarnation doesn't match with one in the group. This "
"usually means this worker has restarted but the collective "
"leader hasn't, or this worker connects to a wrong cluster.");
}
}
}
if (gr->status.ok()) {
// If the group is not yet complete, queue to wait for it.
VLOG(2) << "group_size " << gr->group.group_size << " set size "
<< gr->devices.size() << " gr " << gr;
if (gr->devices.size() < gr->group.group_size) {
gr->waiting.push_back(
std::bind(done_with_cleanup, std::placeholders::_1, gr));
return;
}
CHECK_EQ(gr->devices.size(), gr->group.group_size);
// We get a full group. Fill in remaining fields in gr->group.
if (new_device) {
FinishGroup(gr);
}
}
// At this point, we either have a full group, or an error status. Ensure
// that all callbacks are invoked with the appropriate status.
if (!gr->waiting.empty()) {
std::swap(to_be_called, gr->waiting);
}
status = gr->status;
}
done_with_cleanup(status, gr);
for (int i = 0; i < to_be_called.size(); ++i) {
to_be_called[i](status);
}
}
namespace {
struct DevRec {
string task;
string device;
int original_rank;
int local_rank;
int global_rank;
const DeviceLocality* locality;
};
typedef std::unordered_map<string, DevRec> TaskDeviceMap;
typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
// Create a populated GlobalDeviceMap from CollInstanceParams and localities.
GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp,
const std::vector<DeviceAttributes>& attributes) {
GlobalDeviceMap gdm;
CHECK_EQ(gp.device_names.size(), gp.task_names.size());
CHECK_EQ(gp.device_names.size(), attributes.size());
for (int i = 0; i < gp.device_names.size(); ++i) {
TaskDeviceMap& tdm = gdm[gp.task_names[i]];
DevRec* dr = &tdm[gp.device_names[i]];
dr->task = gp.task_names[i];
dr->device = gp.device_names[i];
dr->original_rank = i;
dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap.
dr->global_rank = 0; // Will be populated later by EstablishGlobalRank.
dr->locality = &attributes[i].locality();
}
return gdm;
}
bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) {
std::vector<string> split_gpu_ring_order_str =
str_util::Split(gpu_ring_order_str, ',');
if (split_gpu_ring_order_str.size() != tdm->size()) return false;
// gpu id -> local rank
gtl::FlatMap<int32, int32> gpu_ranks;
for (int32 rank = 0;
rank < static_cast<int32>(split_gpu_ring_order_str.size()); ++rank) {
int32 tmp;
if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) {
gpu_ranks[tmp] = rank;
} else {
return false;
}
}
for (auto& tdm_it : *tdm) {
DeviceNameUtils::ParsedName parsed_name;
DevRec* dr = &tdm_it.second;
if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) {
return false;
}
auto rank_it = gpu_ranks.find(parsed_name.id);
if (rank_it == gpu_ranks.end()) return false;
dr->local_rank = rank_it->second;
}
VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str;
return true;
}
void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices
// If a valid ring order has been passed in via ConfigProto, use that.
if (ParseRingOrder(gpu_ring_order, tdm)) return;
// Either no ring order was passed in, or the format was unexpected.
// We now assign a ring order based on link strengths. Note that this
// algorithm is not optimal and may not always find the best ring order.
int least_rank = -1;
string next_device;
std::set<string> selected;
// Starting device is one with the least initial rank.
for (const auto& it : *tdm) {
if (least_rank < 0 || it.second.original_rank < least_rank) {
least_rank = it.second.original_rank;
next_device = it.second.device;
}
}
CHECK_GE(least_rank, 0);
DeviceNameUtils::ParsedName parsed_name;
CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
// NOTE: InterconnectLink has only a device_id, nothing more, so for
// the time being if there's more than one device at a task we
// assume they're all GPUs.
int next_rank = 0;
while (true) {
selected.insert(next_device);
auto next_dev_it = tdm->find(next_device);
CHECK(next_dev_it != tdm->end());
DevRec* dr = &next_dev_it->second;
dr->local_rank = next_rank;
++next_rank;
if (selected.size() == tdm->size()) {
break;
}
// For the present time we assume Locality links only cover GPUs.
// For multiple CPUs, just take them in order.
const InterconnectLink* best_link = nullptr;
if (parsed_name.type == "GPU") {
for (const InterconnectLink& il : dr->locality->links().link()) {
parsed_name.id = il.device_id();
string endpoint_device =
DeviceNameUtils::ParsedNameToString(parsed_name);
// Skip the device if we've already seen it.
if (selected.find(endpoint_device) != selected.end()) {
continue;
}
// Skip the device if it is not participating in this collective
// instance.
if (tdm->find(endpoint_device) == tdm->end()) {
continue;
}
if (best_link == nullptr || il.strength() > best_link->strength()) {
best_link = &il;
}
}
}
if (best_link != nullptr) {
// Follow the best edge
parsed_name.id = best_link->device_id();
next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
} else {
// No good edges, alas. Pick the lowest initial rank among remaining
// devices.
least_rank = -1;
for (const auto& it : *tdm) {
if (selected.find(it.second.device) != selected.end()) {
continue;
}
if (least_rank < 0 || it.second.original_rank < least_rank) {
least_rank = it.second.original_rank;
next_device = it.second.device;
}
}
CHECK_GE(least_rank, 0);
}
}
}
// The first time a CollGroupParams is established for a group we compute a good
// rank order for all the devices in the group, that is appropriate for a ring
// algorithm.
GlobalDeviceMap EstablishGlobalRank(
const CollGroupParams& gp,
const std::vector<DeviceAttributes>& attributes) {
VLOG(1) << "EstablishGlobalRank";
GlobalDeviceMap gdm = BuildDevRecs(gp, attributes);
for (auto& iter : gdm) {
TaskDeviceMap& tdm = iter.second;
OrderTaskDeviceMap(gp.gpu_ring_order, &tdm);
}
// Connect the global rank order by the order in which tasks first appear.
std::set<string> ordered_tasks;
int next_rank = 0;
for (int i = 0; i < gp.task_names.size(); ++i) {
const string& task_name = gp.task_names[i];
if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
continue;
}
ordered_tasks.insert(task_name);
TaskDeviceMap* tdm = &gdm[task_name];
for (auto& it : *tdm) {
it.second.global_rank = it.second.local_rank + next_rank;
}
next_rank += tdm->size();
}
return gdm;
}
// Count the devices associated with each task and set
// gp->same_num_devices_per_task. Requires gp->task_names
// be sorted.
void SetDevPerTask(CollGroupParams* gp) {
gp->num_devices_per_task.clear();
const string* last_task_name = &gp->task_names[0];
int count = 0;
for (const string& task_name : gp->task_names) {
if (task_name == *last_task_name) {
++count;
} else {
gp->num_devices_per_task[*last_task_name] = count;
count = 1;
last_task_name = &task_name;
}
}
gp->num_devices_per_task[*last_task_name] = count;
gp->same_num_devices_per_task = false;
int dev_per_task = -1;
for (const auto& task_dev : gp->num_devices_per_task) {
if (dev_per_task == -1) {
dev_per_task = task_dev.second;
} else if (dev_per_task != task_dev.second) {
return;
}
}
gp->same_num_devices_per_task = true;
CHECK_EQ((gp->group_size % gp->num_tasks), 0);
}
// Sort gp->device_names lexicographically, but do by first
// computing a reordering permutation so we can keep gp->task_names
// in corresponding order.
void SortDevicesAndTasks(CollGroupParams* gp) {
VLOG(1) << "SortDevicesAndTasks " << gp << " " << gp;
CHECK(gp);
CHECK_EQ(gp->group_size, gp->device_names.size());
CHECK_EQ(gp->group_size, gp->task_names.size());
std::vector<int> perm(gp->group_size);
// TODO(tucker): substitute std::iota when the windows build supports it.
// std::iota(perm.begin(), perm.end(), 0);
for (int i = 0; i < perm.size(); ++i) {
perm[i] = i;
}
std::sort(perm.begin(), perm.end(), [gp](int a, int b) {
return gp->device_names[a] < gp->device_names[b];
});
std::vector<string> new_devs;
std::vector<string> new_tasks;
new_devs.reserve(gp->group_size);
new_tasks.reserve(gp->group_size);
for (int pi : perm) {
new_devs.push_back(gp->device_names[pi]);
new_tasks.push_back(gp->task_names[pi]);
}
gp->device_names = std::move(new_devs);
gp->task_names = std::move(new_tasks);
VLOG(1) << "Modified device_names on " << gp;
SetDevPerTask(gp);
}
} // namespace
void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) {
gr->group.device_names.reserve(gr->devices.size());
gr->group.task_names.reserve(gr->devices.size());
std::vector<DeviceAttributes> attributes;
// Unique tasks. It's used to calculate num_tasks.
std::unordered_set<string> tasks;
attributes.reserve(gr->devices.size());
for (const auto& item : gr->devices) {
gr->group.device_names.push_back(item.first);
string task_name = TaskNameFromDeviceName(item.first);
gr->group.task_names.push_back(task_name);
tasks.insert(task_name);
attributes.push_back(item.second);
}
gr->group.num_tasks = static_cast<int32>(tasks.size());
// Sort device_names lexicographically, keeping task_names in corresponding
// order. Also set number of devices per task.
SortDevicesAndTasks(&gr->group);
// Establish the final order of gp->device_names and gp->task_names by
// considering localities of all devices.
CompleteDefaultRanking(attributes, &gr->group);
}
void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
CollectiveParams* cp) {
cp->task.is_local.resize(cp->group.group_size, false);
for (int i = 0; i < cp->group.group_size; ++i) {
cp->task.is_local[i] = (cp->group.task_names[i] == task_name);
}
}
void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
CollectiveParams* cp) {
CHECK_EQ(cp->group.group_size, cp->group.device_names.size()) << cp;
for (int i = 0; i < cp->group.group_size; ++i) {
if (cp->group.device_names[i] == device) {
cp->default_rank = i;
break;
}
}
}
void CollectiveParamResolverLocal::InitInstanceSharedParams(
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
ir->shared->instance = cp->instance;
ir->shared->default_rank = -1;
// Set is_local and task_names in *shared prior to invoking
// GetDeviceAttributesAsync. In a distributed context this function can be
// called by a derived class, some of the devices may be non-local and
// GetDeviceAttributesAsync will use those fields to launch RPCs.
CompleteTaskIsLocal(task_name_, ir->shared);
}
// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
// to all devices that they are physically connected to and visible to the
// TensorFlow runtime. This set of devices may be a superset of the devices
// participating in this instance of collectives.
void CollectiveParamResolverLocal::CompleteDefaultRanking(
const std::vector<DeviceAttributes>& attributes, CollGroupParams* gp) {
// Establish an instance-specific default rank order for devices
// based on localities. This rank order should be a good ring
// order, if possible.
GlobalDeviceMap gdm = EstablishGlobalRank(*gp, attributes);
// Reflect the new global ranking on shared
size_t num_devices = gp->group_size;
std::vector<string> new_device_names(num_devices, "");
std::vector<string> new_task_names(num_devices, "");
for (const auto& git : gdm) {
const TaskDeviceMap& tdm = git.second;
for (const auto& tit : tdm) {
const DevRec& dr = tit.second;
new_device_names[dr.global_rank] = gp->device_names[dr.original_rank];
new_task_names[dr.global_rank] = gp->task_names[dr.original_rank];
}
}
gp->device_names = new_device_names;
gp->task_names = new_task_names;
if (VLOG_IS_ON(2)) {
string buf;
for (const auto& d : new_device_names) strings::StrAppend(&buf, "\n", d);
VLOG(2) << "Optimized device order for group " << gp->group_key << ": "
<< buf;
}
}
CollectiveParamResolverLocal::InstanceRec*
CollectiveParamResolverLocal::GetOrCreateInstanceRec(const GroupRec* gr,
CollectiveParams* cp,
bool* created) {
*created = false;
InstanceRec* irec = nullptr;
{
mutex_lock l(instance_mu_);
auto group_it = instance_table_.find(gr->group.group_key);
if (group_it != instance_table_.end()) {
auto instance_it = group_it->second.find(cp->instance.instance_key);
if (instance_it != group_it->second.end()) {
irec = instance_it->second.get();
}
}
if (irec == nullptr) {
// Create new InstanceRec.
irec = new InstanceRec;
*created = true;
{
mutex_lock il(irec->mu);
irec->known.resize(cp->group.group_size, false);
}
InitInstanceSharedParams(gr, cp, irec);
instance_table_[gr->group.group_key][cp->instance.instance_key].reset(
irec);
}
}
Status status;
{
mutex_lock l(status_mu_);
status = status_;
}
if (!status.ok()) {
mutex_lock l(irec->mu);
irec->status = status;
}
return irec;
}
void CollectiveParamResolverLocal::CompleteParamsAsync(
const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr, const StatusCallback& done) {
VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
<< cp->ToString();
CompleteGroupLocal(
device, cp,
[this, device, cp, done](const Status& s, const GroupRec* gr) {
if (s.ok()) {
CompleteInstanceLocal(device.name(), gr, cp, cp->is_source, done);
} else {
done(s);
}
},
cancel_mgr);
}
void CollectiveParamResolverLocal::CompleteInstanceAsync(
const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
CancellationManager* cancel_mgr, const StatusCallback& done) {
done(
errors::Internal("CompleteInstance is not implemented by "
"CollectiveParamResolverLocal which is "
"intended only for non-distributed deployment."));
}
// TODO(b/111897089): we need a better way to pick the collective
// implementation. The ideal way would depend upon the topology and link
// strength before picking a particular implementation.
void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
// We use the NCCL implementation if this is an environment which supports
// NCCL, i.e. `LookupParamResolverInstance` for `NcclReduce` returns OK, and
// also if indicated either in `ConfigProto` or `communication_hint`.
//
// After enough testing, we may simplify this logic to use NCCL whenever
// available.
CollectiveImplementationInterface* col_impl;
bool use_nccl =
(nccl_ || cp->instance.impl_details.communication_hint == "nccl") &&
CollectiveRegistry::LookupParamResolverInstance("NcclReduce", &col_impl)
.ok();
cp->instance.impl_details.collective_name = GetCollectiveName(cp, use_nccl);
VLOG(1) << "AssignCollectiveType "
<< cp->instance.impl_details.collective_name;
}
void CollectiveParamResolverLocal::CompleteInstanceLocal(
const string& device, const GroupRec* gr, CollectiveParams* cp,
bool is_source, const StatusCallback& done) {
VLOG(1) << "CompleteInstanceLocal " << device
<< " instance_key: " << cp->instance.instance_key << " gr " << gr;
// Populate the group portion of *cp from *gr. Most of it should already
// match.
{
mutex_lock l(gr->mu);
DCHECK_EQ(cp->group.group_key, gr->group.group_key);
DCHECK_EQ(cp->group.group_size, gr->group.group_size);
DCHECK_EQ(cp->group.device_type, gr->group.device_type);
cp->group = gr->group;
}
bool created_irec;
InstanceRec* ir = GetOrCreateInstanceRec(gr, cp, &created_irec);
if (!created_irec) {
// Check that the preexisting IRec is consistent with the params passed into
// this invocation.
if (ir->shared->instance.type != cp->instance.type ||
ir->shared->instance.data_type != cp->instance.data_type) {
done(errors::Internal("Collective instance ", cp->instance.instance_key,
" expected type ", ir->shared->instance.type,
" and data_type ", ir->shared->instance.data_type,
" but got type ", cp->instance.type,
" and data_type ", cp->instance.data_type));
return;
}
}
CompleteInstanceFromInitializedIRec(device, gr, cp, ir, is_source, done);
}
void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
const string& device, const GroupRec* gr, CollectiveParams* cp,
InstanceRec* ir, bool is_source, const StatusCallback& done) {
auto expected_shape = cp->instance.shape;
Status status;
// Populate the fields common across instance.
{
mutex_lock l(ir->mu);
status = ir->status;
if (status.ok()) {
// custom operator= does a deep copy.
cp->instance = ir->shared->instance;
}
}
if (!status.ok()) {
done(status);
return;
}
if (expected_shape != cp->instance.shape) {
done(errors::InvalidArgument(
"Shape mismatch in the collective instance ", cp->instance.instance_key,
". Op at device ", device, " expected shape ",
expected_shape.DebugString(), " but another member in the group ",
"expected shape ", cp->instance.shape.DebugString(), ". This is likely",
" due to different input shapes at different members of the collective",
" op."));
return;
}
// Populate the fields common across task.
AssignCollectiveType(cp);
SetDefaultRank(device, cp);
CompleteTaskIsLocal(task_name_, cp);
CollectiveImplementationInterface* col_impl;
status = CollectiveRegistry::LookupParamResolverInstance(
cp->instance.impl_details.collective_name, &col_impl);
if (!status.ok()) {
done(status);
return;
}
// We may need to wait for the group, if this is a broadcast, for source
// discovery.
if (cp->instance.type == BROADCAST_COLLECTIVE) {
WaitForGroup(ir, cp, is_source,
[col_impl, ir, device, cp, done](InstanceRec* irec) {
Status s;
if (ir != irec) {
s = errors::Internal("Expected ir ", ir, " and irec ",
irec, " to be equal");
} else {
mutex_lock l(irec->mu);
s = irec->status;
cp->source_rank = irec->source_rank;
}
if (s.ok()) {
s = col_impl->InitializeCollectiveParams(cp);
}
done(s);
});
} else {
done(col_impl->InitializeCollectiveParams(cp));
}
}
void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
CollectiveParams* cp,
bool is_source,
const IRConsumer& f) {
std::vector<IRConsumer> ready_waiters;
do {
mutex_lock l(ir->mu);
if (!ir->status.ok()) {
break;
}
CHECK_EQ(cp->group.group_size, ir->known.size());
CHECK_GE(cp->default_rank, 0);
if (!ir->known[cp->default_rank]) {
ir->known[cp->default_rank] = true;
++ir->known_count;
if (is_source) {
// Initialize source rank.
if (ir->source_rank >= 0) {
ir->status = errors::Internal("Instance ", cp->instance.instance_key,
" already has source ", ir->source_rank,
", received second claim from ",
cp->default_rank);
} else {
ir->source_rank = cp->default_rank;
}
}
}
if (ir->known_count < cp->group.group_size) {
ir->known_waiters.push_back(f);
return;
}
CHECK_EQ(ir->known_count, cp->group.group_size);
if (ir->source_rank < 0) {
// NOTE(ayushd): changing the error message below would also require
// updating CompleteParamsBroadcastForgotSend test in
// CollectiveParamResolverLocalTest.
ir->status =
errors::Internal("Instance ", cp->instance.instance_key,
" found no source for broadcast. This "
"could mean that there were group_size=",
ir->known_count, " BcastRecvs but no BcastSend.");
}
if (!ir->known_waiters.empty()) {
ready_waiters = std::move(ir->known_waiters);
}
} while (false);
f(ir);
for (auto& f : ready_waiters) {
f(ir);
}
}
void CollectiveParamResolverLocal::StartAbort(const Status& s) {
{
mutex_lock l(status_mu_);
if (!status_.ok()) {
VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring "
"subsequent abortion with status: "
<< s;
return;
}
status_ = s;
}
StartAbortLocal(s);
}
void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) {
{
mutex_lock l(group_mu_);
for (const auto& item : group_table_) {
GroupRec* gr = item.second.get();
std::vector<StatusCallback> waiting;
{
mutex_lock gl(gr->mu);
gr->status = s;
waiting.swap(gr->waiting);
}
for (const StatusCallback& done : waiting) {
done(s);
}
}
}
std::vector<InstanceRec*> instances;
{
mutex_lock l(instance_mu_);
for (const auto& group_entry : instance_table_) {
for (const auto& item : group_entry.second) {
instances.push_back(item.second.get());
}
}
}
for (InstanceRec* ir : instances) {
std::vector<IRConsumer> known_waiters;
{
mutex_lock il(ir->mu);
ir->status = s;
known_waiters.swap(ir->known_waiters);
}
for (const IRConsumer& done : known_waiters) {
done(ir);
}
}
}
} // namespace tensorflow