blob: 1543a341c2a0fd1288965dbb4b965092163b66db [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include <iterator>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/partitioning_utils.h"
#include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"
#include "tensorflow/core/util/ptr_util.h"
#include "tensorflow/core/util/reffed_status_callback.h"
namespace tensorflow {
const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
DistributedFunctionLibraryRuntime* parent, const string& function_name,
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::DoneCallback done) {
{
mutex_lock l(mu_);
is_cross_process_ = true;
if (init_started_) {
init_done_.WaitForNotification();
done(init_result_);
return;
}
init_started_ = true;
}
parent->Instantiate(function_name, lib_def, attrs, options, &local_handle_,
[this, done](const Status& s) {
init_done_.Notify();
done(s);
});
}
ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* default_thread_pool,
DistributedFunctionLibraryRuntime* parent,
const CustomKernelCreator* custom_kernel_creator,
const SessionMetadata* session_metadata,
Rendezvous::Factory rendezvous_factory)
: parent_(parent),
env_(env),
config_(config ? absl::make_optional(*config) : absl::nullopt),
device_mgr_(device_mgr),
lib_def_(lib_def),
default_thread_pool_(default_thread_pool),
flr_map_(new std::unordered_map<Device*,
std::unique_ptr<FunctionLibraryRuntime>>),
next_handle_(0),
session_metadata_(session_metadata),
rendezvous_factory_(std::move(rendezvous_factory)) {
if (device_mgr == nullptr) {
(*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
graph_def_version, lib_def_, default_thread_pool, optimizer_options,
custom_kernel_creator, session_metadata_, this);
return;
}
for (Device* d : device_mgr->ListDevices()) {
(*flr_map_)[d] = NewFunctionLibraryRuntime(
device_mgr, env, config_ ? &(*config_) : nullptr, d, graph_def_version,
lib_def_, default_thread_pool, optimizer_options, custom_kernel_creator,
session_metadata_, this);
}
InitializeDeviceSet();
}
/* static */
Status ProcessFunctionLibraryRuntime::SendTensors(
const string& source_device, const string& target_device,
const string& key_prefix, int64 src_incarnation,
gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
RendezvousInterface* rendezvous) {
std::vector<string> keys;
for (int i = 0; i < tensors_to_send.size(); ++i) {
string name = strings::StrCat(key_prefix, i);
string key = Rendezvous::CreateKey(source_device, src_incarnation,
target_device, name, FrameAndIter(0, 0));
keys.push_back(key);
}
TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
return Status::OK();
}
/* static */
void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
const string& source_device, const string& target_device,
const string& key_prefix, int64 src_incarnation, int64 num_tensors,
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
StatusCallback done) {
std::vector<string> keys;
for (int64 i = 0; i < num_tensors; ++i) {
string name = strings::StrCat(key_prefix, i);
string key = Rendezvous::CreateKey(source_device, src_incarnation,
target_device, name, FrameAndIter(0, 0));
keys.push_back(key);
}
RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
received_tensors, std::move(done));
}
Status ProcessFunctionLibraryRuntime::GetRetTypes(
FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types) {
FunctionLibraryRuntime* flr = nullptr;
{
tf_shared_lock l(mu_);
auto miter = mdevice_data_.find(h);
if (miter != mdevice_data_.end()) {
*ret_types = miter->second->ret_types_;
return Status::OK();
}
auto fiter = function_data_.find(h);
if (fiter != function_data_.end()) {
flr = GetFLR(fiter->second->target_device());
}
}
if (flr != nullptr) {
return flr->GetRetTypes(h, ret_types);
}
return errors::InvalidArgument("Handle ", h, " not found.");
}
Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
const string& device_name, int64* incarnation) const {
FunctionLibraryRuntime* flr = GetFLR(device_name);
if (flr == nullptr) {
return errors::InvalidArgument("Device name: ", device_name, " not found.");
}
*incarnation = flr->device()->attributes().incarnation();
return Status::OK();
}
Status ProcessFunctionLibraryRuntime::GetDeviceContext(
const string& device_name, DeviceContext** device_context) const {
*device_context = nullptr;
FunctionLibraryRuntime* flr = GetFLR(device_name);
if (flr == nullptr) {
return errors::InvalidArgument("Device name: ", device_name, " not found.");
}
Device* device = flr->device();
string device_type = device->parsed_name().type;
if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
// "TPU_SYSTEM" indicates that `device` is a CPU.
return Status::OK();
}
if (device_type == "GPU" || device_type == "TPU") {
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
if (dev_info) {
*device_context = dev_info->default_context;
return Status::OK();
}
}
return errors::Internal("Device type: ", device_type,
" is currently unsupported for remote ",
"function executions");
}
void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
DeviceMgr const* all_devices = device_mgr_;
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
all_devices = parent_->remote_device_mgr();
}
device_set_.reset(new DeviceSet);
for (auto d : all_devices->ListDevices()) {
device_set_->AddDevice(d);
}
}
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
const string& device_name) const {
Device* device = nullptr;
if (device_name != kDefaultFLRDevice) {
if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
VLOG(1) << "Could not find device: " << device_name;
return nullptr;
}
}
const auto& iter = flr_map_->find(device);
if (iter == flr_map_->end()) {
LOG(ERROR) << "Could not find device: " << device_name;
return nullptr;
}
return iter->second.get();
}
FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
const string& function_key, const string& device_name,
FunctionLibraryRuntime::LocalHandle local_handle) {
mutex_lock l(mu_);
return AddHandleLocked(function_key, device_name, local_handle);
}
FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
const string& function_key, const string& device_name,
FunctionLibraryRuntime::LocalHandle local_handle) {
auto h = next_handle_;
function_data_[h] =
absl::make_unique<FunctionData>(device_name, local_handle, function_key);
table_[function_key] = h;
next_handle_++;
return h;
}
FunctionLibraryRuntime::Handle
ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
mutex_lock l(mu_);
auto h = next_handle_;
mdevice_data_[h] = std::move(data);
table_[function_key] = h;
next_handle_++;
return h;
}
FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
const string& function_key) const {
tf_shared_lock l(mu_);
return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
}
bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) const {
return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
}
FunctionLibraryRuntime::LocalHandle
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) const {
tf_shared_lock l(mu_);
auto miter = mdevice_data_.find(handle);
if (miter != mdevice_data_.end()) {
return kInvalidLocalHandle;
}
auto iter = function_data_.find(handle);
if (iter == function_data_.end()) {
return kInvalidLocalHandle;
}
FunctionData* function_data = iter->second.get();
if (function_data->target_device() != device_name) {
return kInvalidLocalHandle;
}
return function_data->local_handle();
}
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) const {
tf_shared_lock l(mu_);
auto iter = function_data_.find(handle);
CHECK(iter != function_data_.end());
FunctionData* function_data = iter->second.get();
return function_data->target_device();
}
ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
ProcessFunctionLibraryRuntime::IsMultiDevice(
FunctionLibraryRuntime::Handle handle) const {
tf_shared_lock l(mu_);
const auto& it = mdevice_data_.find(handle);
if (it != mdevice_data_.end()) {
return it->second.get();
}
return nullptr;
}
namespace {
// Sets `group` to the first colocation group specified in `node`. If no
// group is specified, does not touch `group`.
void GetColocationGroup(const Node* node, string* group) {
// We hoist the conversion from C-style string literal to string here,
// so that we can avoid the many repeated calls to strlen().
static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
const AttrValue* attr_value =
node->attrs().Find(kColocationAttrNameStringPiece);
if (attr_value != nullptr && attr_value->has_list() &&
attr_value->list().s_size() > 0) {
*group = attr_value->list().s(0);
}
}
const string* AssignedOrRequestedDeviceName(const Node& node) {
if (node.has_assigned_device_name()) {
return &node.assigned_device_name();
}
return &node.requested_device();
}
Status SetArgShape(
const std::unordered_map<int, DtypeAndPartialTensorShape>&
input_resource_dtypes_and_shapes,
const std::vector<Node*>& arg_nodes) {
for (Node* n : arg_nodes) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
if (dtype == DT_RESOURCE) {
auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
AttrValue dtype_attr_value;
dtype_attr_value.mutable_list()->add_type(
dtype_and_shape_iter->second.dtype);
n->AddAttr("_handle_dtypes", dtype_attr_value);
TensorShapeProto shape_proto;
dtype_and_shape_iter->second.shape.AsProto(&shape_proto);
AttrValue shape_attr_value;
*shape_attr_value.mutable_list()->add_shape() = shape_proto;
n->AddAttr("_handle_shapes", shape_attr_value);
}
}
}
return Status::OK();
}
// Returns the local tensors referred by `args`.
std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) {
std::vector<Tensor> tensors;
for (const auto& arg : args) {
if (arg.index() == 0) {
tensors.push_back(absl::get<Tensor>(arg));
}
}
return tensors;
}
} // anonymous namespace
Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
const std::vector<string>& input_devices,
const std::vector<string>& output_devices, const DeviceSet& device_set,
const std::vector<Node*>& arg_nodes, const std::vector<Node*>& ret_nodes,
Device* default_device) const {
// If output_devices are not specified, we want to set the output device
// based on the device of the output producing node. The output producing
// node can be an arg node because functions can simply return their
// arguments. To make sure that the output producing nodes have assigned
// devices, we assign them to arguments first.
for (Node* node : arg_nodes) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int64 index = attr_value->i();
node->set_assigned_device_name(input_devices[index]);
}
for (Node* node : ret_nodes) {
if (output_devices.empty()) {
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
VLOG(3) << "Trying to determine device for node " << node->name()
<< "[T=" << DataTypeString(dtype) << "]";
// If output_devices are empty, the node producing retval
// must have explicitly assigned device or a colocation constraint
// to a node with explicitly assigned device.
for (const auto& it : node->in_edges()) {
if (it->IsControlEdge()) continue;
Node* src_node = it->src();
const string* src_device = AssignedOrRequestedDeviceName(*src_node);
string colocation_group = "";
GetColocationGroup(src_node, &colocation_group);
VLOG(3) << "Considering src: " << src_node->name()
<< " src_device: " << *src_device
<< " colo group: " << colocation_group;
while (src_device->empty() && colocation_group.empty() &&
src_node->IsIdentity()) {
// Only follows the real data input of Identity, not control edges.
Node* input_node;
TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node));
src_node = input_node;
src_device = AssignedOrRequestedDeviceName(*src_node);
GetColocationGroup(src_node, &colocation_group);
VLOG(3) << "Considering src: " << src_node->name()
<< " src_device: " << *src_device
<< " colo group: " << colocation_group;
}
// If colocation_group is not set and output producing node is assigned
// to a remote device, colocate the retval node with its input node.
// TODO(yujingzhang): Remove this when we support outputting tensors on
// remote devices.
const bool remote_src_device =
!src_device->empty() && GetFLR(*src_device) == nullptr;
if (colocation_group.empty() && remote_src_device) {
colocation_group =
absl::StrCat(kColocationGroupPrefix, it->src()->name());
VLOG(3) << "Considering src: " << src_node->name()
<< " colo group: " << colocation_group;
}
// If resource is produced by a function call node, we can't trust
// source node device assignment, because multi-device functions can
// return resource placed on multiple devices. In such case we leave
// retval device assignment empty, and rely on placer to infer correct
// assignment based on actual output device.
const bool can_use_src_node_device =
!(dtype == DT_RESOURCE && IsFunctionCall(*lib_def_, *src_node));
if (!colocation_group.empty()) {
AttrValue::ListValue colo_attr;
colo_attr.add_s(colocation_group);
std::vector<string> colo_slice = {colocation_group};
node->AddAttr(kColocationAttrName, colo_slice);
} else if (!src_device->empty() && can_use_src_node_device) {
// src_device can be a partially specified device. Find the
// matching device in the device_set.
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
return errors::InvalidArgument(
"Failed to parse explicit device specification ", *src_device);
}
std::vector<Device*> matching_devices;
device_set.FindMatchingDevices(parsed, &matching_devices);
if (matching_devices.empty()) {
if (default_device != nullptr) {
matching_devices.push_back(default_device);
} else {
return errors::InvalidArgument(
"Unable to find any devices for spec ", *src_device);
}
} else if (matching_devices.size() != 1) {
// Convert a vector of devices to a string.
// Using absl::StrJoin did not work in Android builds.
string devices = "[";
for (Device* device : matching_devices) {
devices.append(device->name());
devices.append(", ");
}
if (devices.size() > 2) {
devices.resize(devices.size() - 2);
}
devices.append("]");
return errors::InvalidArgument(
"When FunctionLibraryRuntime::Options.output_devices are "
"not specified for a multi-device function, the device "
"specification on the output node must match exactly one "
"device. Matched devices are ",
devices);
}
VLOG(3) << "Setting output device to " << matching_devices[0]->name()
<< " for node " << SummarizeNode(*node);
node->set_assigned_device_name(matching_devices[0]->name());
} else if (!src_device->empty() && !can_use_src_node_device) {
VLOG(3) << "Did not set device for a resource output node "
<< SummarizeNode(*node);
}
}
} else {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
int64 index = attr_value->i();
// output_devices size is checked in InstantiateMultiDevice
DCHECK_GT(output_devices.size(), index);
VLOG(3) << "Setting output device to " << output_devices[index]
<< " for return at index " << index;
node->set_assigned_device_name(output_devices[index]);
}
}
return Status::OK();
}
namespace {
Status ValidateNoListArguments(
const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
const string& function_name) {
for (const OpDef::ArgDef& arg : args) {
if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
return errors::InvalidArgument(
"Function ", function_name, " has an ", arg_type, " named \"",
arg.name(),
"\" that is a list of tensors."
" Multi-device functions support only single-tensor inputs "
" and outputs");
}
}
return Status::OK();
}
Status ValidateMultiDeviceOptions(
const FunctionDef& fdef,
const FunctionLibraryRuntime::InstantiateOptions& options) {
const OpDef& signature = fdef.signature();
// Multi-device functions currently do not support list inputs or outputs.
TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
signature.name()));
TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
signature.name()));
if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
return errors::Unimplemented(
"Function '", signature.name(), "' has `",
FunctionLibraryDefinition::kIntsOnDeviceAttr,
"` attribute set. This attribute is not currently supported by "
"multi-device functions.");
}
if (options.input_devices.size() != signature.input_arg_size()) {
return errors::InvalidArgument(
"InstantiateOptions.input_devices must have the same length "
"as the number of arguments: input_devices length = ",
options.input_devices.size(),
" number of arguments = ", signature.input_arg_size());
}
if (!options.output_devices.empty() &&
options.output_devices.size() != signature.output_arg_size()) {
return errors::InvalidArgument(
"InstantiateOptions.output_devices must either be empty or have the "
"same length as the number of arguments: output_devices length = ",
options.output_devices.size(),
" number of arguments = ", signature.output_arg_size());
}
return Status::OK();
}
Status GetGraphAndArgRets(
const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
std::vector<string>* ret_node_names, DataTypeVector* ret_types,
std::vector<string>* control_ret_node_names) {
std::unique_ptr<FunctionBody> fbody;
// TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, attrs, lib_def, &fbody));
if (!fbody) {
LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
return errors::Internal("Failed to construct FunctionBody for ",
function_name);
}
*graph = std::unique_ptr<Graph>(fbody->graph);
arg_nodes->reserve(fbody->arg_nodes.size());
std::copy(fbody->arg_nodes.begin(), fbody->arg_nodes.end(),
std::back_inserter(*arg_nodes));
ret_nodes->reserve(fbody->ret_nodes.size());
std::copy(fbody->ret_nodes.begin(), fbody->ret_nodes.end(),
std::back_inserter(*ret_nodes));
fbody->graph = nullptr;
ret_node_names->reserve(fbody->ret_nodes.size());
for (const Node* node : fbody->ret_nodes) {
ret_node_names->push_back(node->name());
}
for (const auto& ret_type : fbody->ret_types) {
ret_types->push_back(ret_type);
}
control_ret_node_names->reserve(fbody->control_ret_nodes.size());
for (const Node* node : fbody->control_ret_nodes) {
control_ret_node_names->push_back(node->name());
}
return Status::OK();
}
} // anonymous namespace
Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
const string& function_name, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle) {
// Check if this function has already been instantiated.
const string& function_key = Canonicalize(function_name, attrs, options);
{
mutex_lock l(mu_);
const auto& it = table_.find(function_key);
if (it != table_.end()) {
*handle = it->second;
++mdevice_data_[*handle]->instantiation_counter_;
return Status::OK();
}
}
VLOG(1) << "Instantiating MultiDevice function \"" << function_name
<< "\" on default device \"" << options.target << "\"";
if (VLOG_IS_ON(3)) {
int index = 0;
VLOG(3) << "Requested input devices:";
for (const string& device : options.input_devices) {
VLOG(3) << " [input " << index++ << "] " << device;
}
index = 0;
VLOG(3) << "Requested output devices:";
for (const string& device : options.output_devices) {
VLOG(3) << " [output " << index++ << "] " << device;
}
}
const FunctionLibraryDefinition* lib_def =
options.lib_def == nullptr ? lib_def_ : options.lib_def;
const FunctionDef* fdef = lib_def->Find(function_name);
if (fdef == nullptr) {
return errors::InvalidArgument("Failed to find function \"", function_name,
"\" in function library: ", lib_def);
}
TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
std::unique_ptr<Graph> graph;
std::vector<Node*> arg_nodes, ret_nodes;
std::vector<string> ret_node_names;
DataTypeVector ret_types;
std::vector<string> control_ret_node_names;
TF_RETURN_IF_ERROR(GetGraphAndArgRets(
function_name, attrs, fdef, lib_def, &graph, &arg_nodes, &ret_nodes,
&ret_node_names, &ret_types, &control_ret_node_names));
if (options.graph_collector != nullptr) {
GraphDef def;
graph->ToGraphDef(&def);
*def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
options.graph_collector->CollectRawGraph(def);
}
Device* default_device = nullptr;
if (options.default_device_to_target && !options.target.empty()) {
// Make the `target` device the default device if nothing else is hard
// coded. This allows the same function definition to be specialized to
// different devices depending on the `PartitionedCallOp` device.
FunctionLibraryRuntime* flr = GetFLR(options.target);
if (flr == nullptr) {
return errors::InvalidArgument(
"Cannot instantiate multi-device function with target device ",
options.target);
}
default_device = flr->device();
}
TF_RETURN_IF_ERROR(
SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
TF_RETURN_IF_ERROR(PinArgsAndRets(
options.input_devices, options.output_devices, *device_set_, arg_nodes,
ret_nodes,
options.config_proto.allow_soft_placement() ? default_device : nullptr));
auto data = absl::make_unique<MultiDeviceFunctionData>(
function_name, function_key, ret_node_names.size(),
lib_def->ReachableDefinitions(*fdef), std::move(ret_types));
// Mapping from a function body node name to the control output name.
std::unordered_map<string, string> node_name_to_control_ret;
bool control_rets_updated = false;
TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
*device_set_, options.config_proto, &graph, &data->lib_def_,
&control_ret_node_names, &control_rets_updated));
if (control_rets_updated) {
// Function graph pass may have resulted in different nodes/node names for
// control rets.
for (const auto& control_ret : control_ret_node_names) {
node_name_to_control_ret.emplace(control_ret, control_ret);
}
} else {
for (const auto& control_ret : fdef->control_ret()) {
node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
}
}
GraphOptimizationPassOptions optimization_options;
// TODO(iga): Thread other relevant options from SessionOptions.
SessionOptions session_options;
session_options.env = env_;
session_options.config = options.config_proto;
optimization_options.session_options = &session_options;
optimization_options.graph = &graph;
optimization_options.flib_def = &data->lib_def_;
optimization_options.device_set = device_set_.get();
optimization_options.is_function_graph = true;
DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
// TODO(b/124993244): Smartly merge options in nested defuns, and raise
// exceptions/warnings in case where nested function call options are ignored.
DumpGraph("Before calling Placer", graph.get());
Placer placer(graph.get(), function_name, optimization_options.flib_def,
device_set_.get(), default_device,
options.config_proto.allow_soft_placement(),
options.config_proto.log_device_placement());
TF_RETURN_IF_ERROR(placer.Run());
DumpGraph("Before running POST_PLACEMENT passes", graph.get());
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
Device* cpu_device;
TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
if (options.optimize_graph_fn) {
DumpGraph("Before running graph optimization fn", graph.get());
Status status = options.optimize_graph_fn(
std::move(ret_node_names), std::move(control_ret_node_names),
&data->lib_def_, *device_set_, cpu_device, &graph);
if (!status.ok()) {
LOG(WARNING) << "Ignoring multi-device function optimization failure: "
<< status.ToString();
}
DumpGraph("After optimization", graph.get());
}
DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
if (options.graph_collector != nullptr) {
GraphDef def;
graph->ToGraphDef(&def);
*def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
options.graph_collector->CollectOptimizedGraph(def);
}
VLOG(4) << "Main function graph to be partitioned:";
VLOG(4) << DebugString(graph->ToGraphDefDebug());
std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
TF_RETURN_IF_ERROR(
PartitionFunctionGraph(*device_set_, std::move(graph), &subgraphs));
for (const auto& pair : subgraphs) {
DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
pair.first, ")"),
pair.second.get());
}
optimization_options.graph = nullptr;
optimization_options.device_set = nullptr;
optimization_options.partition_graphs = &subgraphs;
// Normally POST_PARTITIONING passes are run by distributed workers.
// Distributed workers are currently not supported in this code path, so we
// run the passes here.
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
for (const auto& pair : subgraphs) {
const auto* optimized_subgraph = pair.second.get();
DumpGraph(
strings::StrCat("After all optimization passes (", pair.first, ")"),
optimized_subgraph);
if (VLOG_IS_ON(1)) {
DumpGraphDefToFile(
strings::StrCat("pflr_after_all_optimization_passes_",
reinterpret_cast<uintptr_t>(optimized_subgraph), "_",
pair.first),
optimized_subgraph->ToGraphDefDebug());
}
}
if (options.graph_collector != nullptr) {
for (const auto& pair : subgraphs) {
GraphDef def;
pair.second->ToGraphDef(&def);
*def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
options.graph_collector->CollectPartitionedGraph(def);
}
}
// We must preserve control returns in each of the function components,
// otherwise after function inlining we might prune side-effectful nodes.
const auto control_ret =
[&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
const auto it = node_name_to_control_ret.find(n->name());
return it != node_name_to_control_ret.end()
? absl::make_optional<string>(it->second)
: absl::nullopt;
};
int i = 0;
// Generate a random function_name to avoid one function reuse the partition
// function instantiated by another function.
FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
FunctionNameGenerator name_generator(
data_lib_def, absl::StrCat(function_name, "_", random::New64()));
auto subgraph_size = subgraphs.size();
gtl::InlinedVector<Status, 4> instantiate_status(subgraph_size);
BlockingCounter counter(static_cast<int>(subgraph_size));
auto runner = [this, subgraph_size](std::function<void()> fn) {
// NOTE: Only use thread pool to instantiate sub-function when there are
// more than 8 sub-functions. We want to avoid cost of switching thread when
// there are only a few sub-functions.
if (default_thread_pool_ != nullptr && subgraph_size > 8) {
default_thread_pool_->Schedule(fn);
} else {
fn();
}
};
for (const auto& pair : subgraphs) {
Status* status = &instantiate_status[i];
string unique_name = name_generator.GetName();
ComponentFunctionData* comp_data = &data->glue_[pair.first];
runner([this, &pair, comp_data, unique_name, data_lib_def, &control_ret,
&options, status, &counter, &data] {
const string& target = pair.first;
const string& device_type =
device_set_->FindDeviceByName(target)->device_type();
Graph* subgraph = pair.second.get();
status->Update(UpdateArgAndRetvalMetadata(
subgraph, device_type, &comp_data->arg_indices_,
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
&comp_data->ret_alloc_attrs_));
if (!status->ok()) {
counter.DecrementCount();
return;
}
FunctionDef shard;
status->Update(
GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
if (!status->ok()) {
counter.DecrementCount();
return;
}
status->Update(data_lib_def->AddFunctionDef(shard));
if (!status->ok()) {
counter.DecrementCount();
return;
}
FunctionLibraryRuntime::InstantiateOptions opts;
opts.executor_type = options.executor_type;
opts.target = target;
opts.lib_def = data_lib_def;
opts.create_kernels_eagerly = options.create_kernels_eagerly;
opts.state_handle = options.state_handle;
auto attrs = AttrSlice(&shard.attr());
VLOG(1) << "Start instantiating component function " << unique_name
<< " on device " << target;
VLOG(4) << DebugString(shard);
auto* component_handle = new FunctionLibraryRuntime::Handle;
auto done = [this, status, unique_name, comp_data, component_handle,
&data, &counter](const Status& s) {
status->Update(s);
VLOG(1) << "Finished instantiating component function " << unique_name
<< " with handle " << *component_handle << " status: " << s;
if (status->ok()) {
{
mutex_lock l(mu_);
if (function_data_[*component_handle]->is_cross_process()) {
data->is_cross_process_ = true;
}
}
comp_data->handle_ = *component_handle;
}
delete component_handle;
counter.DecrementCount();
};
FunctionLibraryRuntime* flr = GetFLR(opts.target);
if (flr != nullptr) {
// Initialize local function synchronously.
Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
done(s);
} else {
// Initialize remote function asynchronously.
InstantiateRemote(unique_name, attrs, opts, component_handle, done);
}
});
i += 1;
}
counter.Wait();
StatusGroup group;
for (auto& status : instantiate_status) {
group.Update(status);
}
TF_RETURN_IF_ERROR(group.as_summary_status());
*handle = AddMultiDeviceHandle(std::move(data), function_key);
VLOG(2) << "Instantiated MultiDevice function \"" << function_name
<< "\" with handle " << *handle;
return Status::OK();
}
Status ProcessFunctionLibraryRuntime::GetOutputDevices(
FunctionLibraryRuntime::Handle handle,
std::vector<Device*>* output_devices) const {
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
if (data == nullptr) {
return errors::InvalidArgument(
"Failed for find multi-device function handle ", handle);
}
for (const auto& pair : data->glue_) {
const ComponentFunctionData& comp_data = pair.second;
DCHECK(comp_data.ret_alloc_attrs_.size() == comp_data.ret_indices_.size());
const string& target = pair.first;
FunctionLibraryRuntime* target_flr = GetFLR(target);
if (target_flr == nullptr) {
if (!comp_data.ret_indices_.empty()) {
return errors::Unimplemented(
"Currently, outputting tensors on remote devices is not supported. "
"The ",
comp_data.ret_indices_[0],
"-th return value of the function outputs to target_device: ",
target,
" Please copy the tensor to local device explicitly using "
"tf.identity and return the new Tensor instead.");
}
continue;
}
Device* target_device = target_flr->device();
const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle_);
DCHECK(fbody != nullptr);
output_devices->resize(data->num_outputs_);
for (int j = 0; j < comp_data.ret_indices_.size(); ++j) {
int ret_index = comp_data.ret_indices_[j];
if (fbody->ret_types[j] == DT_RESOURCE) {
(*output_devices)[ret_index] = target_device;
} else {
(*output_devices)[ret_index] =
comp_data.ret_alloc_attrs_[j].on_host() ? nullptr : target_device;
}
}
}
return Status::OK();
}
void ProcessFunctionLibraryRuntime::RunRemoteDevice(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle local_handle,
gtl::ArraySlice<FunctionArg> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
parent_->Run(opts, local_handle, GetLocalArgs(args), rets, std::move(done));
}
void ProcessFunctionLibraryRuntime::RunMultiDevice(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, std::vector<Tensor>* rets,
std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
FunctionLibraryRuntime::DoneCallback done,
std::function<Status(const ComponentFunctionData& comp_data,
InternalArgs* args)>
get_component_args) const {
if (opts.create_rendezvous) {
// FLR->Run() is the default entry point. It checks for cancellation,
// creates rendezvous, etc.
// Letting create_rendezvous through will do the wrong thing - each
// component function will get a separate rendezvous created by its FLR.
done(
errors::Internal("Cannot call ProcessFunctionLibraryRuntime::Run with "
"create_rendezvous=true. Please run the function "
"using FunctionLibraryRuntime::Run"));
return;
}
const MultiDeviceFunctionData* data = IsMultiDevice(handle);
if (data == nullptr) {
done(
errors::InvalidArgument("Failed for find multi-device function handle ",
handle, ". Was the function instantiated?"));
return;
}
VLOG(1) << "Running multi-device function " << data->function_name_;
VLOG(4) << " with " << opts.DebugString();
if (data->glue_.empty()) {
// Trivial case where the function body is empty.
done(Status::OK());
return;
}
// Check whether we have the right rendezvous.
if (opts.rendezvous && data->is_cross_process_ &&
!opts.rendezvous->is_cross_process()) {
done(errors::InvalidArgument(
"Running a cross process function ", data->function_name_,
" without an appropriate cross process Rendezvous."));
return;
}
auto* refcounted_done = new ReffedStatusCallback(std::move(done));
for (int i = 0; i < data->glue_.size(); ++i) {
refcounted_done->Ref();
}
FunctionLibraryRuntime::Options opts_copy = opts;
for (const auto& pair : data->glue_) {
const string& target = pair.first;
const ComponentFunctionData& comp_data = pair.second;
FunctionLibraryRuntime::Handle handle = pair.second.handle_;
opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_;
opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_;
opts_copy.remote_execution = false;
InternalArgs comp_args;
Status s = get_component_args(comp_data, &comp_args);
if (!s.ok()) {
VLOG(2) << "Failed to get component function arguments: " << s;
refcounted_done->UpdateStatus(s);
refcounted_done->Unref();
continue;
}
std::vector<Tensor>* comp_rets = new std::vector<Tensor>;
rets->resize(data->num_outputs_);
FunctionLibraryRuntime* flr = GetFLR(target);
if (flr != nullptr) {
// When target device has private thread pool, use the target device
// runner
thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
opts_copy.runner = (pool == nullptr) ? opts_copy.runner : flr->runner();
VLOG(1) << "Running component function on device " << target
<< " with handle " << handle;
VLOG(4) << " with " << opts_copy.DebugString();
flr->Run(opts_copy, handle, GetLocalArgs(comp_args.args), comp_rets,
[comp_rets, rets, comp_data, refcounted_done,
data](const Status& status) {
if (!status.ok()) {
VLOG(2) << "Component function execution failed: " << status;
const string function_and_msg = strings::StrCat(
errors::FormatFunctionForError(data->function_name_),
" ", status.error_message());
refcounted_done->UpdateStatus(
Status(status.code(), function_and_msg));
} else {
for (int i = 0; i < comp_rets->size(); ++i) {
(*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
}
}
delete comp_rets;
// refcounted_done is thread-safe
refcounted_done->Unref();
});
} else {
opts_copy.remote_execution = true;
VLOG(1) << "Running component function on device " << target
<< " with handle " << handle;
VLOG(4) << " with " << opts_copy.DebugString();
RunInternal(
opts_copy, handle, comp_args.args, comp_rets, cleanup_items,
[comp_rets, rets, comp_data, refcounted_done](const Status& status) {
if (!status.ok()) {
VLOG(2) << "Component function execution failed: " << status;
refcounted_done->UpdateStatus(status);
} else {
for (int i = 0; i < comp_rets->size(); ++i) {
(*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
}
}
delete comp_rets;
// refcounted_done is thread-safe
refcounted_done->Unref();
});
}
}
refcounted_done->Unref();
}
Status ProcessFunctionLibraryRuntime::Instantiate(
const string& function_name, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle) {
if (options.is_multi_device_function) {
return InstantiateMultiDevice(function_name, attrs, options, handle);
}
*handle = kInvalidHandle;
FunctionLibraryRuntime* flr = GetFLR(options.target);
if (flr != nullptr) {
return flr->Instantiate(function_name, attrs, options, handle);
}
Status status;
Notification notification;
InstantiateRemote(function_name, attrs, options, handle,
[&status, &notification](const Status& s) {
status = s;
notification.Notify();
});
notification.WaitForNotification();
return status;
}
Status ProcessFunctionLibraryRuntime::IsCrossProcess(
FunctionLibraryRuntime::Handle handle, bool* is_cross_process) const {
tf_shared_lock l(mu_);
const auto& mdevice_it = mdevice_data_.find(handle);
if (mdevice_it != mdevice_data_.end()) {
*is_cross_process = mdevice_it->second->is_cross_process_;
return Status::OK();
}
const auto& it = function_data_.find(handle);
if (it != function_data_.end()) {
*is_cross_process = it->second->is_cross_process();
return Status::OK();
}
return errors::InvalidArgument("Handle ", handle, " not found.");
}
void ProcessFunctionLibraryRuntime::InstantiateRemote(
const string& function_name, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle,
FunctionLibraryRuntime::DoneCallback done) {
if (parent_ == nullptr) {
done(errors::Internal(
"Currently don't support instantiating functions on device: ",
options.target));
return;
}
auto target = options.target;
VLOG(1) << "ProcessFLR Instantiate: " << function_name << " on: " << target;
string function_key = Canonicalize(function_name, attrs, options);
FunctionData* f;
{
mutex_lock l(mu_);
FunctionLibraryRuntime::Handle h =
gtl::FindWithDefault(table_, function_key, kInvalidHandle);
if (h == kInvalidHandle || function_data_.count(h) == 0) {
h = AddHandleLocked(function_key, target, kInvalidHandle);
}
f = function_data_[h].get();
*handle = h;
}
f->DistributedInit(
parent_, function_name,
options.lib_def == nullptr ? *lib_def_ : *options.lib_def, attrs, options,
[this, function_name, target, handle, done](const Status& s) {
VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
<< " on: " << target << " with handle: " << *handle
<< " (this: " << this << ")";
done(s);
});
}
Status ProcessFunctionLibraryRuntime::RemoveHandle(
FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
table_.erase(function_data_[handle]->function_key());
function_data_.erase(handle);
return Status::OK();
}
Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
FunctionLibraryRuntime::Handle handle) {
std::unique_ptr<MultiDeviceFunctionData> mdata;
{
mutex_lock l(mu_);
auto it = mdevice_data_.find(handle);
--it->second->instantiation_counter_;
if (it->second->instantiation_counter_ != 0) {
return Status::OK();
}
mdata = std::move(it->second);
table_.erase(mdata->function_key_);
mdevice_data_.erase(it);
}
// If we are here we are releasing the last instantiation of `handle`.
// Release all component function handles.
Status overall_status;
for (const auto& it : mdata->glue_) {
const string& device = it.first;
FunctionLibraryRuntime::Handle flr_handle = it.second.handle_;
FunctionLibraryRuntime* flr = GetFLR(device);
if (flr == nullptr) {
// TODO(nareshmodi): Implement DeregisterGraph call to remote device if
// parent is not null.
if (parent_ != nullptr) {
return errors::Unimplemented(
"Releasing a multi-device component handle on a remote device is "
"not yet implemented.");
}
return errors::InvalidArgument(
"Failed to find FunctionLibraryRuntime for device ", device,
" when releasing multi-device function handle ", handle);
}
Status status = flr->ReleaseHandle(flr_handle);
if (!status.ok()) {
overall_status = status;
}
}
return overall_status;
}
Status ProcessFunctionLibraryRuntime::ReleaseHandle(
FunctionLibraryRuntime::Handle handle) {
// Return directly if all function handles has already been released.
if (flr_map_ == nullptr) return Status::OK();
if (IsMultiDevice(handle)) {
return ReleaseMultiDeviceHandle(handle);
}
FunctionLibraryRuntime* flr = nullptr;
string target_device;
{
mutex_lock l(mu_);
CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
target_device = function_data_[handle]->target_device();
}
flr = GetFLR(target_device);
if (flr != nullptr) {
return flr->ReleaseHandle(handle);
}
return errors::InvalidArgument("Handle not found: ", handle);
}
FunctionLibraryRuntime::DoneCallback
ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
std::vector<std::unique_ptr<CleanUpItem>>* items,
FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
const Rendezvous* created_rendezvous) const {
return
[this, items, done = std::move(done), step_id,
created_rendezvous](const Status& status) {
if (created_rendezvous) {
DCHECK(rendezvous_factory_);
created_rendezvous->Unref();
Status s = rendezvous_factory_.CleanUp(step_id);
if (!s.ok()) {
LOG(ERROR) << s;
}
}
auto* local_status = new Status(status);
CleanUp(items, [local_status, done](const Status& cleanup_status) {
local_status->Update(cleanup_status);
done(*local_status);
delete local_status;
});
delete items;
};
}
void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
FunctionLibraryRuntime::Options new_opts = opts;
Rendezvous* created_rendezvous = nullptr;
if (!opts.rendezvous) {
if (rendezvous_factory_) {
Status s =
rendezvous_factory_(opts.step_id, device_mgr_, &created_rendezvous);
if (!s.ok()) {
done(s);
return;
}
new_opts.rendezvous = created_rendezvous;
} else {
done(
errors::FailedPrecondition("The caller does not provide a rendezvous "
"and ProcessFunctionLibraryRuntime was "
"created without a rendezvous factory."));
return;
}
new_opts.create_rendezvous = false;
}
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
new_opts.step_id, created_rendezvous);
bool multi_device;
{
tf_shared_lock l(mu_);
multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
}
if (multi_device) {
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
InternalArgs* comp_args) -> Status {
for (const auto& tensor :
GetArgsForIndices(comp_data.arg_indices_, args)) {
comp_args->args.push_back(tensor);
}
return Status::OK();
};
return RunMultiDevice(new_opts, handle, rets, cleanup_items,
std::move(done), std::move(get_component_args));
}
std::vector<FunctionArg> local_args;
for (const auto& tensor : args) {
local_args.push_back(tensor);
}
RunInternal(new_opts, handle, local_args, rets, cleanup_items,
std::move(done));
}
void ProcessFunctionLibraryRuntime::RunInternal(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args,
std::vector<Tensor>* rets,
std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
FunctionLibraryRuntime::DoneCallback done) const {
FunctionLibraryRuntime* flr = nullptr;
string target_device;
FunctionLibraryRuntime::LocalHandle local_handle;
{
tf_shared_lock l(mu_);
auto iter = function_data_.find(handle);
if (iter == function_data_.end()) {
done(errors::NotFound("Handle: ", handle, " not found."));
return;
}
FunctionData* function_data = iter->second.get();
target_device = function_data->target_device();
local_handle = function_data->local_handle();
}
if (!opts.remote_execution) {
done(
errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
"only be called for multi-device functions or "
"for remote execution."));
return;
}
flr = GetFLR(target_device);
if (flr != nullptr) {
auto rendezvous = opts.rendezvous;
string source_device = opts.source_device;
DeviceContext* device_context;
Status s = GetDeviceContext(source_device, &device_context);
if (!s.ok()) {
done(s);
return;
}
int64 src_incarnation, target_incarnation;
s = GetDeviceIncarnation(source_device, &src_incarnation);
s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
if (!s.ok()) {
done(s);
return;
}
std::vector<Tensor> local_args = GetLocalArgs(args);
// Send the args over to the target device.
s = SendTensors(source_device, target_device, "arg_", src_incarnation,
local_args, device_context, opts.args_alloc_attrs,
rendezvous);
if (!s.ok()) {
done(s);
return;
}
const std::vector<AllocatorAttributes>& rets_alloc_attrs =
opts.rets_alloc_attrs;
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
flr->Run(opts, handle, local_args, remote_rets,
[source_device, target_device, target_incarnation, rendezvous,
device_context, rets_alloc_attrs, remote_rets, rets,
done = std::move(done)](const Status& status) mutable {
if (!status.ok()) {
delete remote_rets;
done(status);
return;
}
int64 num_returns = remote_rets->size();
delete remote_rets;
// Now receive the return values from the target.
ReceiveTensorsAsync(target_device, source_device, "ret_",
target_incarnation, num_returns,
device_context, rets_alloc_attrs, rendezvous,
rets, std::move(done));
});
return;
}
if (parent_ != nullptr) {
auto cleanup_item = absl::make_unique<CleanUpItem>();
cleanup_item->device = target_device;
cleanup_item->step_id = opts.step_id;
cleanup_item->local_handle = local_handle;
cleanup_items->emplace_back(std::move(cleanup_item));
RunRemoteDevice(opts, local_handle, args, rets, std::move(done));
return;
}
done(errors::Internal("Could not find device"));
}
void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
FunctionLibraryRuntime::DoneCallback done) const {
std::vector<Tensor> args;
args.reserve(frame->num_args());
for (size_t i = 0; i < frame->num_args(); ++i) {
Tensor arg;
Status s = frame->GetArg(i, &arg);
args.push_back(std::move(arg));
if (!s.ok()) {
done(s);
}
}
std::vector<Tensor>* rets = new std::vector<Tensor>;
rets->reserve(frame->num_retvals());
Run(opts, handle, args, rets,
[frame, rets, done = std::move(done)](const Status& status) {
std::unique_ptr<std::vector<Tensor>> rets_releaser(rets);
if (!status.ok()) {
done(status);
return;
}
if (rets->size() != frame->num_retvals()) {
done(errors::Internal(
"Number of return values from function (", rets->size(),
") did not match expected number of return values (",
frame->num_retvals(), ")."));
return;
}
for (size_t i = 0; i < frame->num_retvals(); ++i) {
Status s = frame->SetRetval(i, (*rets)[i]);
if (!s.ok()) {
done(s);
return;
}
}
done(Status::OK());
});
}
void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
const std::vector<Tensor> lcoal_inputs = args.GetLocalTensors();
Run(opts, handle, lcoal_inputs, rets, std::move(done));
}
void ProcessFunctionLibraryRuntime::CleanUp(
std::vector<std::unique_ptr<CleanUpItem>>* items,
FunctionLibraryRuntime::DoneCallback done) const {
auto* refcounted_done = new ReffedStatusCallback(std::move(done));
for (auto& item : *items) {
refcounted_done->Ref();
auto* flr = GetFLR(item->device);
if (flr != nullptr) {
// TODO(fishx): cleanup state for local execution.
refcounted_done->UpdateStatus(
errors::Internal("Cleanup items shouldn't contain local item."));
refcounted_done->Unref();
} else if (parent_ != nullptr) {
parent_->CleanUp(item->step_id, item->local_handle,
[refcounted_done](const Status& status) {
if (!status.ok()) {
refcounted_done->UpdateStatus(status);
}
// refcounted_done is thread-safe
refcounted_done->Unref();
});
} else {
refcounted_done->UpdateStatus(
errors::Internal("Could not find device in cleanup."));
refcounted_done->Unref();
}
}
refcounted_done->Unref();
}
Status ProcessFunctionLibraryRuntime::Clone(
Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
const CustomKernelCreator* custom_kernel_creator,
std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
bool skip_flib_def) const {
if (skip_flib_def) {
*out_lib_def = absl::make_unique<FunctionLibraryDefinition>(
lib_def_->default_registry(), FunctionDefLibrary{});
} else {
*out_lib_def = absl::make_unique<FunctionLibraryDefinition>(*lib_def_);
}
*out_pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
custom_kernel_creator, session_metadata_, rendezvous_factory_);
return Status::OK();
}
} // namespace tensorflow