/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License.
#include <string>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class BufRendezvous;
class CancellationManager;
class CompleteGroupRequest;
class CompleteGroupResponse;
class CompleteInstanceRequest;
class CompleteInstanceResponse;
class Device;
class DeviceMgr;
class GetStepSequenceRequest;
class GetStepSequenceResponse;
class NcclManager;
class Tensor;
// Types of supported collective operations.
enum CollectiveType {
// Some collective op implementations require runtime group configuration from
// the OpKernel. Currently, this struct is used to set communicator key for
// NCCL-based collective implementation.
struct CollGroupRuntimeDetails {
string communicator_key; // for communicator-based techniques e.g. NCCL
string ToString() const;
struct CollGroupMember {
DeviceAttributes device;
string task;
bool is_local;
// Data common to all members of a device group.
// All members share the same device set but its order is
// particular to an instance so it is stored there.
struct CollGroupParams {
int32 group_key;
int32 group_size;
DeviceType device_type;
// Members in this group, in default rank order.
std::vector<CollGroupMember> members;
// True if every task has the same number of devices.
bool same_num_devices_per_task = false;
// Task -> number of devices on that task.
std::unordered_map<string, int32> num_devices_per_task;
int32 num_tasks; // number of distinct tasks in group
CollGroupRuntimeDetails runtime_details;
string ToString() const;
: group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {}
// The best implementation of a collective op depends on many factors
// including the number of devices involved, the topology of
// interconnects between them and the sizes of inputs. This structure
// is used in generating and representing data movement choreography
// for each specific algorithm, hence it does not have a single, fixed
// interpretation. On first execution the runtime will update this
// structure with decisions that will guide all subsequent executions.
struct CollImplDetails {
string collective_name;
std::vector<std::vector<int>> subdiv_permutations;
// subdiv_offsets and max_subdivs_per_device are used together as follows:
// When subdiv_offsets is provided (non-empty) it is used as is. When
// subdiv_offsets is not provided subdivisons are generated dynamically
// constrained by max_subdivs_per_device. When subdiv_offsets is empty AND
// max_subdivs_per_device = 0 an internal default kMaxSubdivsPerDeviceDefault
// is used. When max_subdivs_per_device = -1, no subivision is done.
int max_subdivs_per_device = -1; // Upper bound on subdivisions per device.
std::vector<int> subdiv_offsets;
std::vector<int> subdiv_source_rank; // rank of source in each subdiv
dependencies; // collective instances on which this node depends
string communication_hint; // user-supplied hint for implementation choice,
// e.g. ring or nccl
float timeout_seconds; // If non zero, set a completion timeout for the
// collective op to detect staleness.
// Data common to all members of a collective instance.
// TODO(b/163171014) Refactor this struct to not be a union of all fields.
struct CollInstanceParams {
// Identifies all participating graph nodes.
int32 instance_key = -1;
CollectiveType type = UNDEFINED_COLLECTIVE;
DataType data_type = DT_FLOAT;
TensorShape shape = {0};
CollImplDetails impl_details;
string ToString() const;
CollInstanceParams& operator=(const struct CollInstanceParams& other);
std::vector<string> devices; // permuter only
// For permuter only
// Each rank in the permutation is a receiver.
// Indices of each rank means a sender to that rank.
// Example: permutation = {2,0,1} means
// rank 0 sends to rank 2
// rank 1 sends to rank 0
// rank 2 sends to rank 1
std::vector<int> permutation;
// Unique to a single CollectiveOp node.
struct CollectiveParams : public core::RefCounted {
CollGroupParams group;
CollInstanceParams instance;
string name = ""; // node name used only for log or error messages
int default_rank = -1; // index of this op within device_names
bool is_source = false; // broadcast only
int source_rank = -1; // broadcast only
// Rank of this device in each subdivision permutation.
std::vector<int> subdiv_rank;
OpKernel* merge_op = nullptr; // reduction only
OpKernel* final_op = nullptr; // reduction only
string ToString() const;
bool run_group_initialization = true;
class CollectiveExecutor;
// Interface that provides resolution of device localities.
class DeviceResolverInterface {
virtual ~DeviceResolverInterface() {}
// Populates *attributes with the DeviceAttributes of the specified device.
virtual Status GetDeviceAttributes(const string& device,
DeviceAttributes* attributes) = 0;
// Returns all device attributes of a task.
virtual Status GetAllDeviceAttributes(
const string& task, std::vector<DeviceAttributes>* attributes) = 0;
// Updates device attributes. It returns error if any device already
// exists in the DeviceResolver and has a different incarnation.
virtual Status UpdateDeviceAttributes(
const std::vector<DeviceAttributes>& attributes) = 0;
// Interface that provides resolution of shared CollectiveParams fields.
class ParamResolverInterface {
virtual ~ParamResolverInterface() {}
// Called by each collective op at first execution in order to fill out
// the CollectiveParams structure with data gathered from the full
// (maybe distributed) collection of peer nodes.
virtual void CompleteParamsAsync(const DeviceAttributes& device,
CollectiveParams* cp,
CancellationManager* cancel_mgr,
const StatusCallback& done) = 0;
// Completes group_params with data gathered from all devices in the group.
// This blocks until all devices are there.
virtual void CompleteGroupAsync(const DeviceAttributes& device,
CollGroupParams* group_params,
CancellationManager* cancel_mgr,
const StatusCallback& done) = 0;
// Used within a distributed implementation to discover/verify data
// shared across an instance group.
// Note: this works differently from CompleteGroupAsync as a refactor is in
// progress.
virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request,
CompleteInstanceResponse* response,
CancellationManager* cancel_mgr,
const StatusCallback& done) = 0;
// Aborts the resolver. After abortion the resolver can no longer be used.
virtual void StartAbort(const Status& s) = 0;
// Graphs which utilize Collective Ops in a common instance must
// execute with identical step_ids even if they are disjoint graphs
// run by otherwise independent tasks. This interface supplies
// coordinated step_ids to use in such cases.
class StepSequenceInterface {
virtual ~StepSequenceInterface() {}
// Used with a distributed implementation to coordinate step_id
// sequences across tasks.
virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
GetStepSequenceResponse* response,
const StatusCallback& done) = 0;
// Refresh the local per-graph_key step_id sequence from collective
// group leader, if applicable.
virtual void RefreshStepIdSequenceAsync(int64_t graph_key,
const StatusCallback& done) = 0;
// Returns the step_id that should be used for initiating a new execution
// on the specified graph. May return the same step_id multiple times if
// RetireStepId or RefreshStepIdReservation is not called.
virtual int64_t NextStepId(int64_t graph_key) = 0;
// Reports that execution of the given step has completed successfully.
// Should be called immediately after a step completes with OK status,
// prior to calling NextStepId(). If the step fails, don't call.
virtual void RetireStepId(int64_t graph_key, int64_t step_id) = 0;
class NcclCommunicatorInterface;
// Interface that provides access to per-step CollectiveExecutor
// instances and various distributed resolution capabilities.
class CollectiveExecutorMgrInterface : public StepSequenceInterface {
virtual ~CollectiveExecutorMgrInterface() {}
// Returns the step-specific CollectiveExecutor, creating if one does not
// already exist. The caller assumes ownership of one Ref on the object.
virtual CollectiveExecutor* FindOrCreate(int64_t step_id) = 0;
// If there is a CollectiveExecutor for step_id, remove it from the
// table.
virtual void Cleanup(int64_t step_id) = 0;
virtual ParamResolverInterface* GetParamResolver() const = 0;
virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0;
// Interface that a Collective Op implementation uses to exchange data
// with peers. Note that data exchange is currently limited to types
// for which DMAHelper::CanUseDMA() returns true, i.e. dense numeric
// types.
class CollectiveRemoteAccess {
virtual ~CollectiveRemoteAccess() {}
virtual void RecvFromPeer(const string& peer_device, const string& peer_task,
bool peer_is_local, const string& key,
Device* to_device, DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr,
Tensor* to_tensor,
const DeviceLocality& client_locality,
int dev_to_dev_stream_index,
CancellationManager* cancellation_manager,
const StatusCallback& done) = 0;
virtual void PostToPeer(const string& peer_device, const string& peer_task,
const string& key, Device* from_device,
DeviceContext* from_device_ctx,
const AllocatorAttributes& from_alloc_attr,
const Tensor* from_tensor,
const DeviceLocality& client_locality,
CancellationManager* cancellation_manager,
const StatusCallback& done) = 0;
// Checks the health of a collective peer. It probes the peer to see if it is
// alive. Note that if a peer has restarted, it's considered a different one,
// so CheckPeerHealth fails.
virtual void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms,
const StatusCallback& done) = 0;
virtual BufRendezvous* buf_rendezvous() = 0;
virtual void StartAbort(const Status& s) = 0;
// A step-specific object that can execute a collective operation completely
// described by a CollectiveParams object.
class CollectiveExecutor : public core::RefCounted {
virtual void StartAbort(const Status& s) {}
virtual void ExecuteAsync(OpKernelContext* ctx,
const CollectiveParams* col_params,
const string& exec_key, StatusCallback done) {
"A collective Op has been called in a context in which "
"a CollectiveExecutor has not been provided."));
virtual void CompleteParamsAsync(const DeviceAttributes& device,
CollectiveParams* cp,
CancellationManager* cancel_mgr,
StatusCallback done) {
"A collective Op has been called in a context in which "
"a CollectiveExecutor has not been provided."));
virtual void CompleteGroupAsync(const DeviceAttributes& device,
CollGroupParams* group_params,
CancellationManager* cancel_mgr,
StatusCallback done) {
return cem_->GetParamResolver()->CompleteGroupAsync(device, group_params,
cancel_mgr, done);
// Runs the potentially-blocking closure/expensive callback.
virtual void RunClosure(std::function<void()> closure) = 0;
virtual CollectiveRemoteAccess* remote_access() { return nullptr; }
// `WaitForDependencies` and `Launched` are used for fine-grained control of
// execution order between collective instances. These functions are intended
// to be called in `Run` function of collective implementations, and may be
// used to make part, or whole, of the collective execution ordered with
// respect to other collective instances.
// `WaitForDependencies` will block until it is safe to continue the callee's
// execution, where safety is defined as: ordered with respect to the
// collective instances defined in the callee's `wait_for` attribute.
virtual void WaitForDependencies(const CollectiveParams& col_params) {}
// `UnblockDependencies` unblocks the dependent collective instances by
// recording that this caller's device has completed the critical portion of
// the collective execution.
virtual void UnblockDependencies(const CollectiveParams& col_params) {}
// Used to designate an invalid group or instance key.
static int64_t kInvalidId;
// Lexically scoped handle for Ref.
class Handle {
explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) {
if (!inherit_ref) ce->Ref();
~Handle() { ce_->Unref(); }
CollectiveExecutor* get() const { return ce_; }
CollectiveExecutor* ce_;
explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem)
: cem_(cem) {}
// For use only by derived classes
static OpKernelContext::Params* CtxParams(OpKernelContext* ctx);
CollectiveExecutorMgrInterface* cem_;
struct CollectiveContext {
CollectiveExecutor* col_exec; // Not owned
NcclCommunicatorInterface* nccl_communicator; // Not owned
const DeviceMgr* dev_mgr; // Not owned
OpKernelContext* op_ctx; // Not owned
OpKernelContext::Params* op_params; // Not owned
const CollectiveParams* col_params; // Not owned
const string exec_key;
const int64_t step_id;
const Tensor* input; // Not owned
Tensor* output; // Not owned
Device* device; // The device for which this instance labors
const string device_name;
DeviceLocality device_locality;
CollectiveContext(CollectiveExecutor* col_exec,
NcclCommunicatorInterface* nccl_communicator,
const DeviceMgr* dev_mgr, OpKernelContext* ctx,
OpKernelContext::Params* op_params,
const CollectiveParams* col_params, const string& exec_key,
int64_t step_id, const Tensor* input, Tensor* output);
class NcclCommunicatorInterface {
virtual ~NcclCommunicatorInterface() = default;
virtual string GenerateCommunicatorKey() = 0;
virtual void Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
StatusCallback done) = 0;
virtual void StartAbort(const Status& s) = 0;
// Interface of a Collective Op implementation. Each specific CollectiveOp will
// implement this interface and register the implementation via the
// CollectiveRegistry detailed below. See common_runtime/ring_reducer and
// common_runtime/hierarchical_tree_broadcaster for examples.
class CollectiveImplementationInterface : public core::RefCounted {
virtual ~CollectiveImplementationInterface() = default;
// Initializes the portions of `col_params` specific to this
// implementation. Called exactly once for every Collective instance during
// the CollectiveParams resolution process when the graph is first executed,
// at the end of `CompleteInstanceLocal()`.
// NOTE(ayushd): This is effectively a static function because it modifies the
// `col_params` passed in and should not manipulate any data members. However
// because it is virtual and needs to be implemented by every derived class we
// do not mark it as static.
virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
// Prepares the CollectiveContext for executing this CollectiveImplementation.
// Called from CollectiveExecutor right before calling Run(). The
// CollectiveContext passed in must outlive the CollectiveImplementation
// object.
virtual Status InitializeCollectiveContext(
std::shared_ptr<CollectiveContext> col_ctx) = 0;
// Processes and moves data according to the logic of this Collective
// implementation. Relies on appropriate initialization of op-specific
// CollectiveParams in InitializeCollectiveParams(), as well as appropriate
// context initialization in InitializeCollectiveContext().
virtual void Run(StatusCallback done) = 0;
// Static-methods only class for registering and looking up collective
// implementations.
class CollectiveRegistry {
using Factory = std::function<CollectiveImplementationInterface*()>;
// Looks up a previously registered CollectiveImplementation under
// `collective_name`. If found, creates an instance of the implementation and
// assign to `implementation`.
static Status Lookup(const string& collective_name,
CollectiveImplementationInterface** implementation);
// Looks up a previously registered CollectiveImplementation under
// `collective_name`. If found, returns the static instance of this
// implementation via `implementation`. This instance should only be used to
// call InitializateCollectiveParams.
static Status LookupParamResolverInstance(
const string& collective_name,
CollectiveImplementationInterface** implementation);
// Returns all registered collective implementations.
static void GetAll(
std::vector<CollectiveImplementationInterface*>* implementations);
friend class CollectiveRegistration;
// Registers a CollectiveImplementation with name `collective_name` and
// factory `factory`. The latter is a function used to create instances of
// the CollectiveImplementation. Also creates a static instance of the
// implementation - this instance is used during param resolution and should
// only be used to call InitializeCollectiveParams.
static Status Register(const string& collective_name, Factory factory);
static Status LookupHelper(const string& collective_name,
CollectiveImplementationInterface** implementation,
bool param_resolver);
// Class used to call CollectiveRegistry::Register. This should only be used to
// create a global static object.
class CollectiveRegistration {
CollectiveRegistration(const string& collective_name,
CollectiveRegistry::Factory factory) {
TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory));
#define REGISTER_COLLECTIVE(name, implementation) \
static CollectiveRegistration register_##name##_collective( \
#name, []() { return new implementation; });
} // namespace tensorflow