| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" |
| |
| #include <stddef.h> |
| |
| #include <algorithm> |
| #include <unordered_set> |
| #include <utility> |
| |
| #include "tensorflow/core/common_runtime/device_mgr.h" |
| #include "tensorflow/core/framework/cancellation.h" |
| #include "tensorflow/core/framework/collective.h" |
| #include "tensorflow/core/framework/device_attributes.pb.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/gtl/flatmap.h" |
| #include "tensorflow/core/lib/strings/numbers.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/status.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/protobuf/config.pb.h" |
| #include "tensorflow/core/util/device_name_utils.h" |
| |
| namespace tensorflow { |
| |
| CollectiveParamResolverLocal::CollectiveParamResolverLocal( |
| const ConfigProto& config, const DeviceMgr* dev_mgr, |
| DeviceResolverInterface* dev_resolver, |
| NcclCommunicatorInterface* nccl_communicator, const string& task_name) |
| : nccl_(config.experimental().collective_nccl()), |
| dev_mgr_(dev_mgr), |
| dev_resolver_(dev_resolver), |
| nccl_communicator_(nccl_communicator), |
| task_name_(task_name), |
| gpu_ring_order_( |
| config.gpu_options().experimental().collective_ring_order()) {} |
| |
| void CollectiveParamResolverLocal::CompleteGroupAsync( |
| const DeviceAttributes& device, CollGroupParams* group_params, |
| CancellationManager* cancel_mgr, const StatusCallback& done) { |
| CompleteGroupLocal(device, group_params, cancel_mgr, done); |
| } |
| |
| namespace { |
| const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) { |
| switch (cp->instance.type) { |
| case BROADCAST_COLLECTIVE: |
| return nccl ? "NcclBroadcast" : "HierarchicalTreeBroadcast"; |
| |
| case REDUCTION_COLLECTIVE: |
| return nccl ? "NcclReduce" : "RingReduce"; |
| |
| case GATHER_COLLECTIVE: |
| return nccl ? "NcclGather" : "RingGather"; |
| |
| case PERMUTE_COLLECTIVE: |
| return "Permute"; |
| |
| case ALL_TO_ALL_COLLECTIVE: |
| return "AllToAll"; |
| |
| default: |
| return "undef"; |
| } |
| } |
| |
| string TaskNameFromDeviceName(const string& device_name) { |
| DeviceNameUtils::ParsedName parsed_device; |
| CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device)); |
| string task_name; |
| CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name)); |
| return task_name; |
| } |
| } // namespace |
| |
| void CollectiveParamResolverLocal::CompleteGroupLocal( |
| const DeviceAttributes& device, CollGroupParams* group_params, |
| CancellationManager* cancel_mgr, StatusCallback done) { |
| VLOG(1) << "CompleteGroup device=" << device.name() << ": " |
| << group_params->ToString(); |
| std::vector<StatusCallback> to_be_called; |
| |
| GroupRec* gr = nullptr; |
| Status status; |
| { |
| mutex_lock l(group_mu_); |
| auto it = group_table_.find(group_params->group_key); |
| if (it == group_table_.end()) { |
| gr = new GroupRec; |
| mutex_lock grl(gr->mu); |
| gr->group.group_key = group_params->group_key; |
| gr->group.group_size = group_params->group_size; |
| gr->group.device_type = group_params->device_type; |
| if (nccl_communicator_ != nullptr) { |
| gr->group.runtime_details.communicator_key = |
| nccl_communicator_->GenerateCommunicatorKey(); |
| } |
| // Store GroupRec in group_table_ which is shared between all devices on |
| // this worker. |
| group_table_[gr->group.group_key].reset(gr); |
| VLOG(2) << "New group_key=" << gr->group.group_key |
| << " group_size=" << gr->group.group_size |
| << " runtime_details=" << gr->group.runtime_details.ToString(); |
| } else { |
| gr = it->second.get(); |
| } |
| } |
| { |
| mutex_lock l(status_mu_); |
| status = status_; |
| } |
| if (!status.ok()) { |
| done(status); |
| return; |
| } |
| |
| if (cancel_mgr != nullptr) { |
| CancellationToken token = cancel_mgr->get_cancellation_token(); |
| bool is_cancelled = !cancel_mgr->RegisterCallback( |
| token, std::bind(&CollectiveParamResolverLocal::CancelGroup, this, |
| group_params->group_key)); |
| if (is_cancelled) { |
| done(errors::Cancelled("CompleteGroup is cancelled before it starts")); |
| return; |
| } |
| done = [cancel_mgr, token, |
| original_done = std::move(done)](const Status& status) { |
| cancel_mgr->TryDeregisterCallback(token); |
| original_done(status); |
| }; |
| } |
| |
| { |
| mutex_lock gr_lock(gr->mu); |
| // If there is ever an error associated with a group key, we store the error |
| // status and invoke all waiting and future callbacks with this error |
| // status. |
| VLOG(2) << "gr device_type=" << gr->group.device_type |
| << " cp device_type=" << group_params->device_type |
| << " current device=" << device.name(); |
| if (gr->status.ok()) { |
| // Check for consistency with existing GroupRec. |
| if (group_params->device_type != gr->group.device_type) { |
| gr->status = errors::Internal( |
| "Device ", device.name(), |
| " is joining a group with incompatible device type", |
| gr->group.device_type.type_string(), |
| " (group_key=", gr->group.group_key, ")"); |
| } else if (group_params->group_size != gr->group.group_size) { |
| gr->status = errors::Internal( |
| "Device ", device.name(), " is joining a group with size", |
| group_params->group_size, ", but that group has size ", |
| gr->group.group_size, " (group_key=", gr->group.group_key, ")"); |
| } |
| } |
| bool new_device = false; |
| if (gr->status.ok()) { |
| // Insert device if not already present. |
| auto it = gr->incarnations_by_device_name.find(device.name()); |
| if (it == gr->incarnations_by_device_name.end()) { |
| if (gr->group.members.size() == gr->group.group_size) { |
| // The group is already full. |
| gr->status = |
| errors::Internal("Device ", device.name(), |
| " is joining a group that is already full", |
| " (group_key=", gr->group.group_key, ")"); |
| } else { |
| // This is a new device that has not yet joined the group. |
| gr->incarnations_by_device_name[device.name()] = device.incarnation(); |
| CollGroupMember member; |
| member.device = device; |
| gr->group.members.push_back(std::move(member)); |
| new_device = true; |
| if (VLOG_IS_ON(1)) { |
| string dev_buf; |
| for (const auto& m : gr->group.members) { |
| strings::StrAppend(&dev_buf, ",", m.device.name()); |
| } |
| VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key |
| << " group_size=" << gr->group.group_size << " (current" |
| << " devices)=(" << dev_buf << ") (number of" |
| << " devices pending)=" |
| << (gr->group.group_size - gr->group.members.size()); |
| } |
| } |
| } else { |
| // If the device already exists, check if the incarnation matches. |
| if (it->second != device.incarnation()) { |
| gr->status = errors::FailedPrecondition( |
| "Device ", device.name(), |
| " current incarnation doesn't match with one in the group. This " |
| "usually means this worker has restarted but the collective " |
| "leader hasn't, or this worker connects to a wrong cluster."); |
| } |
| } |
| } |
| |
| if (gr->status.ok()) { |
| // If the group is not yet complete, queue to wait for it. |
| VLOG(2) << "group_size " << gr->group.group_size << " set size " |
| << gr->group.members.size() << " gr " << gr; |
| |
| if (gr->group.members.size() < gr->group.group_size) { |
| gr->pending_done.push_back(std::move(done)); |
| gr->pending_params.push_back(group_params); |
| return; |
| } |
| CHECK_EQ(gr->group.members.size(), gr->group.group_size); |
| // We get a full group. Fill in remaining fields in gr->group. |
| if (new_device) { |
| FinishGroup(gr); |
| } |
| // Copy to all pending CollGroupParams; |
| *group_params = gr->group; |
| for (auto* params : gr->pending_params) { |
| *params = gr->group; |
| } |
| } |
| // At this point, we either have a full group, or an error status. Ensure |
| // that all callbacks are invoked with the appropriate status. |
| to_be_called.swap(gr->pending_done); |
| gr->pending_params.clear(); |
| status = gr->status; |
| } |
| done(status); |
| for (int i = 0; i < to_be_called.size(); ++i) { |
| to_be_called[i](status); |
| } |
| } |
| |
| namespace { |
| struct DevRec { |
| string task; |
| string device; |
| int original_rank; |
| int local_rank; |
| int global_rank; |
| const DeviceLocality* locality; |
| }; |
| typedef std::unordered_map<string, DevRec> TaskDeviceMap; |
| typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap; |
| |
| // Create a populated GlobalDeviceMap from CollInstanceParams and localities. |
| GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp) { |
| GlobalDeviceMap gdm; |
| CHECK_EQ(gp.members.size(), gp.members.size()); |
| for (int i = 0; i < gp.members.size(); ++i) { |
| TaskDeviceMap& tdm = gdm[gp.members[i].task]; |
| DevRec* dr = &tdm[gp.members[i].device.name()]; |
| dr->task = gp.members[i].task; |
| dr->device = gp.members[i].device.name(); |
| dr->original_rank = i; |
| dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap. |
| dr->global_rank = 0; // Will be populated later by EstablishGlobalRank. |
| dr->locality = &gp.members[i].device.locality(); |
| } |
| return gdm; |
| } |
| |
| bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) { |
| std::vector<string> split_gpu_ring_order_str = |
| str_util::Split(gpu_ring_order_str, ','); |
| if (split_gpu_ring_order_str.size() != tdm->size()) return false; |
| |
| // gpu id -> local rank |
| gtl::FlatMap<int32, int32> gpu_ranks; |
| for (int32_t rank = 0; |
| rank < static_cast<int32>(split_gpu_ring_order_str.size()); ++rank) { |
| int32_t tmp; |
| if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) { |
| gpu_ranks[tmp] = rank; |
| } else { |
| return false; |
| } |
| } |
| |
| for (auto& tdm_it : *tdm) { |
| DeviceNameUtils::ParsedName parsed_name; |
| DevRec* dr = &tdm_it.second; |
| if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) { |
| return false; |
| } |
| auto rank_it = gpu_ranks.find(parsed_name.id); |
| if (rank_it == gpu_ranks.end()) return false; |
| dr->local_rank = rank_it->second; |
| } |
| VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str; |
| return true; |
| } |
| |
| void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) { |
| CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices |
| |
| // If a valid ring order has been passed in via ConfigProto, use that. |
| if (ParseRingOrder(gpu_ring_order, tdm)) return; |
| |
| // Either no ring order was passed in, or the format was unexpected. |
| // We now assign a ring order based on link strengths. Note that this |
| // algorithm is not optimal and may not always find the best ring order. |
| int least_rank = -1; |
| string next_device; |
| std::set<string> selected; |
| // Starting device is one with the least initial rank. |
| for (const auto& it : *tdm) { |
| if (least_rank < 0 || it.second.original_rank < least_rank) { |
| least_rank = it.second.original_rank; |
| next_device = it.second.device; |
| } |
| } |
| CHECK_GE(least_rank, 0); |
| DeviceNameUtils::ParsedName parsed_name; |
| CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name)); |
| // NOTE: InterconnectLink has only a device_id, nothing more, so for |
| // the time being if there's more than one device at a task we |
| // assume they're all GPUs. |
| |
| int next_rank = 0; |
| while (true) { |
| selected.insert(next_device); |
| auto next_dev_it = tdm->find(next_device); |
| CHECK(next_dev_it != tdm->end()); |
| DevRec* dr = &next_dev_it->second; |
| dr->local_rank = next_rank; |
| ++next_rank; |
| if (selected.size() == tdm->size()) { |
| break; |
| } |
| // For the present time we assume Locality links only cover GPUs. |
| // For multiple CPUs, just take them in order. |
| const InterconnectLink* best_link = nullptr; |
| if (parsed_name.type == "GPU") { |
| for (const InterconnectLink& il : dr->locality->links().link()) { |
| parsed_name.id = il.device_id(); |
| string endpoint_device = |
| DeviceNameUtils::ParsedNameToString(parsed_name); |
| // Skip the device if we've already seen it. |
| if (selected.find(endpoint_device) != selected.end()) { |
| continue; |
| } |
| // Skip the device if it is not participating in this collective |
| // instance. |
| if (tdm->find(endpoint_device) == tdm->end()) { |
| continue; |
| } |
| if (best_link == nullptr || il.strength() > best_link->strength()) { |
| best_link = &il; |
| } |
| } |
| } |
| if (best_link != nullptr) { |
| // Follow the best edge |
| parsed_name.id = best_link->device_id(); |
| next_device = DeviceNameUtils::ParsedNameToString(parsed_name); |
| } else { |
| // No good edges, alas. Pick the lowest initial rank among remaining |
| // devices. |
| least_rank = -1; |
| for (const auto& it : *tdm) { |
| if (selected.find(it.second.device) != selected.end()) { |
| continue; |
| } |
| if (least_rank < 0 || it.second.original_rank < least_rank) { |
| least_rank = it.second.original_rank; |
| next_device = it.second.device; |
| } |
| } |
| CHECK_GE(least_rank, 0); |
| } |
| } |
| } |
| |
| // The first time a CollGroupParams is established for a group we compute a good |
| // rank order for all the devices in the group, that is appropriate for a ring |
| // algorithm. |
| GlobalDeviceMap EstablishGlobalRank(const CollGroupParams& gp, |
| const string& gpu_ring_order) { |
| VLOG(1) << "EstablishGlobalRank"; |
| GlobalDeviceMap gdm = BuildDevRecs(gp); |
| for (auto& iter : gdm) { |
| TaskDeviceMap& tdm = iter.second; |
| OrderTaskDeviceMap(gpu_ring_order, &tdm); |
| } |
| // Connect the global rank order by the lexicographical order of the tasks. |
| std::set<string> tasks; |
| for (const CollGroupMember& member : gp.members) { |
| tasks.insert(member.task); |
| } |
| int next_rank = 0; |
| for (const string& task : tasks) { |
| TaskDeviceMap* tdm = &gdm[task]; |
| for (auto& it : *tdm) { |
| it.second.global_rank = it.second.local_rank + next_rank; |
| } |
| next_rank += tdm->size(); |
| } |
| return gdm; |
| } |
| |
| // Count the devices associated with each task and set |
| // gp->same_num_devices_per_task. Requires gp->task_names |
| // be sorted. |
| void SetDevPerTask(CollGroupParams* gp) { |
| gp->num_devices_per_task.clear(); |
| for (const CollGroupMember& member : gp->members) { |
| gp->num_devices_per_task[member.task]++; |
| } |
| gp->same_num_devices_per_task = false; |
| int dev_per_task = -1; |
| for (const auto& task_dev : gp->num_devices_per_task) { |
| if (dev_per_task == -1) { |
| dev_per_task = task_dev.second; |
| } else if (dev_per_task != task_dev.second) { |
| return; |
| } |
| } |
| gp->same_num_devices_per_task = true; |
| } |
| |
| } // namespace |
| |
| void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) { |
| // Populate group member task and is_local. |
| for (CollGroupMember& member : gr->group.members) { |
| member.task = TaskNameFromDeviceName(member.device.name()); |
| member.is_local = member.task == task_name_; |
| } |
| // Establish the order of the members by considering localities of all |
| // devices. |
| CompleteDefaultRanking(&gr->group); |
| SetDevPerTask(&gr->group); |
| gr->group.num_tasks = |
| static_cast<int32>(gr->group.num_devices_per_task.size()); |
| } |
| |
| void CollectiveParamResolverLocal::CancelGroup(int32 group_key) { |
| std::vector<StatusCallback> pending_done; |
| GroupRec* gr = nullptr; |
| { |
| mutex_lock l(group_mu_); |
| auto it = group_table_.find(group_key); |
| if (it == group_table_.end()) { |
| return; |
| } |
| gr = it->second.get(); |
| } |
| { |
| mutex_lock l(gr->mu); |
| if (gr->group.members.size() == gr->group.group_size) { |
| // The group is already complete. There's no need to cancel. |
| return; |
| } |
| gr->status = errors::Cancelled("group is cancelled"); |
| pending_done.swap(gr->pending_done); |
| gr->pending_params.clear(); |
| } |
| for (const StatusCallback& done : pending_done) { |
| done(errors::Cancelled("group is cancelled")); |
| } |
| } |
| |
| void CollectiveParamResolverLocal::SetDefaultRank(const string& device, |
| CollectiveParams* cp) { |
| CHECK_EQ(cp->group.group_size, cp->group.members.size()) << cp->ToString(); |
| for (int i = 0; i < cp->group.group_size; ++i) { |
| if (cp->group.members[i].device.name() == device) { |
| cp->default_rank = i; |
| break; |
| } |
| } |
| } |
| |
| void CollectiveParamResolverLocal::InitInstanceSharedParams( |
| const CollectiveParams* cp, InstanceRec* ir) { |
| ir->shared->instance = cp->instance; |
| ir->shared->default_rank = -1; |
| } |
| |
| // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks |
| // to all devices that they are physically connected to and visible to the |
| // TensorFlow runtime. This set of devices may be a superset of the devices |
| // participating in this instance of collectives. |
| void CollectiveParamResolverLocal::CompleteDefaultRanking(CollGroupParams* gp) { |
| // Sort gp->member to avoid indeterminism. |
| std::sort(gp->members.begin(), gp->members.end(), |
| [](const CollGroupMember& lhs, const CollGroupMember& rhs) { |
| return lhs.device.name() < rhs.device.name(); |
| }); |
| // Establish an instance-specific default rank order for devices |
| // based on localities. This rank order should be a good ring |
| // order, if possible. |
| GlobalDeviceMap gdm = EstablishGlobalRank(*gp, gpu_ring_order_); |
| // Reflect the new global ranking on shared |
| std::vector<CollGroupMember> new_members(gp->group_size); |
| for (const auto& git : gdm) { |
| const TaskDeviceMap& tdm = git.second; |
| for (const auto& tit : tdm) { |
| const DevRec& dr = tit.second; |
| new_members[dr.global_rank] = std::move(gp->members[dr.original_rank]); |
| } |
| } |
| |
| if (VLOG_IS_ON(2)) { |
| string buf; |
| for (const auto& m : new_members) |
| strings::StrAppend(&buf, "\n", m.device.name()); |
| VLOG(2) << "Optimized device order for group " << gp->group_key << ": " |
| << buf; |
| } |
| gp->members = std::move(new_members); |
| } |
| |
| CollectiveParamResolverLocal::InstanceRec* |
| CollectiveParamResolverLocal::GetOrCreateInstanceRec(CollectiveParams* cp, |
| bool* created) { |
| *created = false; |
| InstanceRec* irec = nullptr; |
| { |
| mutex_lock l(instance_mu_); |
| auto group_it = instance_table_.find(cp->group.group_key); |
| if (group_it != instance_table_.end()) { |
| auto instance_it = group_it->second.find(cp->instance.instance_key); |
| if (instance_it != group_it->second.end()) { |
| irec = instance_it->second.get(); |
| } |
| } |
| if (irec == nullptr) { |
| // Create new InstanceRec. |
| irec = new InstanceRec; |
| *created = true; |
| { |
| mutex_lock il(irec->mu); |
| irec->known.resize(cp->group.group_size, false); |
| } |
| InitInstanceSharedParams(cp, irec); |
| instance_table_[cp->group.group_key][cp->instance.instance_key].reset( |
| irec); |
| } |
| } |
| Status status; |
| { |
| mutex_lock l(status_mu_); |
| status = status_; |
| } |
| if (!status.ok()) { |
| mutex_lock l(irec->mu); |
| irec->status = status; |
| } |
| return irec; |
| } |
| |
| Status CollectiveParamResolverLocal::LookupAndPopulateGroupParams( |
| CollGroupParams* group) { |
| mutex_lock l(group_mu_); |
| auto group_rec = group_table_.find(group->group_key); |
| if (group_rec == group_table_.end()) { |
| return errors::InvalidArgument("Group ", group->group_key, |
| " is not " |
| "initialized. Please call group " |
| "initialization op first before invoking " |
| "collective op."); |
| } |
| mutex_lock lock(group_rec->second->mu); |
| if (!group_rec->second->status.ok()) { |
| return errors::FailedPrecondition( |
| "Failed to run collective due to " |
| "unsuccessful group initialization. " |
| "Group initialization failed with error ", |
| group_rec->second->status.ToString()); |
| } |
| *group = group_rec->second->group; |
| return Status::OK(); |
| } |
| |
| void CollectiveParamResolverLocal::CompleteParamsAsync( |
| const DeviceAttributes& device, CollectiveParams* cp, |
| CancellationManager* cancel_mgr, const StatusCallback& done) { |
| VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": " |
| << cp->ToString(); |
| if (cp->run_group_initialization) { |
| CompleteGroupLocal(device, &cp->group, cancel_mgr, |
| [this, device, cp, done](const Status& s) { |
| if (s.ok()) { |
| CompleteInstanceLocal(device.name(), cp, done); |
| } else { |
| done(s); |
| } |
| }); |
| } else { |
| // For Collective V3 ops, group is already initialized. Fetch attributes |
| // for the already initialized group to pass to Insitance initialization. |
| auto s = LookupAndPopulateGroupParams(&cp->group); |
| if (s.ok()) { |
| CompleteInstanceLocal(device.name(), cp, done); |
| } else { |
| done(s); |
| } |
| } |
| } |
| |
| void CollectiveParamResolverLocal::CompleteInstanceAsync( |
| const CompleteInstanceRequest* request, CompleteInstanceResponse* response, |
| CancellationManager* cancel_mgr, const StatusCallback& done) { |
| done( |
| errors::Internal("CompleteInstance is not implemented by " |
| "CollectiveParamResolverLocal which is " |
| "intended only for non-distributed deployment.")); |
| } |
| |
| // TODO(b/111897089): we need a better way to pick the collective |
| // implementation. The ideal way would depend upon the topology and link |
| // strength before picking a particular implementation. |
| void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) { |
| // We use the NCCL implementation if this is an environment which supports |
| // NCCL, i.e. `LookupParamResolverInstance` for `NcclReduce` returns OK, and |
| // also if indicated either in `ConfigProto` or `communication_hint`. |
| // |
| // After enough testing, we may simplify this logic to use NCCL whenever |
| // available. |
| CollectiveImplementationInterface* col_impl; |
| bool use_nccl = |
| (nccl_ || cp->instance.impl_details.communication_hint == "nccl") && |
| cp->group.device_type == DEVICE_GPU && |
| CollectiveRegistry::LookupParamResolverInstance("NcclReduce", &col_impl) |
| .ok(); |
| cp->instance.impl_details.collective_name = GetCollectiveName(cp, use_nccl); |
| VLOG(1) << "AssignCollectiveType " |
| << cp->instance.impl_details.collective_name; |
| } |
| |
| void CollectiveParamResolverLocal::CompleteInstanceLocal( |
| const string& device, CollectiveParams* cp, const StatusCallback& done) { |
| VLOG(1) << "CompleteInstanceLocal " << device |
| << " instance_key: " << cp->instance.instance_key << " group_key " |
| << cp->group.group_key; |
| |
| bool created_irec; |
| InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec); |
| if (!created_irec) { |
| // Check that the preexisting IRec is consistent with the params passed into |
| // this invocation. |
| if (ir->shared->instance.type != cp->instance.type || |
| ir->shared->instance.data_type != cp->instance.data_type) { |
| done(errors::Internal("Collective instance ", cp->instance.instance_key, |
| " expected type ", ir->shared->instance.type, |
| " and data_type ", ir->shared->instance.data_type, |
| " but got type ", cp->instance.type, |
| " and data_type ", cp->instance.data_type)); |
| return; |
| } |
| } |
| CompleteInstanceFromInitializedIRec(device, cp, ir, done); |
| } |
| |
| void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec( |
| const string& device, CollectiveParams* cp, InstanceRec* ir, |
| const StatusCallback& done) { |
| auto expected_shape = cp->instance.shape; |
| Status status; |
| // Populate the fields common across instance. |
| { |
| mutex_lock l(ir->mu); |
| status = ir->status; |
| if (status.ok()) { |
| // custom operator= does a deep copy. |
| cp->instance = ir->shared->instance; |
| } |
| } |
| if (!status.ok()) { |
| done(status); |
| return; |
| } |
| if (expected_shape != cp->instance.shape) { |
| done(errors::InvalidArgument( |
| "Shape mismatch in the collective instance ", cp->instance.instance_key, |
| ". Op at device ", device, " expected shape ", |
| expected_shape.DebugString(), " but another member in the group ", |
| "expected shape ", cp->instance.shape.DebugString(), ". This is likely", |
| " due to different input shapes at different members of the collective", |
| " op.")); |
| return; |
| } |
| // Populate the fields common across task. |
| AssignCollectiveType(cp); |
| SetDefaultRank(device, cp); |
| |
| CollectiveImplementationInterface* col_impl; |
| status = CollectiveRegistry::LookupParamResolverInstance( |
| cp->instance.impl_details.collective_name, &col_impl); |
| if (!status.ok()) { |
| done(status); |
| return; |
| } |
| |
| // We may need to wait for the group, if this is a broadcast, for source |
| // discovery. |
| if (cp->instance.type == BROADCAST_COLLECTIVE) { |
| WaitForGroup(ir, cp, [col_impl, ir, device, cp, done](InstanceRec* irec) { |
| Status s; |
| if (ir != irec) { |
| s = errors::Internal("Expected ir ", ir, " and irec ", irec, |
| " to be equal"); |
| } else { |
| mutex_lock l(irec->mu); |
| s = irec->status; |
| cp->source_rank = irec->source_rank; |
| } |
| if (s.ok()) { |
| s = col_impl->InitializeCollectiveParams(cp); |
| } |
| done(s); |
| }); |
| } else { |
| done(col_impl->InitializeCollectiveParams(cp)); |
| } |
| } |
| |
| void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir, |
| CollectiveParams* cp, |
| const IRConsumer& f) { |
| std::vector<IRConsumer> ready_waiters; |
| do { |
| mutex_lock l(ir->mu); |
| if (!ir->status.ok()) { |
| break; |
| } |
| CHECK_EQ(cp->group.group_size, ir->known.size()); |
| CHECK_GE(cp->default_rank, 0); |
| if (!ir->known[cp->default_rank]) { |
| ir->known[cp->default_rank] = true; |
| ++ir->known_count; |
| if (cp->is_source) { |
| // Initialize source rank. |
| if (ir->source_rank >= 0) { |
| ir->status = errors::Internal("Instance ", cp->instance.instance_key, |
| " already has source ", ir->source_rank, |
| ", received second claim from ", |
| cp->default_rank); |
| } else { |
| ir->source_rank = cp->default_rank; |
| } |
| } |
| } |
| if (ir->known_count < cp->group.group_size) { |
| ir->known_waiters.push_back(f); |
| return; |
| } |
| CHECK_EQ(ir->known_count, cp->group.group_size); |
| if (ir->source_rank < 0) { |
| // NOTE(ayushd): changing the error message below would also require |
| // updating CompleteParamsBroadcastForgotSend test in |
| // CollectiveParamResolverLocalTest. |
| ir->status = |
| errors::Internal("Instance ", cp->instance.instance_key, |
| " found no source for broadcast. This " |
| "could mean that there were group_size=", |
| ir->known_count, " BcastRecvs but no BcastSend."); |
| } |
| if (!ir->known_waiters.empty()) { |
| ready_waiters = std::move(ir->known_waiters); |
| } |
| } while (false); |
| f(ir); |
| for (auto& f : ready_waiters) { |
| f(ir); |
| } |
| } |
| |
| void CollectiveParamResolverLocal::StartAbort(const Status& s) { |
| { |
| mutex_lock l(status_mu_); |
| if (!status_.ok()) { |
| VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring " |
| "subsequent abortion with status: " |
| << s; |
| return; |
| } |
| status_ = s; |
| } |
| StartAbortLocal(s); |
| } |
| |
| void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) { |
| std::vector<StatusCallback> pending_done; |
| { |
| mutex_lock l(group_mu_); |
| for (const auto& item : group_table_) { |
| GroupRec* gr = item.second.get(); |
| { |
| mutex_lock gl(gr->mu); |
| gr->status = s; |
| for (auto& done : gr->pending_done) { |
| pending_done.push_back(std::move(done)); |
| } |
| gr->pending_done.clear(); |
| gr->pending_params.clear(); |
| } |
| } |
| } |
| for (const StatusCallback& done : pending_done) { |
| done(s); |
| } |
| std::vector<InstanceRec*> instances; |
| { |
| mutex_lock l(instance_mu_); |
| for (const auto& group_entry : instance_table_) { |
| for (const auto& item : group_entry.second) { |
| instances.push_back(item.second.get()); |
| } |
| } |
| } |
| for (InstanceRec* ir : instances) { |
| std::vector<IRConsumer> known_waiters; |
| { |
| mutex_lock il(ir->mu); |
| ir->status = s; |
| known_waiters.swap(ir->known_waiters); |
| } |
| for (const IRConsumer& done : known_waiters) { |
| done(ir); |
| } |
| } |
| } |
| |
| } // namespace tensorflow |