blob: 64ad8bbb18e7a3c70c21a1851f4adff5bc785d61 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
#include <functional>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
class CompleteGroupRequest;
class CompleteGroupResponse;
class CompleteInstanceRequest;
class CompleteInstanceResponse;
class ConfigProto;
class DeviceMgr;
// Implements ParamResolverInterface for a single-task context.
// It also implements the functionality necessary to serve as the
// group leader for param resolution in a multi-task context.
class CollectiveParamResolverLocal : public ParamResolverInterface {
public:
CollectiveParamResolverLocal(const ConfigProto& config,
const DeviceMgr* dev_mgr,
DeviceResolverInterface* dev_resolver,
NcclCommunicatorInterface* nccl_communicator,
const string& task_name);
~CollectiveParamResolverLocal() override {}
void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
CancellationManager* cancel_mgr,
const StatusCallback& done) override;
void CompleteGroupAsync(const DeviceAttributes& device,
CollGroupParams* group_params,
CancellationManager* cancel_mgr,
const StatusCallback& done) override;
void CompleteInstanceAsync(const CompleteInstanceRequest* request,
CompleteInstanceResponse* response,
CancellationManager* cancel_mgr,
const StatusCallback& done) override;
void StartAbort(const Status& s) override;
protected:
// For access to InstanceRec and CompleteDefaultRanking.
friend class CollectiveParamResolverLocalTest;
// Used to complete/verify CollGroup.
struct GroupRec {
mutable mutex mu;
CollGroupParams group TF_GUARDED_BY(mu);
Status status TF_GUARDED_BY(mu);
std::unordered_map<string, int64_t> incarnations_by_device_name
TF_GUARDED_BY(mu);
std::vector<CollGroupParams*> pending_params TF_GUARDED_BY(mu);
std::vector<StatusCallback> pending_done TF_GUARDED_BY(mu);
};
// Finds the GroupRec that corresponds to group_params->group_key.
// Also populates group_params from that group_rec.
// Will wait until GroupRec is fully populated or an error arises before
// calling done. Callback GroupRec* arg is only valid if status is ok.
// Ownership of GroupRec stays with this object and does not pass to the
// callback.
void CompleteGroupLocal(const DeviceAttributes& device,
CollGroupParams* group_params,
CancellationManager* cancel_mgr, StatusCallback done)
TF_LOCKS_EXCLUDED(group_mu_);
// Finishes the group parameters once all members of the group are there.
void FinishGroup(GroupRec* gr) TF_EXCLUSIVE_LOCKS_REQUIRED(gr->mu);
// Cancels the group if it's still pending.
void CancelGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_);
// Lookup and populate parameters from an already initialized group.
Status LookupAndPopulateGroupParams(CollGroupParams* group_params);
// Used to complete/verify CollInstance.
struct InstanceRec;
typedef std::function<void(InstanceRec*)> IRConsumer;
struct InstanceRec {
mutex mu;
// Values to be shared by all instances, constant after initialization.
CollectiveParams* shared;
// If an error occurs during initialization this structure stays in the
// table with a non-OK status. Purging the table and restarting needs to be
// done at a higher level.
Status status TF_GUARDED_BY(mu);
// These fields are used to count the instances that have called
// in and become known while resolving broadcast source identity and
// communicator key.
int source_rank TF_GUARDED_BY(mu);
string communicator_key TF_GUARDED_BY(mu);
int known_count TF_GUARDED_BY(mu);
std::vector<bool> known TF_GUARDED_BY(mu);
std::vector<IRConsumer> known_waiters TF_GUARDED_BY(mu);
InstanceRec()
: shared(new CollectiveParams()), source_rank(-1), known_count(0) {}
~InstanceRec() { shared->Unref(); }
};
// Find the InstanceRec with the same instance_key as cp. If it doesn't
// already exist, create and initialize from gr and cp.
// created is set to true if a new IRec is created, false otherwise.
//
// Precondition: *gr must be a complete GroupRec, i.e. the value set
// by CompleteGroupLocal. *cp must be populated with all the fields
// required by InitInstanceSharedParams. Ownership of InstanceRec stays
// with this object and does not pass to the callback.
InstanceRec* GetOrCreateInstanceRec(CollectiveParams* cp, bool* created)
TF_LOCKS_EXCLUDED(instance_mu_, group_mu_);
// Populate *ir with device membership from gr, then initialize to be specific
// to cp->instance_key, i.e. order the devices and tasks.
//
// Preconditions:
// cp is populated with all DeviceLocalities
void InitInstanceSharedParams(const CollectiveParams* cp, InstanceRec* ir);
// Establishes the final order of gp->device_names and gp->task_names by
// considering localities of all devices.
void CompleteDefaultRanking(CollGroupParams* gp);
// Finish populating *cp.
// Precondition: *gr has been fully populated by CompleteGroupLocal.
void CompleteInstanceLocal(const string& device, CollectiveParams* cp,
const StatusCallback& done)
TF_LOCKS_EXCLUDED(instance_mu_, group_mu_);
// Finish populating *cp from fully initialized *ir.
// Precondition: *gr and *ir are fully populated.
void CompleteInstanceFromInitializedIRec(const string& device,
CollectiveParams* cp,
InstanceRec* ir,
const StatusCallback& done)
TF_LOCKS_EXCLUDED(ir->mu);
// Complete instance params after waiting for group.
// Precondition: *cp has complete group data and default_rank.
void WaitForGroup(InstanceRec* ir, CollectiveParams* cp, const IRConsumer& f)
TF_LOCKS_EXCLUDED(ir->mu);
// If cp.device_names contains only devices local to this process
// populates *localities, else returns an error.
Status GetLocalDeviceLocalities(const CollectiveParams& cp,
std::vector<DeviceLocality>* localities);
// Sets cp->instance_default_rank according to location of device in
// current ordering of cp->instance.device_names.
void SetDefaultRank(const string& device, CollectiveParams* cp);
// Sets cp->instance.type based on collective op type, and attempts to assign
// best implementation.
void AssignCollectiveType(CollectiveParams* cp);
void StartAbortLocal(const Status& s)
TF_LOCKS_EXCLUDED(status_mu_, group_mu_, instance_mu_);
const bool nccl_;
const DeviceMgr* dev_mgr_;
DeviceResolverInterface* dev_resolver_; // Not owned.
NcclCommunicatorInterface* nccl_communicator_; // Not owned.
string task_name_;
string gpu_ring_order_;
mutex group_mu_;
gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
TF_GUARDED_BY(group_mu_);
mutex instance_mu_;
gtl::FlatMap<int32, gtl::FlatMap<int32, std::unique_ptr<InstanceRec>>>
instance_table_ TF_GUARDED_BY(instance_mu_);
mutex status_mu_;
Status status_ TF_GUARDED_BY(status_mu_);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_