blob: b93d69659911de5452e5746482b5b72c9a3cae10 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
#include <unordered_map>
// clang-format off
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
#endif // IS_MOBILE_PLATFORM
namespace tensorflow {
// A class that stores all the FunctionLibraryRuntime objects, one per device.
class ProcessFunctionLibraryRuntime {
public:
// Creates FunctionLibraryRuntime objects for each device in the provided
// DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent
// (if provided) outlive this object.
ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool = nullptr,
DistributedFunctionLibraryRuntime* parent = nullptr,
const CustomKernelCreator* custom_kernel_creator = nullptr,
const SessionMetadata* metadata = nullptr);
virtual ~ProcessFunctionLibraryRuntime() {
// Deleting the FunctionLibraryRuntime map will delete the function handles
// registered in it, which may call ReleaseHandle in this class again to
// release their sub-function. These circular calls may casue segfault
// since the flr_map_ may has already been deleted. Explicitly releasing
// flr_map_ here and checking flr_map_ in ReleaseHandle to avoid this.
flr_map_.reset();
}
// Sends `tensors_to_send` from `source_device` to `target_device` using
// `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
// Rendezvous. `device_context` should be the DeviceContext of the device
// doing the sending. `alloc_attrs` should either be empty or be the size of
// `tensors_to_send` and indicates how the input tensors are allocated. Method
// takes references on each of the `tensors_to_send`. Method doesn't block.
static Status SendTensors(const string& source_device,
const string& target_device,
const string& key_prefix, int64 src_incarnation,
gtl::ArraySlice<Tensor> tensors_to_send,
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous);
// Receives `received_tensors` from `target_device` (originally sent from
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
// keys to be retrieved. `device_context` should be for the device receiving
// the tensors. `alloc_attrs` indicates how to allocate the received
// tensors and should either be empty or `num_tensors` in size. Method doesn't
// block and calls `done` when `num_tensors` are fetched.
static void ReceiveTensorsAsync(
const string& source_device, const string& target_device,
const string& key_prefix, int64 src_incarnation, int64 num_tensors,
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
StatusCallback done);
static const char kDefaultFLRDevice[];
// Returns the FunctionLibraryRuntime for the corresponding device_name.
FunctionLibraryRuntime* GetFLR(const string& device_name) const;
// Returns the return types for the function identified by handle `h`.
Status GetRetTypes(FunctionLibraryRuntime::Handle h,
DataTypeVector* ret_types);
// Returns the device incarnation for the given device_name.
Status GetDeviceIncarnation(const string& device_name,
int64* incarnation) const;
// For a given canonicalized key signature of the function instantiated
// on device `device_name` and a `local_handle`, creates a handle and returns
// that value. Uses core/common_runtime/framework/function.h::Canonicalize
// to canonicalize the function signature.
FunctionLibraryRuntime::Handle AddHandle(
const string& function_key, const string& device_name,
FunctionLibraryRuntime::LocalHandle local_handle);
// Returns a handle if found for the given key, else returns kInvalidHandle.
FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const;
// For the given handle instantiated on device `device_name` returns the local
// index of instantiation of that function. If the function was not
// instantiated on `device_name` or the function is multi-device,
// returns kInvalidLocalHandle.
FunctionLibraryRuntime::LocalHandle GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) const;
// Fills `output_devices` with the devices on which the results will
// be produced. If some output is produced on CPU, the corresponding Device*
// is set to nullptr. If some output is DT_RESOURCE, the corresponding Device*
// is set to the device backing the resource.
// REQUIRES: `handle` identifies a multi-device function.
Status GetOutputDevices(FunctionLibraryRuntime::Handle handle,
std::vector<Device*>* output_devices) const;
// Returns true if function with handle `handle` was instantiated on device
// `device_name`. Returns false for multi-device functions.
bool IsInstantiatedOnDevice(const string& device_name,
FunctionLibraryRuntime::Handle handle) const;
// Instantiates the function. See framework/function.h for more details.
// Allows for function_name to be instantiated on different devices
// as specified in attrs.
Status Instantiate(const string& function_name, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle);
// Delegates to the local FLR that owns state corresponding to `handle` and
// tells it to release it. If the `handle` isnt' needed at all, the local FLR
// might call RemoveHandle on this to get rid of the state owned by the Proc
// FLR.
// For multi-device functions, calls ReleaseHandle on local FLRs for each
// component function that is part of this multi-device function.
// Each local FLR might call RemoveHandle on this.
Status ReleaseHandle(FunctionLibraryRuntime::Handle handle);
// Runs the function with given `handle`. Function could have been
// instantiated on any device. More details in framework/function.h
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const;
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
FunctionLibraryRuntime::DoneCallback done) const;
const DeviceMgr* device_mgr() { return device_mgr_; }
const DeviceSet* device_set() { return &device_set_; }
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const {
return lib_def_;
}
protected:
friend class FunctionLibraryRuntimeImpl;
struct InternalArgs {
std::vector<Tensor> local_args;
#if !defined(IS_MOBILE_PLATFORM)
std::vector<eager::RemoteTensorHandle*> remote_args;
#endif // IS_MOBILE_PLATFORM
};
struct InternalArgsView {
public:
explicit InternalArgsView(gtl::ArraySlice<Tensor> tensors)
: local_args(tensors) {}
explicit InternalArgsView(const InternalArgs& args)
: local_args(args.local_args) {
#if !defined(IS_MOBILE_PLATFORM)
remote_args = args.remote_args;
#endif // IS_MOBILE_PLATFORM
}
gtl::ArraySlice<Tensor> local_args;
#if !defined(IS_MOBILE_PLATFORM)
absl::Span<eager::RemoteTensorHandle* const> remote_args;
#endif // IS_MOBILE_PLATFORM
};
// Structure to keep track of how a component function (a single-device
// piece of a multi-device function) fits into the multi-device function.
struct ComponentFunctionData {
// The handle for the instantiated component function.
FunctionLibraryRuntime::Handle handle_;
// arg_indices_.size() is the number of arguments to the component function.
// The i-th argument of the component function comes from the
// `arg_indices_[i]`-th argument of the multi-device function.
std::vector<int> arg_indices_;
// ret_indices_.size() is the number of return values of the component
// function. The i-th return value of the component function goes to the
// `ret_indices_[i]`-th return value of the multi-device function.
std::vector<int> ret_indices_;
// arg_alloc_attrs_[i] are the allocator attributes of the i-th argument to
// the component function.
std::vector<AllocatorAttributes> arg_alloc_attrs_;
// ret_alloc_attrs_[i] are the allocator attributes of the i-th return value
// of the component function.
std::vector<AllocatorAttributes> ret_alloc_attrs_;
};
// Data structure holding information for a single instantiated multi-device
// function.
// The fields are filled in during instantiation. Once the object is
// added to mdevice_data_, all fields are constant.
struct MultiDeviceFunctionData {
MultiDeviceFunctionData(const string& function_name,
const string& function_key, int num_outputs,
FunctionLibraryDefinition&& lib_def,
DataTypeVector ret_types)
: function_name_(function_name),
function_key_(function_key),
instantiation_counter_(1),
lib_def_(std::move(lib_def)),
num_outputs_(num_outputs),
ret_types_(std::move(ret_types)) {}
const string function_name_;
const string function_key_;
uint64 instantiation_counter_;
// A library that contains definitions of component functions and their
// transitive dependencies.
FunctionLibraryDefinition lib_def_;
// Stored here to resize the output tensor vector when function is run.
const int num_outputs_;
DataTypeVector ret_types_;
// Maps the device name to the information about the component function
// be run on this device.
std::unordered_map<string, ComponentFunctionData> glue_;
};
struct CleanUpItem {
string device;
uint64 step_id;
FunctionLibraryRuntime::Handle local_handle;
};
// If handle represents a multi-device function, returns the multi-device
// data associated with handle. Else, nullptr.
MultiDeviceFunctionData* IsMultiDevice(
FunctionLibraryRuntime::Handle handle) const;
virtual void RunRemoteDevice(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle local_handle,
const InternalArgsView& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const;
void RunMultiDevice(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, std::vector<Tensor>* rets,
std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
FunctionLibraryRuntime::DoneCallback done,
std::function<InternalArgs(const ComponentFunctionData& comp_data)>
get_component_args) const;
FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback(
std::vector<std::unique_ptr<CleanUpItem>>* items,
FunctionLibraryRuntime::DoneCallback done) const;
DistributedFunctionLibraryRuntime* const parent_;
private:
FunctionLibraryRuntime::Handle AddHandleLocked(
const string& function_key, const string& device_name,
FunctionLibraryRuntime::LocalHandle local_handle)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
// For a given device_name, returns a DeviceContext for copying
// tensors to/from the device.
Status GetDeviceContext(const string& device_name,
DeviceContext** device_context) const;
// Looks up the information for the given `handle` and returns the name
// of the device where the function is registered.
string GetDeviceName(FunctionLibraryRuntime::Handle handle) const;
// Removes handle from the state owned by this object.
Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
// Clones ProcessFunctionLibraryRuntime and FunctionLibraryDefinition
// (transferring ownership of both to the caller). Note that the
// ProcessFunctionLibraryRuntime borrows a pointer to the
// FunctionLibraryDefinition and so the FunctionLibraryDefinition should
// outlive the ProcessFunctionLibraryRuntime.
//
// The `skip_flib_def` argument controls whether the method should clone the
// FunctionLibraryDefinition (default behavior) or return an empty function
// library. The latter is used by tf.data, which manages
// FunctionLibraryDefinitions for its functions independently (and passes
// these into the FunctionLibraryRuntime through an overlay), to avoid linear
// runtime w.r.t. to number of functions in the current function library.
Status Clone(Env* env, int graph_def_version,
const OptimizerOptions& optimizer_options,
const CustomKernelCreator* custom_kernel_creator,
std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
bool skip_flib_def = false) const;
Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle);
Status InstantiateMultiDevice(
const string& function_name, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle);
FunctionLibraryRuntime::Handle AddMultiDeviceHandle(
const std::unique_ptr<MultiDeviceFunctionData> data,
const string& function_key);
// TODO(iga): Reword
// Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the
// corresponding resource lives. This ensures that the Placer assigns ops that
// access these resources to the appropriate devices.
Status PinArgsAndRets(const std::vector<string>& input_devices,
const std::vector<string>& output_devices,
const DeviceSet& device_set,
const std::vector<Node*>& arg_nodes,
const std::vector<Node*>& ret_nodes) const;
void RunInternal(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
const InternalArgsView& args, std::vector<Tensor>* rets,
std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
FunctionLibraryRuntime::DoneCallback done) const;
void CleanUp(std::vector<std::unique_ptr<CleanUpItem>>* items,
FunctionLibraryRuntime::DoneCallback done) const;
// Data structure holding information for a single instantiated remote
// (to be executed on `target_device`) function.
class FunctionData {
public:
FunctionData(const string& target_device,
FunctionLibraryRuntime::LocalHandle local_handle,
const string& function_key)
: target_device_(target_device),
local_handle_(local_handle),
function_key_(function_key) {}
string target_device() { return target_device_; }
const string& function_key() { return function_key_; }
FunctionLibraryRuntime::LocalHandle local_handle() {
mutex_lock l(mu_);
return local_handle_;
}
// Initializes the FunctionData object by potentially making an Initialize
// call to the DistributedFunctionLibraryRuntime.
Status DistributedInit(
DistributedFunctionLibraryRuntime* parent, const string& function_name,
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options);
private:
mutex mu_;
const string target_device_;
FunctionLibraryRuntime::LocalHandle local_handle_ GUARDED_BY(mu_);
const string function_key_;
bool init_started_ GUARDED_BY(mu_) = false;
Status init_result_ GUARDED_BY(mu_);
Notification init_done_;
};
mutable mutex mu_;
Env* const env_;
const absl::optional<const ConfigProto> config_;
const DeviceMgr* const device_mgr_;
DeviceSet device_set_;
const FunctionLibraryDefinition* lib_def_;
thread::ThreadPool* default_thread_pool_;
// Holds all the function instantiations. Maps function_keys to handles.
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
GUARDED_BY(mu_);
// Function data for instantitated remote functions.
std::unordered_map<FunctionLibraryRuntime::Handle,
std::unique_ptr<FunctionData>>
function_data_ GUARDED_BY(mu_);
// Function data for instantiated multi-device functions.
std::unordered_map<FunctionLibraryRuntime::Handle,
std::unique_ptr<MultiDeviceFunctionData>>
mdevice_data_ GUARDED_BY(mu_);
std::unique_ptr<
std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>>>
flr_map_;
int next_handle_ GUARDED_BY(mu_);
const SessionMetadata* const session_metadata_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_