blob: cf92ee5f3af39d7227ab7d7c5737d9517bc6ff90 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/captured_function.h"
#include <utility>
#include "absl/time/clock.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/stats_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/notification.h"
namespace tensorflow {
namespace data {
namespace {
// Simplistic implementation of the `StepStatsCollectorInterface` that only
// cares about collecting the CPU time needed to execute a captured function.
class SimpleStepStatsCollector : public StepStatsCollectorInterface {
public:
void IncrementProcessingTime(int64 delta) {
mutex_lock l(mu_);
processing_time_ += delta;
}
NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override {
return new SimpleNodeExecStats(this);
}
string ReportAllocsOnResourceExhausted(const string& err) override {
return "";
}
int64 processing_time() {
tf_shared_lock l(mu_);
return processing_time_;
}
private:
class SimpleNodeExecStats : public NodeExecStatsInterface {
public:
explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
: step_stats_collector_(step_stats_collector) {}
void Done(const string& device) override {
step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
start_time_ns_);
delete this;
}
void RecordExecutorStarted() override {
start_time_ns_ = absl::GetCurrentTimeNanos();
}
void RecordComputeStarted() override {}
void RecordComputeEnded() override {}
void RecordExecutorEnded() override {
end_time_ns_ = absl::GetCurrentTimeNanos();
}
bool TrackAllocations() const override { return false; }
void SetMemory(OpKernelContext* ctx) override {}
void SetOutput(int slot, const Tensor* tensor) override {}
void SetReferencedTensors(const TensorReferenceVector& tensors) override {}
void SetScheduled(int64 nanos) override {}
private:
int64 start_time_ns_ = 0;
int64 end_time_ns_ = 0;
SimpleStepStatsCollector* step_stats_collector_; // Not owned.
};
mutex mu_;
int64 processing_time_ GUARDED_BY(mu_) = 0;
};
Status RunShortCircuit(const ShortCircuitInfo& info,
const std::vector<Tensor>& args,
const CapturedFunction* const func,
std::vector<Tensor>* rets) {
VLOG(3) << "Running function " << func->func().name() << " short circuit";
size_t num_args = args.size();
for (size_t i = 0; i < info.indices.size(); ++i) {
if (info.indices[i] < num_args) {
rets->push_back(args[info.indices[i]]);
} else {
rets->push_back(func->captured_inputs()[info.indices[i] - num_args]);
}
}
return Status::OK();
}
Status RunShortCircuit(const ShortCircuitInfo& info, std::vector<Tensor>&& args,
const CapturedFunction* const func,
std::vector<Tensor>* rets) {
VLOG(3) << "Running function " << func->func().name() << " short circuit";
size_t num_args = args.size();
for (size_t i = 0; i < info.indices.size(); ++i) {
if (info.indices[i] < num_args) {
if (info.can_move[i]) {
rets->push_back(std::move(args[info.indices[i]]));
} else {
rets->push_back(args[info.indices[i]]);
}
} else {
rets->push_back(func->captured_inputs()[info.indices[i] - num_args]);
}
}
return Status::OK();
}
Status CreateShortCircuitInfo(OpKernelConstruction* ctx,
const NameAttrList& func,
ShortCircuitInfo* info) {
auto& indices = info->indices;
FunctionLibraryRuntime::Handle fn_handle;
TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
func.name(), AttrSlice(&func.attr()), &fn_handle));
auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
Status s = ctx->function_library()->ReleaseHandle(fn_handle);
if (!s.ok()) {
LOG(WARNING) << "Failed to release handle: " << s.error_message();
}
});
// If the function contains any stateful operations, we conservatively execute
// the entire function.
if (ctx->function_library()->IsStateful(func.name())) {
return Status::OK();
}
const FunctionBody* fn_body =
ctx->function_library()->GetFunctionBody(fn_handle);
indices.resize(fn_body->ret_nodes.size());
for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
Node* ret_node = fn_body->ret_nodes[i];
Node* ret_input_node;
TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
while (ret_input_node->def().op() == "Identity") {
TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
}
if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
TF_RETURN_IF_ERROR(
GetNodeAttr(ret_input_node->def(), "index", &(indices[i])));
} else {
indices.clear();
break;
}
}
// Compute the `can_move` vector.
if (!indices.empty()) {
auto& can_move = info->can_move;
std::map<int, int> last_use;
for (size_t i = 0; i < indices.size(); ++i) {
last_use[indices[i]] = i;
}
can_move.resize(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
can_move[i] = last_use[indices[i]] == i;
}
}
return Status::OK();
}
Status CreateFunctionLibraryDefinition(
const FunctionLibraryDefinition* lib_def, const string& func_name,
std::unique_ptr<FunctionLibraryDefinition>* result) {
DCHECK(lib_def != nullptr);
const FunctionDef* fdef = lib_def->Find(func_name);
if (TF_PREDICT_FALSE(fdef == nullptr)) {
return errors::FailedPrecondition(strings::StrCat(
"Could not find required function definition ", func_name));
}
*result = absl::make_unique<FunctionLibraryDefinition>(
lib_def->ReachableDefinitions(*fdef));
return (*result)->CopyFunctionDefFrom(func_name, *lib_def);
}
Status IsFunctionStateful(const FunctionLibraryDefinition& library,
const FunctionDef& function_def) {
if (!function_def.signature().is_stateful()) {
return Status::OK();
}
for (const NodeDef& node_def : function_def.node_def()) {
TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
}
return Status::OK();
}
// Returns whether an op has been whitelisted as stateless. Uses a heuristic to
// whitelist source dataset ops which have been marked stateful due to
// b/65524810. Also looks up the `op_def->name` in the global
// `WhitelistedStatefulOpRegistry`.
bool IsOpWhitelisted(const OpDef* op_def) {
return (op_def->output_arg_size() == 1 &&
op_def->output_arg(0).type() == DT_VARIANT &&
(absl::EndsWith(op_def->name(), "Dataset") ||
absl::EndsWith(op_def->name(), "DatasetV2"))) ||
WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
}
Status LookupFunction(const FunctionLibraryDefinition& lib_def,
const string& name, const FunctionDef** fdef) {
*fdef = lib_def.Find(name);
if (*fdef == nullptr) {
return errors::InvalidArgument(
"Failed to find function ", name,
" in function library: ", lib_def.ToProto().DebugString());
}
return Status::OK();
}
class CallFrameBase : public CallFrameInterface {
public:
explicit CallFrameBase(DataTypeSlice ret_types)
: ret_types_(ret_types), retvals_(ret_types.size()) {}
// Caller methods.
Status ConsumeRetvals(std::vector<Tensor>* retvals) {
retvals->reserve(retvals_.size());
int i = 0;
for (auto&& val : retvals_) {
if (!val) {
return errors::Internal("No return value for index ", i, ".");
}
retvals->emplace_back(std::move(val.value()));
++i;
}
return Status::OK();
}
size_t num_retvals() const override { return retvals_.size(); }
// Callee methods.
Status SetRetval(int index, const Tensor& val) override {
if (index < retvals_.size() && val.dtype() == ret_types_[index] &&
!retvals_[index]) {
retvals_[index] = val;
return Status::OK();
} else if (index >= retvals_.size()) {
return errors::InvalidArgument("Return value ", index,
" is out of range.");
} else if (val.dtype() != ret_types_[index]) {
return errors::InvalidArgument("Expected type ",
DataTypeString(ret_types_[index]),
" for return value ", index, " but got ",
DataTypeString(val.dtype()), ".");
} else {
return errors::Internal("Attempted to set return value ", index,
" more than once.");
}
}
private:
DataTypeSlice ret_types_;
std::vector<gtl::optional<Tensor>> retvals_;
TF_DISALLOW_COPY_AND_ASSIGN(CallFrameBase);
};
class OwnedArgsCallFrame : public CallFrameBase {
public:
OwnedArgsCallFrame(std::vector<Tensor>&& args,
const std::vector<Tensor>* captured_inputs,
DataTypeSlice ret_types)
: CallFrameBase(ret_types),
args_(std::move(args)),
captured_inputs_(captured_inputs) {}
size_t num_args() const override {
return args_.size() + captured_inputs_->size();
}
// Callee methods.
Status GetArg(int index, Tensor* val) const override {
if (index < args_.size()) {
// TODO(mrry): Consider making `CallFrameInterface::GetArg` non-const in
// order to be able to `std::move(args_[index])` into `*val`.
*val = args_[index];
return Status::OK();
} else if (index < args_.size() + captured_inputs_->size()) {
*val = (*captured_inputs_)[index - args_.size()];
return Status::OK();
} else {
return errors::InvalidArgument("Argument ", index, " is out of range.");
}
}
private:
std::vector<Tensor> args_;
const std::vector<Tensor>* const captured_inputs_; // Not owned.
};
class BorrowedArgsCallFrame : public CallFrameBase {
public:
BorrowedArgsCallFrame(const std::vector<Tensor>& args,
const std::vector<Tensor>* captured_inputs,
DataTypeSlice ret_types)
: CallFrameBase(ret_types),
args_(args),
captured_inputs_(captured_inputs) {}
size_t num_args() const override {
return args_.size() + captured_inputs_->size();
}
// Callee methods.
Status GetArg(int index, Tensor* val) const override {
if (index < args_.size()) {
*val = args_[index];
return Status::OK();
} else if (index < args_.size() + captured_inputs_->size()) {
*val = (*captured_inputs_)[index - args_.size()];
return Status::OK();
} else {
return errors::InvalidArgument("Argument ", index, " is out of range.");
}
}
private:
const std::vector<Tensor>& args_; // Not owned.
const std::vector<Tensor>* const captured_inputs_; // Not owned.
};
} // namespace
Status IsNodeStateful(const FunctionLibraryDefinition& library,
const NodeDef& node) {
const OpDef* op_def;
// TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
// `LookUpOpDef` errors here.
if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
IsOpWhitelisted(op_def) || !op_def->is_stateful() ||
op_def->name() == "Assert") {
return Status::OK();
}
if (op_def->name() == "If") {
const FunctionDef* then_func =
library.Find(node.attr().at("then_branch").func().name());
const FunctionDef* else_func =
library.Find(node.attr().at("else_branch").func().name());
if (then_func != nullptr) {
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
}
if (else_func != nullptr) {
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
}
return Status::OK();
}
if (op_def->name() == "While") {
const FunctionDef* cond_func =
library.Find(node.attr().at("cond").func().name());
const FunctionDef* body_func =
library.Find(node.attr().at("body").func().name());
if (cond_func != nullptr) {
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
}
if (body_func != nullptr) {
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
}
return Status::OK();
}
return errors::FailedPrecondition(op_def->name(), " is stateful.");
}
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func,
StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator) {
std::vector<Tensor> return_values;
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element,
&return_values));
if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
TensorShapeUtils::IsScalar(return_values[0].shape()))) {
return errors::InvalidArgument(
"Function must return a single scalar of dtype DT_VARIANT.");
}
// Retrieve the dataset that was created in `f`.
DatasetBase* returned_dataset;
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
// Create an iterator for the dataset that was returned by `f`.
return returned_dataset->MakeIterator(
ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
/* static */
Status FunctionMetadata::Create(
OpKernelConstruction* ctx, const string& func_name, Params params,
std::shared_ptr<FunctionMetadata>* out_metadata) {
NameAttrList func;
TF_RETURN_IF_ERROR(ctx->GetAttr(func_name, &func));
return Create(ctx, std::move(func), params, out_metadata);
}
Status FunctionMetadata::Create(
OpKernelConstruction* ctx, NameAttrList&& func, Params params,
std::shared_ptr<FunctionMetadata>* out_metadata) {
out_metadata->reset(new FunctionMetadata(std::move(func), params));
TF_RETURN_IF_ERROR(CreateFunctionLibraryDefinition(
ctx->function_library()->GetFunctionLibraryDefinition(),
(*out_metadata)->func_.name(), &(*out_metadata)->lib_def_));
TF_RETURN_IF_ERROR(CreateShortCircuitInfo(
ctx, (*out_metadata)->func_, &(*out_metadata)->short_circuit_info_));
(*out_metadata)->ValidateMultiDevice();
return Status::OK();
}
void FunctionMetadata::ValidateMultiDevice() {
const FunctionDef* fdef = lib_def_->Find(func_.name());
if (is_multi_device_function_) {
auto attr = fdef->attr().find(FunctionLibraryDefinition::kIntsOnDeviceAttr);
if (attr != fdef->attr().end() && attr->second.b()) {
LOG(WARNING)
<< "Disabling multi-device execution for a function that uses the "
<< FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute.";
is_multi_device_function_ = false;
return;
}
auto validate_arg = [this](const OpDef::ArgDef& arg) {
if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
LOG(WARNING) << "Disabling multi-device execution for a function with "
"a vector argument "
<< arg.name() << ".";
is_multi_device_function_ = false;
}
};
for (const auto& arg : fdef->signature().input_arg()) {
validate_arg(arg);
}
for (const auto& arg : fdef->signature().output_arg()) {
validate_arg(arg);
}
}
}
/* static */
Status CapturedFunction::Create(
OpKernelContext* ctx,
const std::shared_ptr<const FunctionMetadata> metadata,
const string& argument_name,
std::unique_ptr<CapturedFunction>* out_function) {
OpInputList inputs;
TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
std::vector<Tensor> captured_inputs(inputs.begin(), inputs.end());
return Create(ctx, metadata, std::move(captured_inputs), out_function);
}
/* static */
Status CapturedFunction::Create(
OpKernelContext* ctx,
const std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor>&& captured_inputs,
std::unique_ptr<CapturedFunction>* out_function) {
*out_function = absl::WrapUnique(
new CapturedFunction(metadata, std::move(captured_inputs)));
return Status::OK();
}
Status CapturedFunction::AddToGraph(
SerializationContext* ctx, DatasetBase::DatasetGraphDefBuilder* b,
std::vector<Node*>* other_arguments,
DataTypeVector* other_arguments_types) const {
other_arguments->reserve(captured_inputs_.size());
other_arguments_types->reserve(captured_inputs_.size());
for (const Tensor& t : captured_inputs_) {
Node* node;
DatasetBase* input;
Status s = GetDatasetFromVariantTensor(t, &input);
if (s.ok()) {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
other_arguments->emplace_back(node);
other_arguments_types->emplace_back(t.dtype());
}
TF_RETURN_IF_ERROR(
b->AddFunction(ctx, metadata_->func().name(), *metadata_->lib_def()));
return Status::OK();
}
Status CapturedFunction::Instantiate(
IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
instantiated_captured_function) {
// The context's runtime will be used for all subsequent calls.
FunctionLibraryRuntime* lib = ctx->flr();
FunctionLibraryRuntime::InstantiateOptions inst_opts;
inst_opts.lib_def = metadata_->lib_def();
inst_opts.create_kernels_eagerly = true;
inst_opts.default_device_to_target = metadata_->use_default_device();
inst_opts.config_proto =
lib->config_proto() ? *lib->config_proto() : ConfigProto();
if (!metadata_->use_inter_op_parallelism()) {
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
}
TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &inst_opts.is_multi_device_function));
// We infer the target device from the function library runtime.
DCHECK(lib->device() != nullptr);
inst_opts.target = lib->device()->name();
if (inst_opts.is_multi_device_function) {
// Compute devices of non-captured inputs.
//
// We infer the number of non-captured inputs by subtracting the number
// of captured inputs from the number of input arguments and we infer the
// input devices from the function library runtime.
const FunctionDef* fdef;
TF_RETURN_IF_ERROR(
LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
size_t num_non_captured_inputs =
fdef->signature().input_arg_size() - captured_inputs_.size();
for (size_t i = 0; i < num_non_captured_inputs; ++i) {
inst_opts.input_devices.push_back(inst_opts.target);
}
// Compute devices of captured inputs.
// TODO(jsimsa): Correctly handle tensors on devices other than CPU:0.
Device* cpu_device;
TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
std::unordered_map<int, DtypeAndPartialTensorShape>&
input_resource_variable_dtypes_and_shapes =
inst_opts.input_resource_dtypes_and_shapes;
for (size_t i = 0; i < captured_inputs_.size(); ++i) {
const auto& input = captured_inputs_[i];
DataType dtype = input.dtype();
if (dtype == DT_RESOURCE) {
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
inst_opts.input_devices.push_back(handle.device());
const auto& dtypes_and_shapes = handle.dtypes_and_shapes();
// Set dtypes and shapes for resource variable inputs.
if (!dtypes_and_shapes.empty()) {
input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
i] =
dtypes_and_shapes.at(0);
}
} else if (MTypeFromDType(dtype) == HOST_MEMORY) {
inst_opts.input_devices.push_back(cpu_device->name());
} else {
// Fall back to using the function library runtime device.
inst_opts.input_devices.push_back(inst_opts.target);
}
}
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
inst_opts.output_devices.push_back(inst_opts.target);
}
}
FunctionLibraryRuntime::Handle f_handle;
TF_RETURN_IF_ERROR(ctx->function_handle_cache()->Instantiate(
metadata_->func().name(), AttrSlice(&metadata_->func().attr()), inst_opts,
&f_handle));
DataTypeVector ret_types;
TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types));
*instantiated_captured_function =
absl::WrapUnique<InstantiatedCapturedFunction>(
new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types),
*ctx->runner(),
ctx->cancellation_manager(), this));
return Status::OK();
}
bool CapturedFunction::IsStateful() const { return !CheckExternalState().ok(); }
Status CapturedFunction::CheckExternalState() const {
for (const auto& name : lib_def()->ListFunctionNames()) {
TF_RETURN_IF_ERROR(
IsFunctionStateful(*lib_def(), *(lib_def()->Find(name))));
}
return Status::OK();
}
InstantiatedCapturedFunction::InstantiatedCapturedFunction(
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
CancellationManager* cancellation_manager, CapturedFunction* captured_func)
: lib_(lib),
f_handle_(f_handle),
ret_types_(std::move(ret_types)),
captured_runner_(std::move(runner)),
cancellation_manager_(cancellation_manager),
captured_func_(captured_func) {}
// NOTE: We don't release f_handle_ here and instead delegate the function
// handle releasing to the FunctionHandleCache. This is because in some cases
// (RepeatDatasetOp in particular), we want to keep the function state (e.g.
// random number generator) even after the Iterator is reset after going through
// one epoch.
InstantiatedCapturedFunction::~InstantiatedCapturedFunction() {}
Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
return RunShortCircuit(info, std::move(args), captured_func_, rets);
}
FunctionLibraryRuntime::Options f_opts;
ScopedStepContainer step_container(
f_opts.step_id, [this](const string& name) {
lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
});
f_opts.step_container = &step_container;
f_opts.runner = ctx->runner();
f_opts.create_rendezvous = ShouldCreateRendezvous();
CancellationManager cancellation_manager;
f_opts.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(ConnectCancellationManagers(
cancellation_manager_, &cancellation_manager, &deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
ret_types_);
Notification n;
Status s;
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
s.Update(func_status);
n.Notify();
});
n.WaitForNotification();
TF_RETURN_IF_ERROR(s);
return frame.ConsumeRetvals(rets);
}
Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
IteratorContext* ctx, const std::vector<Tensor>& args,
std::vector<Tensor>* rets) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
return RunShortCircuit(info, args, captured_func_, rets);
}
FunctionLibraryRuntime::Options f_opts;
ScopedStepContainer step_container(
f_opts.step_id, [this](const string& name) {
lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
});
f_opts.step_container = &step_container;
f_opts.runner = ctx->runner();
f_opts.create_rendezvous = ShouldCreateRendezvous();
CancellationManager cancellation_manager;
f_opts.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(ConnectCancellationManagers(
cancellation_manager_, &cancellation_manager, &deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
ret_types_);
Notification n;
Status s;
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
s.Update(func_status);
n.Notify();
});
n.WaitForNotification();
TF_RETURN_IF_ERROR(s);
return frame.ConsumeRetvals(rets);
}
Status InstantiatedCapturedFunction::RunInstantiated(
const std::vector<Tensor>& args, std::vector<Tensor>* rets) {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
return RunShortCircuit(info, args, captured_func_, rets);
}
FunctionLibraryRuntime::Options f_opts;
ScopedStepContainer step_container(
f_opts.step_id, [this](const string& name) {
lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
});
f_opts.step_container = &step_container;
f_opts.runner = &captured_runner_;
f_opts.create_rendezvous = ShouldCreateRendezvous();
CancellationManager cancellation_manager;
f_opts.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(ConnectCancellationManagers(
cancellation_manager_, &cancellation_manager, &deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
ret_types_);
Notification n;
Status s;
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
s.Update(func_status);
n.Notify();
});
n.WaitForNotification();
TF_RETURN_IF_ERROR(s);
return frame.ConsumeRetvals(rets);
}
void InstantiatedCapturedFunction::RunAsync(
IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done, const string& prefix) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
// Run the `done` callback on a threadpool thread, because it will
// potentially do a non-trivial amount of (e.g. copying) work, and we may
// want to run that concurrently with the next invocation.
Status s = RunShortCircuit(info, std::move(args), captured_func_, rets);
(*ctx->runner())(
std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); },
std::move(done)));
return;
}
// NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
// be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function.
OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
std::move(args), &captured_func_->captured_inputs(), ret_types_);
FunctionLibraryRuntime::Options f_opts;
ResourceMgr* resource_mgr = lib_->device()->resource_manager();
ScopedStepContainer* step_container = new ScopedStepContainer(
f_opts.step_id, [resource_mgr](const string& name) {
resource_mgr->Cleanup(name).IgnoreError();
});
f_opts.step_container = step_container;
f_opts.runner = ctx->runner();
f_opts.create_rendezvous = ShouldCreateRendezvous();
auto cancellation_manager = absl::make_unique<CancellationManager>();
f_opts.cancellation_manager = cancellation_manager.get();
std::function<void()> deregister_fn;
Status s = ConnectCancellationManagers(
ctx->cancellation_manager(), cancellation_manager.get(), &deregister_fn);
if (!s.ok()) {
done(s);
return;
}
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
if (ctx->model() || ctx->stats_aggregator()) {
stats_collector = absl::make_unique<SimpleStepStatsCollector>();
}
f_opts.stats_collector = stats_collector.get();
// Transfer ownership of the cancellation manager to `callback`.
CancellationManager* raw_cancellation_manager =
cancellation_manager.release();
auto callback = std::bind(
[this, rets, step_container, raw_cancellation_manager, frame](
const FunctionLibraryRuntime::DoneCallback& done,
IteratorContext* ctx, const std::function<void()>& deregister_fn,
const string& prefix,
const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
// Begin unbound arguments.
Status s) {
delete step_container;
deregister_fn();
delete raw_cancellation_manager;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
}
delete frame;
if (ctx->model()) {
// TODO(b/129085499) Utilize the `node_name` which would be unique
// than the prefix for the function execution time statistics.
// prefix_with_func_name would then be node_name + func_name.
if (ctx->stats_aggregator()) {
string prefix_end =
str_util::Split(prefix, "::", str_util::SkipEmpty()).back();
string prefix_with_func_name =
strings::StrCat(prefix_end, stats_utils::kDelimiter,
captured_func_->func().name());
ctx->stats_aggregator()->AddToHistogram(
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
{static_cast<float>(stats_collector->processing_time())},
ctx->model()->NumElements(prefix));
}
ctx->model()->AddProcessingTime(prefix,
stats_collector->processing_time());
ctx->model()->RecordStart(prefix, false /* stop_output */);
}
done(s);
if (ctx->model()) {
ctx->model()->RecordStop(prefix, false /* start_output */);
}
},
std::move(done), ctx, std::move(deregister_fn), prefix,
std::move(stats_collector), std::placeholders::_1);
lib_->Run(f_opts, f_handle_, frame, std::move(callback));
}
bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
return lib_->device()->device_type() != DEVICE_CPU ||
captured_func_->is_multi_device_function();
}
CapturedFunction::CapturedFunction(
const std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor> captured_inputs)
: metadata_(metadata), captured_inputs_(std::move(captured_inputs)) {}
Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
bool* is_multi_device) {
if (!metadata_->is_multi_device_function()) {
*is_multi_device = false;
return Status::OK();
}
const FunctionDef* fdef;
TF_RETURN_IF_ERROR(
LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
Device* current_device = ctx->flr()->device();
DeviceType current_device_type(current_device->device_type());
DeviceNameUtils::ParsedName current_device_name;
if (!DeviceNameUtils::ParseFullName(current_device->name(),
&current_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
current_device->name());
}
// Check if any of the captured inputs are placed on a device not compatible
// with the current device. For non-captured inputs, we assume they are placed
// on the current device.
for (const auto& input : captured_inputs_) {
DataType dtype = input.dtype();
if (dtype == DT_RESOURCE) {
const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
DeviceNameUtils::ParsedName resource_device_name;
if (!DeviceNameUtils::ParseFullName(handle.device(),
&resource_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
handle.device());
}
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
resource_device_name)) {
*is_multi_device = true;
return Status::OK();
}
}
}
// Check if all ops could be placed on the current device.
for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
const FunctionDef* fdef;
TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
for (const auto& node : fdef->node_def()) {
// Check if the op has a kernel availabe for the current device.
if (!KernelDefAvailable(current_device_type, node)) {
*is_multi_device = true;
return Status::OK();
}
// If the op has a requested device, check if the requested device is
// compatible with the current device.
if (!node.device().empty()) {
DeviceNameUtils::ParsedName node_device_name;
if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
return errors::InvalidArgument("Failed to parse device name: ",
node.device());
}
if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
node_device_name)) {
*is_multi_device = true;
return Status::OK();
}
}
}
}
*is_multi_device = false;
return Status::OK();
}
} // namespace data
} // namespace tensorflow