| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/common_runtime/function.h" |
| |
| #include <deque> |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/memory/memory.h" |
| #include "absl/strings/str_cat.h" |
| #include "tensorflow/core/common_runtime/device.h" |
| #include "tensorflow/core/common_runtime/executor.h" |
| #include "tensorflow/core/common_runtime/executor_factory.h" |
| #include "tensorflow/core/common_runtime/graph_optimizer.h" |
| #include "tensorflow/core/common_runtime/memory_types.h" |
| #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
| #include "tensorflow/core/framework/collective.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/versions.pb.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/control_flow.h" |
| #include "tensorflow/core/graph/gradients.h" |
| #include "tensorflow/core/graph/graph_constructor.h" |
| #include "tensorflow/core/graph/node_builder.h" |
| #include "tensorflow/core/graph/optimizer_cse.h" |
| #include "tensorflow/core/lib/core/threadpool.h" |
| #include "tensorflow/core/lib/gtl/map_util.h" |
| #include "tensorflow/core/platform/macros.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/protobuf/config.pb.h" |
| |
| // See core/kernels/function_ops.cc for related kernels. |
| |
| namespace tensorflow { |
| |
| // A few string constant used throughout this module. |
| static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp; |
| static constexpr const char* const kDeviceArgOp = |
| FunctionLibraryDefinition::kDeviceArgOp; |
| static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp; |
| static constexpr const char* const kDeviceRetOp = |
| FunctionLibraryDefinition::kDeviceRetOp; |
| static constexpr const char* const kGradientOp = |
| FunctionLibraryDefinition::kGradientOp; |
| static constexpr const char* const kNodeLabel = "Func"; |
| static constexpr const char* const kFuncAttr = |
| FunctionLibraryDefinition::kFuncAttr; |
| |
| // Represents the index-th output of a node. |
| struct Endpoint { |
| Node* node; |
| int index; |
| |
| // Returns the string name represents this endpoint. |
| string name() const { |
| if (index == 0) { |
| return node->name(); |
| } else { |
| return strings::StrCat(node->name(), ":", index); |
| } |
| } |
| |
| DataType dtype() const { return node->output_type(index); } |
| }; |
| |
| struct EndpointHash { |
| uint64 operator()(const Endpoint& x) const { |
| return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*), |
| x.index); |
| } |
| }; |
| |
| struct EndpointEq { |
| bool operator()(const Endpoint& x, const Endpoint& y) const { |
| return (x.node == y.node) && (x.index == y.index); |
| } |
| }; |
| |
| // The following Add* routines are used to add a few graph nodes while |
| // functions are transformed. |
| static Node* AddNoOp(StringPiece name, Graph* g) { |
| NodeDef ndef; |
| ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); |
| ndef.set_op("NoOp"); |
| Status s; |
| Node* ret = g->AddNode(ndef, &s); |
| TF_CHECK_OK(s); |
| return ret; |
| } |
| |
| static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { |
| DCHECK_LT(0, input.dtype()); |
| NodeDef ndef; |
| ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); |
| ndef.set_op("Identity"); |
| ndef.add_input(input.name()); |
| AddNodeAttr("T", BaseType(input.dtype()), &ndef); |
| Status s; |
| Node* ret = g->AddNode(ndef, &s); |
| TF_CHECK_OK(s); |
| g->AddEdge(input.node, input.index, ret, 0); |
| return ret; |
| } |
| |
| static Node* AddArg(Graph* g, DataType dtype, int index) { |
| DCHECK_LT(0, dtype); |
| DCHECK_LT(dtype, DT_FLOAT_REF); |
| NodeDef ndef; |
| ndef.set_name(g->NewName(kNodeLabel)); |
| ndef.set_op(kArgOp); |
| AddNodeAttr("T", dtype, &ndef); |
| AddNodeAttr("index", index, &ndef); |
| Status s; |
| Node* ret = g->AddNode(ndef, &s); |
| TF_CHECK_OK(s); |
| return ret; |
| } |
| |
| static Node* AddRet(Graph* g, Endpoint input, int index) { |
| DCHECK_LT(0, input.dtype()); |
| DCHECK_LT(input.dtype(), DT_FLOAT_REF); |
| NodeDef ndef; |
| ndef.set_name(g->NewName(kNodeLabel)); |
| ndef.set_op(kRetOp); |
| ndef.add_input(input.name()); |
| AddNodeAttr("T", input.dtype(), &ndef); |
| AddNodeAttr("index", index, &ndef); |
| Status s; |
| Node* ret = g->AddNode(ndef, &s); |
| TF_CHECK_OK(s); |
| g->AddEdge(input.node, input.index, ret, 0); |
| return ret; |
| } |
| |
| // FunctionLibraryRuntime implementation that forwards all the function calls to |
| // the base runtime implementation, and only overrides FunctionLibraryDefinition |
| // in calls to Instantiate (if caller doesn't provide the |
| // InstantiateOptions::lib_def option). |
| // |
| // When the function library runtime (FunctionLibraryRuntimeImpl specifically) |
| // instantiates a function into a Graph object, it also creates an Executor for |
| // it. That executor has a pointer to the function library runtime instance, |
| // that is used to instantiate all nested function calls. |
| // |
| // The function library definition used to instantiate the function must be |
| // preserved in the Executor's function library runtime. |
| // |
| // IMPORTANT: This runtime is intended for use only in executors created for |
| // functions instantiated into a graph in FunctionLibraryRuntimeImpl. |
| class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { |
| public: |
| FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime* base_flr, |
| const FunctionLibraryDefinition* lib_def) |
| : base_flr_(base_flr), lib_def_(lib_def) {} |
| ~FunctionLibraryRuntimeOverlay() override; |
| |
| Status Instantiate(const string& function_name, AttrSlice attrs, |
| const InstantiateOptions& options, |
| Handle* handle) override; |
| |
| Status ReleaseHandle(Handle handle) override; |
| |
| const FunctionBody* GetFunctionBody(Handle h) override; |
| |
| Status GetRetTypes(Handle h, DataTypeVector* ret_types) override; |
| |
| void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args, |
| std::vector<Tensor>* rets, DoneCallback done) override; |
| |
| void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame, |
| DoneCallback done) override; |
| |
| Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; |
| |
| bool IsStateful(const string& function_name) const override; |
| |
| const FunctionLibraryDefinition* GetFunctionLibraryDefinition() |
| const override; |
| |
| Env* env() override; |
| const ConfigProto* const config_proto() override; |
| Device* device() override; |
| const Device* device() const override; |
| std::function<void(std::function<void()>)>* runner() override; |
| const DeviceMgr* device_mgr() const override; |
| |
| string DebugString(Handle handle) override; |
| int graph_def_version() const override; |
| |
| Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
| std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
| FunctionLibraryRuntime** out_flr, |
| bool skip_flib_def = false) override; |
| |
| private: |
| FunctionLibraryRuntime* base_flr_; // not owned |
| const FunctionLibraryDefinition* lib_def_; // not owned |
| }; |
| |
| FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default; |
| |
| Status FunctionLibraryRuntimeOverlay::Instantiate( |
| const string& function_name, AttrSlice attrs, |
| const InstantiateOptions& options, Handle* handle) { |
| // We automatically set the `lib_def` option for all instantiations, if the |
| // caller doesn't set this option explicitly. |
| if (!options.lib_def && lib_def_) { |
| InstantiateOptions options_copy = options; |
| options_copy.lib_def = lib_def_; |
| return base_flr_->Instantiate(function_name, attrs, options_copy, handle); |
| } else { |
| return base_flr_->Instantiate(function_name, attrs, options, handle); |
| } |
| } |
| |
| Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) { |
| return base_flr_->ReleaseHandle(handle); |
| } |
| |
| const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) { |
| return base_flr_->GetFunctionBody(h); |
| } |
| |
| Status FunctionLibraryRuntimeOverlay::GetRetTypes(Handle h, |
| DataTypeVector* ret_types) { |
| return base_flr_->GetRetTypes(h, ret_types); |
| } |
| |
| void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle, |
| gtl::ArraySlice<Tensor> args, |
| std::vector<Tensor>* rets, |
| DoneCallback done) { |
| base_flr_->Run(opts, handle, args, rets, std::move(done)); |
| } |
| |
| void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle, |
| CallFrameInterface* call_frame, |
| DoneCallback done) { |
| base_flr_->Run(opts, handle, call_frame, std::move(done)); |
| } |
| |
| Status FunctionLibraryRuntimeOverlay::CreateKernel(const NodeDef&, OpKernel**) { |
| // We don't have access to base_lib_def_ in base function library runtime (aka |
| // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with |
| // the wrong lib_def we just disable creation of new kernels through overlays. |
| // |
| // When we call Instantiate from the base runtime with the lib_def option, |
| // the base runtime implementation is responsible for correctly passing it |
| // through to all kernel constructions. |
| return errors::Internal( |
| "Overlay function library runtime doesn't support kernel creation."); |
| } |
| |
| bool FunctionLibraryRuntimeOverlay::IsStateful( |
| const string& function_name) const { |
| // Important: we do not forward lookup to the base FLR. |
| const OpDef* op_def; |
| const Status s = lib_def_->LookUpOpDef(function_name, &op_def); |
| return s.ok() && op_def->is_stateful(); |
| } |
| |
| Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); } |
| |
| const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() { |
| return base_flr_->config_proto(); |
| } |
| |
| Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); } |
| |
| const Device* FunctionLibraryRuntimeOverlay::device() const { |
| return base_flr_->device(); |
| } |
| |
| std::function<void(std::function<void()>)>* |
| FunctionLibraryRuntimeOverlay::runner() { |
| return base_flr_->runner(); |
| } |
| |
| const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const { |
| return base_flr_->device_mgr(); |
| } |
| |
| const FunctionLibraryDefinition* |
| FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const { |
| return lib_def_ ? lib_def_ : base_flr_->GetFunctionLibraryDefinition(); |
| } |
| |
| string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) { |
| return base_flr_->DebugString(handle); |
| } |
| |
| int FunctionLibraryRuntimeOverlay::graph_def_version() const { |
| return base_flr_->graph_def_version(); |
| } |
| |
| Status FunctionLibraryRuntimeOverlay::Clone( |
| std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
| std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
| FunctionLibraryRuntime** out_flr, bool skip_flib_def) { |
| // NOTE(ezhulenev): The cloned FunctionLibraryRuntime will be missing the |
| // FunctionLibraryDefinition override, but that's ok because we anyway do not |
| // copy / clone instantiated items from the base FLR. |
| return base_flr_->Clone(out_lib_def, out_pflr, out_flr, skip_flib_def); |
| } |
| |
| class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { |
| public: |
| FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, |
| const ConfigProto* config, Device* device, |
| int graph_def_version, |
| const FunctionLibraryDefinition* lib_def, |
| thread::ThreadPool* default_thread_pool, |
| const OptimizerOptions& optimizer_options, |
| const CustomKernelCreator* custom_kernel_creator, |
| const SessionMetadata* session_metadata, |
| ProcessFunctionLibraryRuntime* parent); |
| |
| ~FunctionLibraryRuntimeImpl() override; |
| |
| Status Instantiate(const string& function_name, AttrSlice attrs, |
| const InstantiateOptions& options, |
| Handle* handle) override; |
| |
| Status ReleaseHandle(Handle handle) override; |
| |
| const FunctionBody* GetFunctionBody(Handle handle) override; |
| |
| Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override; |
| |
| Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; |
| |
| void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args, |
| std::vector<Tensor>* rets, DoneCallback done) override; |
| void Run(const Options& opts, Handle handle, CallFrameInterface* frame, |
| DoneCallback done) override; |
| |
| bool IsStateful(const string& function) const override; |
| |
| const FunctionLibraryDefinition* GetFunctionLibraryDefinition() |
| const override { |
| return base_lib_def_; |
| } |
| |
| Device* device() override { return device_; } |
| const Device* device() const override { return device_; } |
| |
| std::function<void(std::function<void()>)>* runner() override { |
| return &default_runner_; |
| } |
| |
| const DeviceMgr* device_mgr() const override { return device_mgr_; } |
| Env* env() override { return env_; } |
| const ConfigProto* const config_proto() override { return config_; } |
| int graph_def_version() const override { return graph_def_version_; } |
| |
| string DebugString(Handle h) override; |
| |
| Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
| std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
| FunctionLibraryRuntime** out_flr, |
| bool skip_flib_def = false) override; |
| |
| private: |
| typedef FunctionLibraryRuntimeImpl ME; |
| |
| const DeviceMgr* const device_mgr_; |
| Device* const device_; |
| Env* const env_; |
| const ConfigProto* const config_; |
| const int graph_def_version_; |
| const FunctionLibraryDefinition* const base_lib_def_; |
| GraphOptimizer optimizer_; |
| const CustomKernelCreator* custom_kernel_creator_; |
| const SessionMetadata* const session_metadata_; |
| Executor::Args::Runner default_runner_; |
| const string device_name_; |
| |
| std::function<Status(const string&, const OpDef**)> get_func_sig_; |
| std::function<Status(const NodeDef&, OpKernel**)> create_kernel_; |
| |
| mutable mutex mu_; |
| |
| int next_handle_ GUARDED_BY(mu_); |
| |
| // The instantiated and transformed function is encoded as a Graph |
| // object, and an executor is created for the graph. |
| struct Item { |
| uint64 instantiation_counter = 0; |
| std::unique_ptr<const Graph> graph = nullptr; |
| const FunctionLibraryDefinition* lib_def = nullptr; // Not owned. |
| FunctionBody* func_graph = nullptr; |
| Executor* exec = nullptr; |
| FunctionLibraryRuntimeOverlay* overlay_flr = nullptr; |
| string executor_type; |
| Executor::RendezvousFactory rendezvous_factory = nullptr; |
| |
| ~Item() { |
| delete this->func_graph; |
| delete this->exec; |
| delete this->overlay_flr; |
| } |
| }; |
| std::unique_ptr<std::unordered_map<Handle, std::unique_ptr<Item>>> items_ |
| GUARDED_BY(mu_); |
| |
| ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned. |
| |
| // Overloads the CreateKernel method, providing a FunctionLibraryRuntime |
| // to use for kernel creation and execution. In particular, this method can |
| // accept a FunctionLibraryRuntimeOverlay that overlays a different |
| // FunctionLibraryDefinition. |
| Status CreateKernel(const NodeDef& ndef, FunctionLibraryRuntime* flr, |
| OpKernel** kernel); |
| Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, |
| const FunctionLibraryDefinition* lib_def, |
| std::unique_ptr<FunctionBody>* fbody); |
| Status CreateItem(Item** item); |
| Status GetOrCreateItem(LocalHandle local_handle, Item** item); |
| Status InstantiateSymbolicGradient(const NameAttrList& func, |
| const FunctionLibraryDefinition* lib_def, |
| std::unique_ptr<FunctionBody>* g_body); |
| bool IsLocalTarget(const InstantiateOptions& options) const; |
| AttrValueMap FixAttrs(const AttrSlice& attrs); |
| void RunRemote(const Options& opts, Handle handle, |
| gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, |
| Item* item, DoneCallback done); |
| |
| void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts, |
| CallFrameInterface* frame, |
| Executor::Args* exec_args); |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); |
| }; |
| |
| FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( |
| const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device, |
| int graph_def_version, const FunctionLibraryDefinition* lib_def, |
| thread::ThreadPool* default_thread_pool, |
| const OptimizerOptions& optimizer_options, |
| const CustomKernelCreator* custom_kernel_creator, |
| const SessionMetadata* session_metadata, |
| ProcessFunctionLibraryRuntime* parent) |
| : device_mgr_(dmgr), |
| device_(device), |
| env_(env), |
| config_(config), |
| graph_def_version_(graph_def_version), |
| base_lib_def_(lib_def), |
| optimizer_(optimizer_options), |
| custom_kernel_creator_(custom_kernel_creator), |
| session_metadata_(session_metadata), |
| default_runner_(nullptr), |
| device_name_(device_ == nullptr |
| ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice |
| : device_->name()), |
| next_handle_(0), |
| items_(new std::unordered_map<Handle, std::unique_ptr<Item>>), |
| parent_(parent) { |
| get_func_sig_ = [this](const string& op, const OpDef** sig) { |
| return base_lib_def_->LookUpOpDef(op, sig); |
| }; |
| create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) { |
| return CreateKernel(ndef, kernel); |
| }; |
| thread::ThreadPool* pool = nullptr; |
| if (device_ != nullptr) { |
| pool = device_->tensorflow_device_thread_pool(); |
| } |
| if (pool == nullptr) { |
| pool = default_thread_pool; |
| } |
| if (pool != nullptr) { |
| default_runner_ = [pool](Executor::Args::Closure c) { |
| pool->Schedule(std::move(c)); |
| }; |
| } |
| } |
| |
| FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() { |
| // Deleting the items_ list will delete all the function handles registered in |
| // this object. A function may contains a few sub-functions which have also |
| // been registered in this object. Deleting the parent function will call |
| // ReleaseHandle in this class again for each of the sub-functions. These |
| // circular calls may cause segfault since the items_ may have already been |
| // partially deleted when releasing handles of sub-functions. Explicitly |
| // release items_ here and check it in ReleaseHandle to avoid this. |
| items_.reset(); |
| } |
| |
| // An asynchronous op kernel which executes an instantiated function |
| // defined in a library. |
| class CallOp : public AsyncOpKernel { |
| public: |
| CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx) |
| : AsyncOpKernel(ctx), handle_(handle) {} |
| |
| ~CallOp() override { |
| // TODO(iga): Release the cached handle_ |
| } |
| |
| void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { |
| FunctionLibraryRuntime* lib = ctx->function_library(); |
| OP_REQUIRES_ASYNC(ctx, lib != nullptr, |
| errors::Internal("No function library is provided."), |
| done); |
| FunctionLibraryRuntime::Options opts; |
| opts.rendezvous = ctx->rendezvous(); |
| opts.cancellation_manager = ctx->cancellation_manager(); |
| opts.step_container = ctx->step_container(); |
| opts.stats_collector = ctx->stats_collector(); |
| opts.runner = ctx->runner(); |
| opts.collective_executor = ctx->collective_executor(); |
| std::vector<Tensor> args; |
| args.reserve(ctx->num_inputs()); |
| for (int i = 0; i < ctx->num_inputs(); ++i) { |
| args.push_back(ctx->input(i)); |
| } |
| std::vector<Tensor>* rets = new std::vector<Tensor>; |
| profiler::TraceMe trace_me( |
| [&] { |
| return absl::StrCat("CallOp #parent_step_id=", ctx->step_id(), |
| ",function_step_id=", opts.step_id, "#"); |
| }, |
| /*level=*/2); |
| lib->Run(opts, handle_, args, rets, |
| [ctx, done, rets](const Status& status) { |
| if (!status.ok()) { |
| ctx->SetStatus(status); |
| } else { |
| const int ret_size = static_cast<int>(rets->size()); |
| CHECK_EQ(ret_size, ctx->num_outputs()); |
| for (int i = 0; i < ret_size; ++i) { |
| ctx->set_output(i, (*rets)[i]); |
| } |
| } |
| delete rets; |
| done(); |
| }); |
| } |
| |
| private: |
| FunctionLibraryRuntime::Handle handle_; |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(CallOp); |
| }; |
| |
| const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) { |
| LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h); |
| if (local_handle == kInvalidLocalHandle) { |
| LOG(ERROR) << "Could not find Handle: " << h |
| << " on device: " << device_name_; |
| return nullptr; |
| } |
| |
| tf_shared_lock l(mu_); |
| auto iter = items_->find(local_handle); |
| CHECK(iter != items_->end()); |
| return iter->second->func_graph; |
| } |
| |
| Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h, |
| DataTypeVector* ret_types) { |
| if (parent_->IsMultiDevice(h)) { |
| return parent_->GetRetTypes(h, ret_types); |
| } |
| LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h); |
| if (local_handle == kInvalidLocalHandle) { |
| return errors::InvalidArgument("Handle ", h, " not found."); |
| } |
| const FunctionBody* fbody = GetFunctionBody(h); |
| *ret_types = fbody->ret_types; |
| return Status::OK(); |
| } |
| |
| Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, |
| OpKernel** kernel) { |
| return CreateKernel(ndef, this, kernel); |
| } |
| |
| Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, |
| FunctionLibraryRuntime* flr, |
| OpKernel** kernel) { |
| // If a custom kernel creator is given, try that. |
| Status s; |
| if (custom_kernel_creator_ != nullptr && |
| custom_kernel_creator_->CanCreateKernel(*this, ndef)) { |
| std::unique_ptr<OpKernel> ret; |
| s = custom_kernel_creator_->CreateKernel(this, ndef, &ret); |
| if (s.ok()) { |
| *kernel = ret.release(); |
| } else { |
| VLOG(2) << "Custom creator error: " << s; |
| } |
| return s; |
| } |
| |
| const FunctionLibraryDefinition* lib_def = |
| flr->GetFunctionLibraryDefinition(); |
| if (lib_def->Find(ndef.op()) == nullptr) { |
| // A primitive operation. Creates the registered kernel. |
| return CreateNonCachedKernel(device_, flr, ndef, graph_def_version_, |
| kernel); |
| } |
| |
| // Try to instantiate this function for the func/attr. Maybe it's |
| // cached already. |
| InstantiateOptions options; |
| if (lib_def != base_lib_def_) { |
| options.lib_def = lib_def; |
| } |
| Handle handle; |
| TF_RETURN_IF_ERROR( |
| Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle)); |
| |
| const FunctionBody* fbody = GetFunctionBody(handle); |
| CHECK_NOTNULL(fbody); |
| |
| // TODO(zhifengc): For now, we assume int32 and resources are always on host |
| // memory and other types are always on device memory. We should do type |
| // inference over function body to derive the correct input/output memory |
| // types. |
| MemoryTypeVector input_memory_types; |
| for (const auto& t : fbody->arg_types) { |
| input_memory_types.push_back(MTypeFromDType(t)); |
| } |
| MemoryTypeVector output_memory_types; |
| for (const auto& t : fbody->ret_types) { |
| output_memory_types.push_back(MTypeFromDType(t)); |
| } |
| |
| // Constructs a CallOp kernel for running the instantiated function. |
| auto device_type = DeviceType(device_->attributes().device_type()); |
| OpKernelConstruction construction( |
| device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef, |
| &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, |
| fbody->ret_types, output_memory_types, graph_def_version_, &s); |
| if (s.ok()) { |
| *kernel = new CallOp(handle, &construction); |
| } |
| return s; |
| } |
| |
| Status FunctionLibraryRuntimeImpl::FunctionDefToBody( |
| const FunctionDef& fdef, AttrSlice attrs, |
| const FunctionLibraryDefinition* lib_def, |
| std::unique_ptr<FunctionBody>* fbody) { |
| if (lib_def == base_lib_def_) { |
| return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody); |
| } else { |
| auto get_func_sig = [lib_def](const string& op, const OpDef** sig) { |
| return lib_def->LookUpOpDef(op, sig); |
| }; |
| return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody); |
| } |
| } |
| |
| Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( |
| const NameAttrList& func, const FunctionLibraryDefinition* lib_def, |
| std::unique_ptr<FunctionBody>* g_body) { |
| const FunctionDef* fdef = lib_def->Find(func.name()); |
| if (fdef == nullptr) { |
| // f is a primitive op. |
| gradient::Creator creator; |
| TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator)); |
| if (creator == nullptr) { |
| return errors::InvalidArgument("No gradient is defined for ", |
| func.name()); |
| } |
| FunctionDef grad_fdef; |
| // TODO(josh11b): Should filter out the attrs from func that aren't used |
| // by the gradient function. |
| TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef)); |
| TF_RETURN_IF_ERROR( |
| FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body)); |
| } else { |
| // f is a user-defined function. |
| InstantiateOptions options; |
| if (lib_def != base_lib_def_) { |
| options.lib_def = lib_def; |
| } |
| Handle f_handle; |
| TF_RETURN_IF_ERROR( |
| Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle)); |
| const FunctionBody* f_body = GetFunctionBody(f_handle); |
| CHECK_NOTNULL(f_body); |
| *g_body = SymbolicGradient(*f_body); |
| } |
| return Status::OK(); |
| } |
| |
| bool FunctionLibraryRuntimeImpl::IsLocalTarget( |
| const InstantiateOptions& options) const { |
| if (device_ == nullptr) return true; |
| if (options.target.empty()) return true; |
| if (options.is_multi_device_function) return false; |
| Device* target_device; |
| if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) { |
| VLOG(1) << "Not instantiating function in FLR because failed to " |
| << "find device " << options.target << " in device manager"; |
| return false; |
| } |
| if (target_device != device_) { |
| VLOG(1) << "Not instantiating function in FLR because target device " |
| << options.target |
| << " is different from FLR's device: " << device_->DebugString(); |
| return false; |
| } |
| return true; |
| } |
| |
| Status FunctionLibraryRuntimeImpl::Instantiate( |
| const string& function_name, AttrSlice attrs, |
| const InstantiateOptions& options, Handle* handle) { |
| if (!IsLocalTarget(options)) { |
| return parent_->Instantiate(function_name, attrs, options, handle); |
| } |
| |
| // Since this is a local target, ensure that the local `device_name_` appears |
| // in the canonical key. |
| InstantiateOptions options_copy(options); |
| options_copy.target = device_name_; |
| const string key = Canonicalize(function_name, attrs, options_copy); |
| |
| { |
| mutex_lock l(mu_); |
| *handle = parent_->GetHandle(key); |
| if (*handle != kInvalidHandle) { |
| FunctionLibraryRuntime::LocalHandle handle_on_device = |
| parent_->GetHandleOnDevice(device_name_, *handle); |
| if (handle_on_device == kInvalidLocalHandle) { |
| return errors::Internal("LocalHandle not found for handle ", *handle, |
| "."); |
| } |
| auto item_handle = items_->find(handle_on_device); |
| if (item_handle == items_->end()) { |
| return errors::Internal("LocalHandle ", handle_on_device, |
| " for handle ", *handle, |
| " not found in items."); |
| } |
| ++item_handle->second->instantiation_counter; |
| return Status::OK(); |
| } |
| } |
| |
| const FunctionLibraryDefinition* lib_def = |
| options.lib_def ? options.lib_def : base_lib_def_; |
| std::unique_ptr<FunctionBody> fbody; |
| if (function_name == kGradientOp) { |
| const AttrValue* f = attrs.Find(kFuncAttr); |
| if (f == nullptr) { |
| return errors::InvalidArgument("SymbolicGradient is missing attr: f"); |
| } |
| const auto& func = f->func(); |
| if (func.name() == kGradientOp) { |
| return errors::InvalidArgument("Can't take gradient of SymbolicGradient"); |
| } |
| const string grad = lib_def->FindGradient(func.name()); |
| if (!grad.empty()) { |
| return Instantiate(grad, AttrSlice(&func.attr()), options, handle); |
| } |
| TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody)); |
| } else { |
| const FunctionDef* fdef = lib_def->Find(function_name); |
| if (fdef == nullptr) { |
| return errors::NotFound("Function ", function_name, " is not defined."); |
| } |
| TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody)); |
| } |
| |
| LocalHandle local_handle; |
| { |
| mutex_lock l(mu_); |
| *handle = parent_->GetHandle(key); |
| if (*handle != kInvalidHandle) { |
| local_handle = parent_->GetHandleOnDevice(device_name_, *handle); |
| ++(*items_)[local_handle]->instantiation_counter; |
| } else { |
| *handle = parent_->AddHandle(key, device_name_, next_handle_); |
| Item* item = new Item; |
| item->func_graph = fbody.release(); |
| item->instantiation_counter = 1; |
| item->executor_type = ExecutorType(options, attrs); |
| if (options.lib_def) { |
| item->overlay_flr = |
| new FunctionLibraryRuntimeOverlay(this, options.lib_def); |
| } |
| item->rendezvous_factory = [](const int64, const DeviceMgr* device_mgr, |
| Rendezvous** r) { |
| *r = new IntraProcessRendezvous(device_mgr); |
| return Status::OK(); |
| }; |
| local_handle = next_handle_++; |
| items_->emplace(local_handle, std::unique_ptr<Item>(item)); |
| } |
| } |
| |
| if (options.create_kernels_eagerly) { |
| Item* item; |
| TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { |
| LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle); |
| if (h == kInvalidLocalHandle) { |
| return parent_->ReleaseHandle(handle); |
| } |
| std::unique_ptr<Item> item_to_delete; |
| Status parent_status; |
| { |
| mutex_lock l(mu_); |
| // Return directly if all items has already been released. |
| if (items_ == nullptr) return Status::OK(); |
| |
| auto it = items_->find(h); |
| if (it == items_->end()) { |
| return errors::Internal( |
| "Inconsistent FunctionLibraryRuntime. Expected to find an item for " |
| "handle ", |
| h, " but found none"); |
| } |
| std::unique_ptr<Item>& item = it->second; |
| --item->instantiation_counter; |
| if (item->instantiation_counter == 0) { |
| // We don't simply erase h's item because that would trigger |
| // item destruction while holding mu_. Item destruction can |
| // trigger graph destruction. If the graph contains kernels like |
| // CallOp or PartitionCallOp, their destructors will release cached |
| // function handles, resulting in deadlock here. |
| item_to_delete = std::move(item); |
| items_->erase(h); |
| parent_status = parent_->RemoveHandle(handle); |
| } |
| } |
| return parent_status; |
| } |
| |
| void DumpGraph(StringPiece label, const Graph* g) { |
| // TODO(zhifengc): Change Graph to record #nodes. |
| VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges " |
| << g->num_edges(); |
| if (VLOG_IS_ON(5)) { |
| for (const auto& line : str_util::Split(DebugString(g), '\n')) { |
| VLOG(5) << "|| " << line; |
| } |
| } |
| } |
| |
| void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g, |
| const GraphOptimizer::Options& graph_optimizer_options) { |
| OptimizerOptions opts; |
| opts.set_do_common_subexpression_elimination(true); |
| opts.set_do_function_inlining(true); |
| opts.set_do_constant_folding(true); |
| GraphOptimizer optimizer(opts); |
| optimizer.Optimize(lib, lib->env(), lib->device(), g, |
| graph_optimizer_options); |
| } |
| |
| void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) { |
| OptimizeGraph(lib, g, GraphOptimizer::Options()); |
| } |
| |
| namespace { |
| // Removes all stateless nodes that do not contribute to a return |
| // value from the function body. Unlike `RemoveDeadNodes()`, which is |
| // triggered by `OptimizerOptions.do_function_inlining`, this pass |
| // ignores the SINK node, from which (by definition) all nodes are |
| // reverse reachable, and preserves all nodes that are reachable from |
| // control output nodes. |
| // |
| // TODO(ezhulenev, skyewm): Function body should not have special treatment of |
| // stateful ops, graph should encode nodes that must execute with `control_ret` |
| // and `control_output`. |
| void PruneFunctionBody(const FunctionDef& fdef, Graph* g) { |
| VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name(); |
| |
| // `control_ret` nodes must be always executed. |
| std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes; |
| for (const auto& control_ret : fdef.control_ret()) { |
| control_ret_nodes.insert(control_ret.second); |
| } |
| |
| std::unordered_set<const Node*> nodes; |
| for (auto n : g->nodes()) { |
| // NOTE(mrry): "_Retval" nodes are stateful, and so will be added |
| // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we |
| // specifically exclude them as seeds, to avoid unconditionally executing |
| // unused argument nodes (e.g. in a function like `lambda x, y: y`). |
| // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is |
| // still needed. It would be preferable to prune entire loops and/or |
| // conditionals if they are not used in the graph. |
| if (n->IsControlFlow() || |
| (n->op_def().is_stateful() && n->type_string() != kArgOp) || |
| (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) { |
| nodes.insert(n); |
| } |
| } |
| bool changed = PruneForReverseReachability(g, std::move(nodes)); |
| if (changed) { |
| FixupSourceAndSinkEdges(g); |
| } |
| } |
| } // namespace |
| |
| Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) { |
| const FunctionBody* fbody; |
| FunctionLibraryRuntime* flr; |
| string executor_type; |
| { |
| tf_shared_lock l(mu_); |
| fbody = (*item)->func_graph; |
| flr = (*item)->overlay_flr |
| ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr) |
| : static_cast<FunctionLibraryRuntime*>(this); |
| executor_type = (*item)->executor_type; |
| } |
| const FunctionLibraryDefinition* lib_def = |
| flr->GetFunctionLibraryDefinition(); |
| std::unique_ptr<Graph> g(new Graph(lib_def)); |
| CopyGraph(*fbody->graph, g.get()); |
| |
| PruneFunctionBody(fbody->fdef, g.get()); |
| optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr); |
| TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()), |
| device()->name(), g.get())); |
| |
| // Creates an executor based on the g. This must be done without |
| // holding mu_ because create_kernel_ calls back into the library. |
| LocalExecutorParams params; |
| params.device = device_; |
| params.function_library = flr; |
| if (flr == this) { |
| params.create_kernel = create_kernel_; |
| } else { |
| params.create_kernel = [this, flr](const NodeDef& ndef, OpKernel** kernel) { |
| return CreateKernel(ndef, flr, kernel); |
| }; |
| } |
| params.delete_kernel = [](OpKernel* kernel) { |
| DeleteNonCachedKernel(kernel); |
| }; |
| params.rendezvous_factory = (*item)->rendezvous_factory; |
| params.session_metadata = session_metadata_; |
| std::unique_ptr<Executor> exec; |
| TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec)); |
| { |
| // Guard item since it is already inserted in items_. |
| mutex_lock l(mu_); |
| if ((*item)->exec == nullptr) { |
| (*item)->graph = std::move(g); |
| (*item)->exec = exec.release(); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle, |
| Item** item) { |
| { |
| tf_shared_lock l(mu_); |
| auto iter = items_->find(local_handle); |
| if (iter == items_->end()) { |
| return errors::Internal("Local function handle ", local_handle, |
| " is not valid. Likely an internal error."); |
| } |
| *item = iter->second.get(); |
| if ((*item)->exec != nullptr) { |
| return Status::OK(); |
| } |
| } |
| // NOTE: We need to call CreateItem out of mu_ because creating an |
| // executor needs to call CreateKernel. |
| return CreateItem(item); |
| } |
| |
| void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions( |
| const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame, |
| Executor::Args* exec_args) { |
| // Inherit the step_id from the caller. |
| exec_args->step_id = run_opts.step_id; |
| exec_args->rendezvous = run_opts.rendezvous; |
| exec_args->stats_collector = run_opts.stats_collector; |
| exec_args->cancellation_manager = run_opts.cancellation_manager; |
| exec_args->step_container = run_opts.step_container; |
| if (run_opts.runner) { |
| exec_args->runner = *run_opts.runner; |
| } else { |
| exec_args->runner = default_runner_; |
| } |
| exec_args->collective_executor = run_opts.collective_executor; |
| exec_args->call_frame = frame; |
| } |
| |
| void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, |
| gtl::ArraySlice<Tensor> args, |
| std::vector<Tensor>* rets, |
| Item* item, DoneCallback done) { |
| string target_device = parent_->GetDeviceName(handle); |
| string source_device = opts.source_device; |
| Rendezvous* rendezvous = opts.rendezvous; |
| DeviceContext* device_context; |
| Status s = parent_->GetDeviceContext(target_device, &device_context); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| int64 src_incarnation, target_incarnation; |
| s = parent_->GetDeviceIncarnation(source_device, &src_incarnation); |
| s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation)); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| |
| const FunctionBody* fbody = GetFunctionBody(handle); |
| FunctionCallFrame* frame = |
| new FunctionCallFrame(fbody->arg_types, fbody->ret_types); |
| Executor::Args* exec_args = new Executor::Args; |
| ExecutorArgsFromOptions(opts, frame, exec_args); |
| |
| std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs; |
| args_alloc_attrs.reserve(fbody->arg_types.size()); |
| rets_alloc_attrs.reserve(fbody->ret_types.size()); |
| // Note: Functions assume that int32's are always on host memory. |
| for (const auto& arg_type : fbody->arg_types) { |
| AllocatorAttributes arg_alloc_attrs; |
| if (MTypeFromDType(arg_type) == HOST_MEMORY) { |
| arg_alloc_attrs.set_on_host(true); |
| } |
| args_alloc_attrs.push_back(arg_alloc_attrs); |
| } |
| for (const auto& ret_type : fbody->ret_types) { |
| AllocatorAttributes ret_alloc_attrs; |
| if (MTypeFromDType(ret_type) == HOST_MEMORY) { |
| ret_alloc_attrs.set_on_host(true); |
| } |
| rets_alloc_attrs.push_back(ret_alloc_attrs); |
| } |
| |
| bool allow_dead_tensors = opts.allow_dead_tensors; |
| |
| // The ProcFLR sends the arguments to the function from the source_device to |
| // the target_device. So here we receive those arguments. Similarly, when the |
| // computation is done and stored in *rets, we send the return values back |
| // to the source_device (caller) so that the ProcFLR can receive them later. |
| std::vector<Tensor>* remote_args = new std::vector<Tensor>; |
| ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( |
| source_device, target_device, "arg_", src_incarnation, args.size(), |
| device_context, args_alloc_attrs, rendezvous, remote_args, |
| [frame, remote_args, item, source_device, target_device, |
| target_incarnation, rendezvous, device_context, rets, done, exec_args, |
| rets_alloc_attrs, allow_dead_tensors](const Status& status) { |
| Status s = status; |
| if (s.ok()) { |
| s = frame->SetArgs(*remote_args); |
| } |
| if (!s.ok()) { |
| delete frame; |
| delete remote_args; |
| delete exec_args; |
| done(s); |
| return; |
| } |
| item->exec->RunAsync( |
| *exec_args, |
| [frame, rets, done, source_device, target_device, |
| target_incarnation, rendezvous, device_context, remote_args, |
| rets_alloc_attrs, allow_dead_tensors](const Status& status) { |
| Status s = status; |
| if (s.ok()) { |
| s = frame->ConsumeRetvals(rets, allow_dead_tensors); |
| } |
| delete frame; |
| if (!s.ok()) { |
| delete remote_args; |
| done(s); |
| return; |
| } |
| s = ProcessFunctionLibraryRuntime::SendTensors( |
| target_device, source_device, "ret_", target_incarnation, |
| *rets, device_context, rets_alloc_attrs, rendezvous); |
| delete remote_args; |
| done(s); |
| }); |
| delete exec_args; |
| }); |
| } |
| |
| void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, |
| gtl::ArraySlice<Tensor> args, |
| std::vector<Tensor>* rets, |
| DoneCallback done) { |
| if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { |
| done(errors::Cancelled("Function was cancelled before it was started")); |
| return; |
| } |
| Options run_opts = opts; |
| if (opts.create_rendezvous) { |
| Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_); |
| run_opts.rendezvous = rendezvous; |
| run_opts.create_rendezvous = false; |
| done = [done = std::move(done), rendezvous](const Status& status) { |
| rendezvous->Unref(); |
| done(status); |
| }; |
| } |
| |
| LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); |
| if (local_handle == kInvalidLocalHandle) { |
| parent_->Run(run_opts, handle, args, rets, done); |
| return; |
| } |
| |
| if (run_opts.runner == nullptr) { |
| run_opts.runner = &default_runner_; |
| } |
| DCHECK(run_opts.runner != nullptr); |
| |
| Item* item = nullptr; |
| Status s = GetOrCreateItem(local_handle, &item); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| |
| if (run_opts.remote_execution) { |
| // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us. |
| RunRemote(run_opts, handle, args, rets, item, std::move(done)); |
| return; |
| } |
| |
| const FunctionBody* fbody = GetFunctionBody(handle); |
| FunctionCallFrame* frame = |
| new FunctionCallFrame(fbody->arg_types, fbody->ret_types); |
| s = frame->SetArgs(args); |
| if (!s.ok()) { |
| delete frame; |
| done(s); |
| return; |
| } |
| |
| Executor::Args exec_args; |
| ExecutorArgsFromOptions(run_opts, frame, &exec_args); |
| |
| bool allow_dead_tensors = run_opts.allow_dead_tensors; |
| item->exec->RunAsync( |
| // Executor args |
| exec_args, |
| // Done callback. |
| [frame, rets, done, allow_dead_tensors](const Status& status) { |
| Status s = status; |
| if (s.ok()) { |
| s = frame->ConsumeRetvals(rets, allow_dead_tensors); |
| } |
| delete frame; |
| done(s); |
| }); |
| } |
| |
| void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, |
| CallFrameInterface* frame, |
| DoneCallback done) { |
| if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { |
| done(errors::Cancelled("")); |
| return; |
| } |
| |
| Options run_opts = opts; |
| if (opts.create_rendezvous) { |
| Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_); |
| run_opts.rendezvous = rendezvous; |
| run_opts.create_rendezvous = false; |
| done = [done = std::move(done), rendezvous](const Status& status) { |
| rendezvous->Unref(); |
| done(status); |
| }; |
| } |
| |
| LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); |
| if (local_handle == kInvalidLocalHandle) { |
| parent_->Run(run_opts, handle, frame, done); |
| return; |
| } |
| |
| if (opts.remote_execution) { |
| // NOTE(mrry): This bit is only set for a local function when `parent_` |
| // calls back into this class, and the current implementation of |
| // `ProcessFunctionLibraryRuntime` currently always uses the vector-based |
| // `args`/`rets` interface. |
| done(errors::Unimplemented("Remote calling with CallFrameInterface")); |
| return; |
| } |
| |
| Item* item = nullptr; |
| Status s = GetOrCreateItem(local_handle, &item); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| if (run_opts.runner == nullptr) { |
| run_opts.runner = &default_runner_; |
| } |
| DCHECK(run_opts.runner != nullptr); |
| |
| Executor::Args exec_args; |
| ExecutorArgsFromOptions(run_opts, frame, &exec_args); |
| item->exec->RunAsync(exec_args, std::move(done)); |
| } |
| |
| bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const { |
| const OpDef* op_def; |
| const Status s = base_lib_def_->LookUpOpDef(func, &op_def); |
| return s.ok() && op_def->is_stateful(); |
| } |
| |
| string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { |
| Item* item = nullptr; |
| LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); |
| Status s = GetOrCreateItem(local_handle, &item); |
| if (s.ok()) { |
| if (item->graph) { |
| return tensorflow::DebugString(item->graph.get()); |
| } else { |
| return tensorflow::DebugString(item->func_graph->graph); |
| } |
| } else { |
| return s.ToString(); |
| } |
| } |
| |
| Status FunctionLibraryRuntimeImpl::Clone( |
| std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
| std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
| FunctionLibraryRuntime** out_flr, bool skip_flib_def) { |
| TF_RETURN_IF_ERROR(parent_->Clone( |
| env_, graph_def_version_, optimizer_.options(), custom_kernel_creator_, |
| out_lib_def, out_pflr, skip_flib_def)); |
| *out_flr = (*out_pflr)->GetFLR(device_->name()); |
| if (*out_flr != nullptr) { |
| return Status::OK(); |
| } else { |
| return errors::Internal("Cloning FunctionLibraryRuntime failed."); |
| } |
| } |
| |
| namespace { |
| |
| struct CustomCreatorSingleton { |
| mutex mu; |
| CustomKernelCreator* custom_creator = nullptr; |
| |
| void Set(CustomKernelCreator* cb) { |
| mutex_lock l(mu); |
| custom_creator = cb; |
| } |
| |
| CustomKernelCreator* Get() { |
| mutex_lock l(mu); |
| return custom_creator; |
| } |
| }; |
| |
| CustomCreatorSingleton* GetCustomCreatorSingleton() { |
| static CustomCreatorSingleton* ccs = new CustomCreatorSingleton; |
| return ccs; |
| } |
| |
| } // namespace |
| |
| const CustomKernelCreator* GetDefaultCustomKernelCreator() { |
| return GetCustomCreatorSingleton()->Get(); |
| } |
| |
| void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) { |
| GetCustomCreatorSingleton()->Set(c); |
| } |
| |
| std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( |
| const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, |
| Device* device, int graph_def_version, |
| const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool, |
| const OptimizerOptions& optimizer_options, |
| const CustomKernelCreator* custom_kernel_creator, |
| const SessionMetadata* session_metadata, |
| ProcessFunctionLibraryRuntime* parent) { |
| return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl( |
| device_mgr, env, config, device, graph_def_version, lib_def, thread_pool, |
| optimizer_options, custom_kernel_creator, session_metadata, parent)); |
| } |
| |
| bool RemoveDeadNodes(Graph* g) { |
| VLOG(2) << "Removing dead nodes"; |
| std::unordered_set<const Node*> nodes; |
| for (auto n : g->nodes()) { |
| if (n->IsSource() || n->IsSink() || n->IsControlFlow() || |
| n->op_def().is_stateful()) { |
| nodes.insert(n); |
| } |
| } |
| return PruneForReverseReachability(g, std::move(nodes)); |
| } |
| |
| namespace { |
| // If 'edges' contains only 1 non-control edge, returns it. Otherwise, |
| // returns a nullptr. |
| const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) { |
| const Edge* ret = nullptr; |
| for (const Edge* e : edges) { |
| if (e->IsControlEdge() || ret) { |
| // Don't touch it if there is a control edge. |
| return nullptr; |
| } |
| if (IsRefType(e->src()->output_type(e->src_output()))) { |
| // Don't touch it if the identity node is effectively de-reffing |
| // a ref. |
| return nullptr; |
| } |
| if (IsRecv(e->src()) || IsSwitch(e->src())) { |
| // Don't touch it if the identity is introduced for control flow. |
| // Recv disables all its successors if it receives a dead signal. |
| // When Recv has an outgoing control edge, the current executor |
| // would not disable the destination. The current solution (see |
| // graph_partition.cc) is to add an identity after Recv and change |
| // the control edge to be from this identity node. So the identity |
| // can't be removed. |
| return nullptr; |
| } |
| ret = e; |
| } |
| return ret; |
| } |
| } // end namespace |
| |
| bool RemoveIdentityNodes(Graph* g) { |
| VLOG(2) << "Removing identity nodes"; |
| bool removed_any = false; |
| gtl::InlinedVector<Node*, 8> matches; |
| for (Node* n : g->nodes()) { |
| if (!n->IsIdentity()) continue; |
| if (!GetTheOnlyDataEdge(n->in_edges())) continue; |
| |
| // Some identity nodes are used as sink nodes to give names to output |
| // tensors. These nodes are not going to be executed unless they are in the |
| // fetch set. But if they are in the fetch set we don't want to remove them. |
| if (n->out_edges().empty()) continue; |
| |
| matches.push_back(n); |
| } |
| if (!matches.empty()) { |
| for (Node* n : matches) { |
| const Edge* in = GetTheOnlyDataEdge(n->in_edges()); |
| for (const Edge* out : n->out_edges()) { |
| if (out->IsControlEdge()) { |
| g->AddControlEdge(in->src(), out->dst()); |
| } else { |
| g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input()); |
| } |
| } |
| VLOG(2) << "Remove Identity: " << n->DebugString(); |
| g->RemoveNode(n); |
| removed_any = true; |
| } |
| } |
| return removed_any; |
| } |
| |
| bool RemoveListArrayConverter(Graph* g) { |
| VLOG(2) << "Removing list array converter"; |
| gtl::InlinedVector<Node*, 8> matches; |
| for (Node* n : g->nodes()) { |
| if ((n->type_string() == "_ListToArray") || |
| (n->type_string() == "_ArrayToList")) { |
| matches.push_back(n); |
| } |
| } |
| bool removed_any = false; |
| if (!matches.empty()) { |
| for (Node* n : matches) { |
| if (n->num_inputs() != n->num_outputs()) { |
| continue; // Not expected. Skip. |
| } |
| gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr); |
| |
| const auto no_op = [&](StringPiece name) -> Node* { |
| return AddNoOp(absl::StrCat(n->name(), "/", name), g); |
| }; |
| |
| const auto identity = [&](StringPiece name, Endpoint input) -> Node* { |
| Node* node = AddIdentity(absl::StrCat(n->name(), "/", name), g, input); |
| node->set_requested_device(input.node->def().device()); |
| return node; |
| }; |
| |
| // Process input edges first. |
| Node* input_control_node = nullptr; |
| for (const Edge* e : n->in_edges()) { |
| if (e->IsControlEdge()) { |
| if (input_control_node == nullptr) { |
| // If node "n" has any control dependencies, adds a no-op |
| // node (input_control_node) which the additional Identity |
| // nodes depends on and the input_control_node depends on |
| // the node "n"s control dependencies. |
| input_control_node = no_op("input_control_node"); |
| } |
| g->AddControlEdge(e->src(), input_control_node); |
| } else { |
| const int index = e->dst_input(); |
| Node** id_node = &identity_nodes[index]; |
| if (*id_node != nullptr) { |
| LOG(ERROR) |
| << "RemoveListArrayConverter unexpected duplicated input: " |
| << e->dst_input(); |
| return removed_any; |
| } |
| *id_node = identity("input", {e->src(), e->src_output()}); |
| } |
| } |
| |
| // If node "n" has any control dependencies, the added identity |
| // nodes should have control dependencies on input_control_node. |
| if (input_control_node != nullptr) { |
| for (Node* id : identity_nodes) { |
| g->AddControlEdge(input_control_node, id); |
| } |
| } |
| |
| Node* output_control_node = nullptr; |
| for (const Edge* e : n->out_edges()) { |
| if (e->IsControlEdge()) { |
| if (output_control_node == nullptr) { |
| // If node "n" is control-depended upon by other nodes, |
| // adds a no-op node (output_control_node) which those |
| // nodes will depend on and output_control_node depends on |
| // all Identity nodes. |
| output_control_node = no_op("output_control_node"); |
| } |
| g->AddControlEdge(output_control_node, e->dst()); |
| } else { |
| Node* id_node = identity_nodes[e->src_output()]; |
| if (id_node == nullptr) { |
| LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: " |
| << e->src_output(); |
| return removed_any; |
| } |
| CHECK(id_node); |
| g->AddEdge(id_node, 0, e->dst(), e->dst_input()); |
| } |
| } |
| |
| // If any nodes have control dependencies on node "n", those |
| // nodes should have control dependencies on |
| // output_control_node. |
| if (output_control_node != nullptr) { |
| for (Node* id : identity_nodes) { |
| g->AddControlEdge(id, output_control_node); |
| } |
| } |
| |
| g->RemoveNode(n); |
| removed_any = true; |
| } |
| } |
| return removed_any; |
| } |
| |
| Status NameAndAttrsFromFunctionCall(const NodeDef& call_def, |
| NameAttrList* function) { |
| if (call_def.op() == "PartitionedCall" || |
| call_def.op() == "StatefulPartitionedCall") { |
| TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", function)); |
| } else { |
| function->set_name(call_def.op()); |
| *function->mutable_attr() = call_def.attr(); |
| } |
| return Status::OK(); |
| } |
| |
| Status InstantiateFunctionCall(const NodeDef& call_def, |
| FunctionLibraryRuntime* flr, |
| FunctionLibraryRuntime::Handle* handle) { |
| NameAttrList function; |
| TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(call_def, &function)); |
| return flr->Instantiate(function.name(), AttrSlice(&function.attr()), handle); |
| } |
| |
| namespace { |
| |
| std::vector<string> InputDevices(const Node& caller) { |
| std::vector<string> input_devices(caller.in_edges().size()); |
| std::vector<string> input_tensors(caller.in_edges().size()); |
| |
| for (const Edge* edge : caller.in_edges()) { |
| if (edge->IsControlEdge()) continue; |
| const string& input_device = edge->src()->has_assigned_device_name() |
| ? edge->src()->assigned_device_name() |
| : edge->src()->requested_device(); |
| input_devices[edge->dst_input()] = input_device; |
| input_tensors[edge->dst_input()] = |
| absl::StrCat(edge->src()->name(), ":", edge->src_output()); |
| } |
| |
| if (VLOG_IS_ON(4)) { |
| VLOG(4) << "Function instantiation input devices:"; |
| for (int i = 0; i < input_devices.size(); ++i) { |
| if (input_tensors[i].empty()) continue; // skip control edges |
| VLOG(4) << " [index " << i << "]" |
| << " device: " << input_devices[i] |
| << " (input: " << input_tensors[i] << ")"; |
| } |
| } |
| |
| return input_devices; |
| } |
| |
| // Place input nodes on the same device as the correspinding caller input |
| // node. Do not specify any placement for all other nodes. |
| class DefaultFunctionBodyPlacer : public InlinedFunctionBodyPlacer { |
| public: |
| explicit DefaultFunctionBodyPlacer(const Node& caller) |
| : input_devices_(InputDevices(caller)) {} |
| |
| absl::optional<string> InputNodeDevice(int input_index) const override { |
| return input_devices_[input_index]; |
| } |
| absl::optional<string> OutputNodeDevice(int output_index) const override { |
| return absl::nullopt; |
| } |
| bool ColocateOutputIdentity() const override { return false; } |
| absl::optional<string> ControlNodeDevice() const override { |
| return absl::nullopt; |
| } |
| absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override { |
| return absl::nullopt; |
| } |
| |
| private: |
| const std::vector<string> input_devices_; |
| }; |
| |
| // Place all nodes on the same device as caller node. |
| class SingleDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer { |
| public: |
| explicit SingleDeviceFunctionBodyPlacer(const Node& caller) |
| : caller_device_(caller.def().device()) {} |
| |
| absl::optional<string> InputNodeDevice(int input_index) const override { |
| return caller_device_; |
| } |
| absl::optional<string> OutputNodeDevice(int output_index) const override { |
| return caller_device_; |
| } |
| bool ColocateOutputIdentity() const override { return false; } |
| absl::optional<string> ControlNodeDevice() const override { |
| return caller_device_; |
| } |
| absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override { |
| return caller_device_; |
| } |
| |
| private: |
| const string caller_device_; |
| }; |
| |
| // Place input nodes on the same device as the correspinding caller input |
| // node. Do not place output node. Place control nodes on the same device as |
| // caller node. For all function body nodes overrides job, replica and task |
| // parts of the device assignment to match function caller node. |
| class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer { |
| public: |
| explicit MultiDeviceFunctionBodyPlacer(const Node& caller) |
| : caller_device_(caller.def().device()), |
| input_devices_(InputDevices(caller)) { |
| has_parsed_caller_device_ = |
| DeviceNameUtils::ParseFullName(caller_device_, &caller_parsed_device_); |
| } |
| |
| absl::optional<string> InputNodeDevice(int input_index) const override { |
| return input_devices_[input_index]; |
| } |
| absl::optional<string> OutputNodeDevice(int output_index) const override { |
| return absl::nullopt; |
| } |
| bool ColocateOutputIdentity() const override { return true; } |
| absl::optional<string> ControlNodeDevice() const override { |
| return caller_device_; |
| } |
| absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override { |
| // TODO(ezhulenev): If function would have been instantiated as a |
| // multi-device function and executed via FunctionLibraryRuntime, it could |
| // be potentially placed on any available device. However there are multiple |
| // tests relying on this assumption. Fix them, and remove this line. |
| if (ndef.device().empty()) return caller_device_; |
| |
| if (!has_parsed_caller_device_) return ndef.device(); |
| |
| DeviceNameUtils::ParsedName ndef_parsed_device; |
| if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device)) |
| return ndef.device(); |
| |
| if (caller_parsed_device_.has_job) { |
| ndef_parsed_device.has_job = caller_parsed_device_.has_job; |
| ndef_parsed_device.job = caller_parsed_device_.job; |
| } |
| |
| if (caller_parsed_device_.has_replica) { |
| ndef_parsed_device.has_replica = caller_parsed_device_.has_replica; |
| ndef_parsed_device.replica = caller_parsed_device_.replica; |
| } |
| |
| if (caller_parsed_device_.has_task) { |
| ndef_parsed_device.has_task = caller_parsed_device_.has_task; |
| ndef_parsed_device.task = caller_parsed_device_.task; |
| } |
| return DeviceNameUtils::ParsedNameToString(ndef_parsed_device); |
| } |
| |
| private: |
| string caller_device_; |
| bool has_parsed_caller_device_; |
| DeviceNameUtils::ParsedName caller_parsed_device_; |
| std::vector<string> input_devices_; |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<InlinedFunctionBodyPlacer> |
| InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph, |
| const Node& caller) { |
| VLOG(3) << "Create default placer for inlined function body."; |
| return absl::make_unique<DefaultFunctionBodyPlacer>(caller); |
| } |
| |
| std::unique_ptr<InlinedFunctionBodyPlacer> |
| InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph, |
| const Node& caller) { |
| VLOG(3) << "Create single device placer for inlined function body."; |
| return absl::make_unique<SingleDeviceFunctionBodyPlacer>(caller); |
| } |
| |
| std::unique_ptr<InlinedFunctionBodyPlacer> |
| InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph, |
| const Node& caller) { |
| VLOG(3) << "Create multi device placer for inlined function body."; |
| return absl::make_unique<MultiDeviceFunctionBodyPlacer>(caller); |
| } |
| |
| namespace { |
| |
| Status ValidateNoInline(const FunctionBody* fbody) { |
| const auto attr = AttrSlice(&fbody->fdef.attr()); |
| bool noinline = false; |
| if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) { |
| return errors::InvalidArgument( |
| "Can't inline function marked with '_noinline'"); |
| } |
| return Status::OK(); |
| } |
| |
| using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; |
| |
| // Propagate the debug info of `nodes` in function `func` to the `target` node. |
| // If the debug info of any node is missing, its node name and function name |
| // is used. |
| void PropagateDebugInfoToNode(const string& func, |
| const std::vector<const Node*>& nodes, |
| NodeDef* target) { |
| if (nodes.empty() || target->has_experimental_debug_info()) { |
| return; |
| } |
| for (const Node* node : nodes) { |
| const auto& node_def = node->def(); |
| if (node_def.has_experimental_debug_info()) { |
| target->mutable_experimental_debug_info()->MergeFrom( |
| node_def.experimental_debug_info()); |
| } else { |
| target->mutable_experimental_debug_info()->add_original_node_names( |
| node_def.name()); |
| target->mutable_experimental_debug_info()->add_original_func_names(func); |
| } |
| } |
| } |
| } // namespace |
| |
| string InlineFunctionBodyOptions::DebugString() const { |
| const auto true_false = [](bool b) { return b ? "true" : "false"; }; |
| |
| const auto keep_caller_node_str = [this]() -> string { |
| switch (keep_caller_node) { |
| case KeepCallerNode::kDoNotKeep: |
| return "DoNotKeep"; |
| case KeepCallerNode::kFetchable: |
| return "Fetchable"; |
| case KeepCallerNode::kTargetable: |
| return "Targetable"; |
| } |
| }; |
| |
| return absl::StrCat( |
| "disable_inlining=", true_false(disable_inlining), |
| ", ignore_noinline=", true_false(ignore_noinline), |
| ", inline_impl_selection_group_functions=", |
| true_false(inline_impl_selection_group_functions), |
| ", keep_caller_node=", keep_caller_node_str(), ", output_control_src=", |
| output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs" |
| : "ControlOutputs", |
| ", inlined_function_body_placer=", inlined_function_body_placer.name, |
| ", uniquify_frame_names=", true_false(uniquify_frame_names)); |
| } |
| |
| Status ValidateInlining(const Node* node, const FunctionBody* fbody, |
| const InlineFunctionBodyOptions& options) { |
| // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee |
| // that all side-effectful ops will be executed after inlining. See Grappler |
| // function_optimizer for details. Unify all function inlining mechanism. |
| // Do not inline if `!fbody->control_ret_nodes.empty()`. |
| |
| const auto num_node_inputs = static_cast<size_t>(node->num_inputs()); |
| const auto num_node_outputs = static_cast<size_t>(node->num_outputs()); |
| |
| if (num_node_inputs != fbody->arg_types.size() || |
| num_node_inputs != fbody->arg_nodes.size()) { |
| return errors::InvalidArgument( |
| "Node inputs do not match function arguments: inputs=", num_node_inputs, |
| " arg_types=", fbody->arg_types.size(), |
| " arg_nodes=", fbody->arg_nodes.size()); |
| } |
| |
| if (num_node_outputs != fbody->ret_types.size() || |
| num_node_outputs != fbody->ret_nodes.size()) { |
| return errors::InvalidArgument( |
| "Node outputs do not match function returns: outputs=", |
| num_node_outputs, " ret_types=", fbody->ret_types.size(), |
| " ret_nodes=", fbody->ret_nodes.size()); |
| } |
| |
| for (int i = 0; i < node->num_inputs(); ++i) { |
| if (node->input_type(i) != fbody->arg_types[i]) { |
| return errors::InvalidArgument( |
| "Node input type doesn't match function argument type: ", |
| node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i); |
| } |
| } |
| for (int i = 0; i < node->num_outputs(); ++i) { |
| if (node->output_type(i) != fbody->ret_types[i]) { |
| return errors::InvalidArgument( |
| "Node output type doesn't match function return type: ", |
| node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i); |
| } |
| } |
| |
| if (options.disable_inlining) { |
| return errors::InvalidArgument( |
| "Function inlining explicitly disabled by 'options.disable_inlining'"); |
| } |
| |
| if (!options.inline_impl_selection_group_functions) { |
| bool is_impl_selection_group_function = |
| fbody->fdef.attr().find("api_implements") != fbody->fdef.attr().end(); |
| if (is_impl_selection_group_function) { |
| return errors::InvalidArgument( |
| "Inlining of implementation selection group function ", |
| fbody->fdef.signature().name(), |
| " is disabled by options.inline_impl_selection_group_functions"); |
| } |
| } |
| |
| if (!options.ignore_noinline) { |
| TF_RETURN_IF_ERROR(ValidateNoInline(fbody)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| // Function inlining must preserve function execution semantics with regards to |
| // side-effects visibility. Tensorflow in Eager mode has an automatic control |
| // dependencies tracking mechanism, which enforces well-defined execution order |
| // of all side-effects. Any other frontend (e.g. Swift) must produce graphs |
| // following the same rules, to ensure that function inlining works correctly. |
| // |
| // IMPORTANT: Currently we do not have a true notion of "side-effectful" node, |
| // we assume that all stateful nodes might have side-effects, though it's not |
| // true in practice, e.g. `ReadVariableOp` doesn't have an observable |
| // side-effect. |
| // |
| // Automatic control dependency rules in Tensorflow 2.0 (python in eager mode): |
| // |
| // 1) When a function has a resource (DT_RESOURCE data type) input argument it |
| // "captures" the mutable resource. This is implemented by automatically |
| // adding a incoming control edge from the previous side-effectful op |
| // touching that resource, and an outgoing control edge to the next |
| // side-effectful op using the same resource. This serializes the mutations |
| // of the resource to make graph execution deterministic. |
| // |
| // 2) All stateful ops inside a function body are guaranteed to execute in |
| // program order, this is achieved by adding control edges between stateful |
| // ops at graph construction time. Stateful ops (or ops that must execute) |
| // should be in the function control return set. Having a data edge to the |
| // regular function output might be not enough, because after function |
| // inlining it might happen that data output is unused. |
| // |
| // 3) Furthermore, all ops accepting the same resource as an input are |
| // guaranteed to run in program order. This is also done by adding control |
| // edges at graph construction time. The last op touching the resource |
| // must be in a control return set, which will guarantee that all side |
| // effects to the resource will happen before function completion. |
| // |
| // Function inlining must preserve side-effect visibility: |
| // |
| // 1) All side-effects to the captured resources, that happened before function |
| // call must be visible to the function body nodes using that resources. |
| // |
| // 2) All side-effects to the captured resources, that happened inside function |
| // body, must be visible to every op/function using that resource after the |
| // function call completed. |
| // |
| // To guarantee that these properties are preserved after inlining we: |
| // |
| // 1) Create "input_control_node" NoOp. Function call node incoming control |
| // edges will be forwarded *to* this node. Function inputs (Identity nodes) |
| // will have a control edge *from* this node. If function body has nodes |
| // without inputs, they will have a control edge *from* this node. |
| // |
| // 2) Create "output_control_node" NoOp. All nodes that have incoming control |
| // edge *from* the function call node, will be forwarded to this node. |
| // |
| // We have two options for choosing which nodes will have a control edge *to* |
| // the "output control node": |
| // a) control returns (`control_ret` field in FunctionDef) |
| // b) data returns (`ret` field in FunctionDef) |
| // |
| // We do a) for multi-device function calls in Tensorflow v2 and b) |
| // for the rest for compatibility with Tensorflow v1. |
| // |
| // Following the automatic control dependencies tracking rules, a node that |
| // has an incoming control edge from the function call node is dependent on |
| // the side-effects happening inside the function body. The output control |
| // node will guarantee side-effects execution order. |
| // |
| // If function call node doesn't have an outgoing control edge, it means that |
| // no one is interested in observing side-effects that might have happened. |
| // |
| // Function inlining might leave the graph in partially-placed state. Function |
| // inlining caller must call Placer to guarantee that all nodes are placed. |
| // |
| // Function inlining with `options.override_device=true` will leave graph in |
| // fully placed state, by overriding all inlined nodes devices with the caller |
| // node device, but it will make functions always single-device. These functions |
| // after inlining will not be able to handle resources on multiple devices. This |
| // is currently acceptable for XLA use cases (XLA cluster is always executed on |
| // a single device). |
| // |
| // TODO(ezhulenev): Documentation above is ahead of implementation below. |
| Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, |
| Node* caller, const FunctionBody* fbody, |
| const InlineFunctionBodyOptions& options) { |
| VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " [" |
| << options.DebugString() << "]"; |
| |
| Status validation = ValidateInlining(caller, fbody, options); |
| if (!validation.ok()) { |
| return errors::Internal("Inlining mismatch: ", validation.error_message()); |
| } |
| |
| // Placer is responsible for assigning devices for all nodes that we will add |
| // to the graph. |
| const std::unique_ptr<InlinedFunctionBodyPlacer> placer = |
| options.inlined_function_body_placer.get(*g, *caller); |
| |
| // We can't possibly introduce a duplicate control edge during function |
| // inlining, so we skip this check in calls to the 'g->AddControlEdge(...)'. |
| static constexpr bool kDoNotCheckDuplicates = true; |
| |
| // ------------------------------------------------------------------------ // |
| // Helper functions to create `NoOp` and `Identity` nodes for auxiliary |
| // control nodes and inlined function inputs and outputs. |
| |
| // Add a NoOp node for function control inputs/outputs. |
| const auto no_op = [&](StringPiece name) -> Node* { |
| Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g); |
| const absl::optional<string> device = placer->ControlNodeDevice(); |
| if (device.has_value()) node->set_requested_device(*device); |
| return node; |
| }; |
| |
| // Add an Identity node for function input. |
| const auto input_identity = [&](StringPiece name, Endpoint input, |
| int index) -> Node* { |
| Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input); |
| const absl::optional<string> device = placer->InputNodeDevice(index); |
| if (device.has_value()) node->set_requested_device(*device); |
| return node; |
| }; |
| |
| // Add an Identity node for function output. |
| const auto output_identity = [&](StringPiece name, Endpoint input, |
| int index) -> Node* { |
| Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input); |
| const absl::optional<string> device = placer->OutputNodeDevice(index); |
| if (device.has_value()) node->set_requested_device(*device); |
| bool colocate_identity = placer->ColocateOutputIdentity(); |
| if (colocate_identity) { |
| node->AddAttr(kColocationAttrName, |
| std::vector<string>{absl::StrCat(kColocationGroupPrefix, |
| input.node->name())}); |
| } |
| return node; |
| }; |
| |
| // ------------------------------------------------------------------------ // |
| // Input edges. For data edges coming into "caller", we first compute the |
| // <src>:<src_output> for the i-th input in "inputs". |
| // If "caller" has any input control dependencies, we add a NoOp |
| // node "input_control_node", which depends on "caller"'s control inputs. |
| std::vector<Endpoint> inputs(caller->num_inputs()); |
| Node* input_control_node = nullptr; |
| for (const Edge* e : caller->in_edges()) { |
| if (e->IsControlEdge()) { |
| if (input_control_node == nullptr) { |
| input_control_node = no_op("input_control_node"); |
| } |
| g->AddControlEdge(e->src(), input_control_node, kDoNotCheckDuplicates); |
| } else { |
| inputs[e->dst_input()] = {e->src(), e->src_output()}; |
| } |
| } |
| if (input_control_node != nullptr) { |
| VLOG(3) << "Created input control node: " << input_control_node->name(); |
| } |
| |
| // ------------------------------------------------------------------------ // |
| // Duplicate fbody->graph into 'g'. First, we copy the nodes of |
| // fbody->graph into 'g' except the source and sink nodes. We copy |
| // edges among nodes in 'fbody->graph'. |
| // |
| // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we |
| // remember 'y' in node_map[x->id()]. |
| std::vector<Node*> node_map(fbody->graph->num_node_ids()); |
| for (Node* n : fbody->graph->op_nodes()) { |
| NodeDef ndef = n->def(); |
| |
| // Maybe override requested node device assignment. |
| const absl::optional<string> device = placer->BodyNodeDevice(ndef); |
| if (device.has_value()) ndef.set_device(*device); |
| |
| // Add inlined function name to inlined node debug information. |
| PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef); |
| |
| // Add the function node name as a prefix: |
| // 1) to node name to avoid collisions |
| // 2) to frame name to avoid multiple LoopCond nodes in one frame |
| // 3) to colocation attribute |
| const string prefix = strings::StrCat(caller->name(), "/"); |
| TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef, |
| options.uniquify_frame_names)); |
| |
| Status added_node; |
| Node* clone = g->AddNode(ndef, &added_node); |
| TF_CHECK_OK(added_node); |
| node_map[n->id()] = clone; |
| |
| // If there is an input control node, and one of: |
| // a) the node has no data or control inputs, or |
| // b) the node is a function call (including SymbolicGradient), |
| // then add a control edge from the input control node to the clone (only |
| // if it does not already have a control input). |
| // |
| // We must not execute any nodes if the original function call would not |
| // have executed. This is especially critical when the function call is |
| // inside a control-flow construct like tf.cond(). Case (a) ensures that |
| // such nodes do not run. |
| // |
| // The purpose of case (b) is to ensure that instances of case (a) created |
| // by further inlining steps also receive the control dependency. |
| // |
| // This edge is required to transfer execution frame down to all function |
| // body nodes of inlined nested function calls. |
| if (input_control_node) { |
| const auto is_input_edge = [](const Edge* e) -> bool { |
| return !e->src()->IsSource(); |
| }; |
| const auto is_control_edge = [](const Edge* e) -> bool { |
| return !e->src()->IsSource() && e->IsControlEdge(); |
| }; |
| |
| // Forward execution frame if: |
| // |
| // a) The node has no data or control inputs. |
| // b) OR the node is a function call without control inputs (control edge |
| // will be used in nested function inlining to forward execution frame |
| // to constants inside the function body). |
| // |
| // c) Do not forward control frame to function argument nodes, they will |
| // be connected to the corresponding function input later. |
| const bool forward_execution_frame = |
| (absl::c_none_of(n->in_edges(), is_input_edge) || // (a) |
| (n->IsFunctionCall() && // (b) |
| absl::c_none_of(n->in_edges(), is_control_edge))) && // |
| !n->IsArg(); // (c) |
| |
| if (forward_execution_frame) { |
| VLOG(4) << "Add control edge from input control node to: " |
| << clone->name(); |
| g->AddControlEdge(input_control_node, clone, kDoNotCheckDuplicates); |
| } |
| } |
| } |
| for (const Edge* e : fbody->graph->edges()) { |
| if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() || |
| e->dst()->IsSink()) { |
| continue; |
| } |
| Node* src_copy = node_map[e->src()->id()]; |
| Node* dst_copy = node_map[e->dst()->id()]; |
| g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); |
| } |
| |
| // ------------------------------------------------------------------------ // |
| // Connect input edges. |
| // |
| // We create one Identity node for each input. Then, we connect inputs[i] to |
| // the i-th identity node added. The nodes that previously connected |
| // to the j-th output of i-th arg node are reconnected to the i-th |
| // identity node. |
| // |
| // The added identity nodes depend on "input_control_node". |
| VLOG(4) << "Add input Identity nodes for each function argument:"; |
| for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) { |
| Node* arg = node_map[fbody->arg_nodes[i]->id()]; |
| Node* n = input_identity("input", inputs[i], i); |
| VLOG(4) << " [index " << i << "] " |
| << fbody->fdef.signature().input_arg(i).name() << " as " |
| << n->name() << " (input: " << inputs[i].name() |
| << ", requested_device: " << n->requested_device() << ")"; |
| |
| if (input_control_node) { |
| g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates); |
| } |
| for (const Edge* e : arg->out_edges()) { |
| if (e->IsControlEdge()) { |
| g->AddControlEdge(n, e->dst(), kDoNotCheckDuplicates); |
| } else { |
| g->AddEdge(n, 0, e->dst(), e->dst_input()); |
| } |
| } |
| node_map[fbody->arg_nodes[i]->id()] = n; |
| g->RemoveNode(arg); // 'arg' is disconnected. |
| } |
| |
| // ------------------------------------------------------------------------ // |
| // Connect output edges. |
| // |
| // For i-th return node in fbody->graph, we add in "g" an identity node |
| // (outputs[i-th]). We then reconnect every incoming edge into the i-th return |
| // node to the added identity node. |
| // |
| // For every data edge coming out of "callee"s i-th output, we reconnect it to |
| // the i-th identity added above. |
| // |
| // If "callee" is control-depended upon by any other nodes, we add a NoOp node |
| // "output_control_node". "output_control_node" depends on all identity nodes |
| // added above or on all control return nodes (controlled by |
| // `options.output_control_src` value). And nodes previously depend on |
| // "callee" is changed to depend on "output_control_node". |
| // |
| // If `keep_node_fetchable` is `true` we always add an output control node, to |
| // guarantee that executing a fetchable node will execute all side-effects. |
| VLOG(4) << "Add output Identity nodes for each function output argument:"; |
| std::vector<Node*> outputs(caller->num_outputs()); |
| for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) { |
| Node* ret = node_map[fbody->ret_nodes[i]->id()]; |
| Endpoint data; // Data input for the ret node. |
| for (const Edge* e : ret->in_edges()) { |
| if (!e->IsControlEdge()) { |
| data = {e->src(), e->src_output()}; |
| break; |
| } |
| } |
| CHECK(data.node != nullptr); |
| Node* n = output_identity("output", data, i); |
| outputs[i] = n; |
| VLOG(4) << " [index " << i << "] " |
| << fbody->fdef.signature().output_arg(i).name() << " as " |
| << n->name() << " (ret: " << data.node->name() << ":" << data.index |
| << ", requested_device: " << n->requested_device() << ")"; |
| for (const Edge* e : ret->in_edges()) { |
| if (e->IsControlEdge()) { |
| g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates); |
| } |
| } |
| g->RemoveNode(ret); // 'ret' is disconnected. |
| } |
| |
| Node* output_control_node = nullptr; |
| const bool has_control_outputs = absl::c_any_of( |
| caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); }); |
| |
| using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode; |
| const bool keep_caller_node = |
| options.keep_caller_node == KeepCallerNode::kFetchable || |
| options.keep_caller_node == KeepCallerNode::kTargetable; |
| |
| if (has_control_outputs || keep_caller_node) { |
| output_control_node = no_op("output_control_node"); |
| VLOG(4) << "Add output control node: " << output_control_node->name(); |
| if (options.output_control_src == OutputControlSrc::kDataOutputs) { |
| for (Node* n : outputs) { |
| VLOG(4) << " [data output] add control edge from: " << n->name(); |
| g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates); |
| } |
| } else { |
| for (Node* fbody_node : fbody->control_ret_nodes) { |
| Node* n = node_map[fbody_node->id()]; |
| VLOG(4) << " [control output] add control edge from: " << n->name(); |
| g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates); |
| } |
| } |
| } |
| |
| // We can't leave output control node without incoming control edges, because |
| // in this case outgoing control edge will loose execution frame information. |
| // We connect input_control_node and output_control_node with a control edge |
| // to forward execution frame to the controlled nodes. Above we add a control |
| // edge to all function calls inside function body, to guarantee that we will |
| // always have input_control_node when we need it. |
| if (output_control_node && output_control_node->in_edges().empty()) { |
| if (input_control_node) { |
| VLOG(4) |
| << "Add add a control edge between input and output control nodes: " |
| << input_control_node->name() << " to " |
| << output_control_node->name(); |
| g->AddControlEdge(input_control_node, output_control_node, |
| kDoNotCheckDuplicates); |
| } else { |
| VLOG(4) << "Function inlining potentially dropped execution frame " |
| "information from outgoing control edges."; |
| } |
| } |
| |
| for (const Edge* e : caller->out_edges()) { |
| if (e->IsControlEdge()) { |
| g->AddControlEdge(output_control_node, e->dst(), kDoNotCheckDuplicates); |
| } else { |
| g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input()); |
| } |
| } |
| |
| // ------------------------------------------------------------------------ // |
| // Add an IdentityN or NoOp node in-place of caller node to keep `caller` |
| // fetchable or targetable. |
| |
| if (keep_caller_node) { |
| std::vector<NodeBuilder::NodeOut> output_tensors; |
| absl::c_transform(outputs, std::back_inserter(output_tensors), |
| [](Node* n) { return NodeBuilder::NodeOut(n, 0); }); |
| |
| Node* caller_substitute_node; |
| if (options.keep_caller_node == KeepCallerNode::kTargetable || |
| output_tensors.empty()) { |
| // IdentityN node must have at least one data input. If function has no |
| // data outputs, we can't keep it fetchable. |
| TF_CHECK_OK(NodeBuilder(caller->name(), "NoOp") |
| .Device(caller->requested_device()) |
| .ControlInput(output_control_node) |
| .Finalize(g, &caller_substitute_node)); |
| |
| } else if (options.keep_caller_node == KeepCallerNode::kFetchable) { |
| TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN") |
| .Device(caller->requested_device()) |
| .Input(output_tensors) |
| .ControlInput(output_control_node) |
| .Finalize(g, &caller_substitute_node)); |
| } |
| } |
| |
| // ------------------------------------------------------------------------ // |
| // 'caller' is replaced with inlined function body nodes and maybe IdentityN |
| // to keep it fetchable. |
| VLOG(3) << "Successfully inlined function call node: " << caller->name(); |
| g->RemoveNode(caller); |
| |
| return Status::OK(); |
| } |
| |
| bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, |
| const Node& node) { |
| return node.IsFunctionCall(); |
| } |
| |
| bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, |
| const ExpandInlineFunctionsOptions& options) { |
| std::vector<std::pair<Node*, const FunctionBody*>> candidates; |
| |
| const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition(); |
| |
| for (Node* node : graph->nodes()) { |
| // Skip nodes that are not function calls or SymbolicGradient calls. |
| if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) { |
| continue; |
| } |
| // Skip function calls that marked noinline. |
| bool noinline; |
| if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) { |
| VLOG(3) << "noinline: " << SummarizeNode(*node); |
| continue; |
| } |
| FunctionLibraryRuntime::Handle handle; |
| Status s = InstantiateFunctionCall(node->def(), lib, &handle); |
| if (!s.ok()) { |
| LOG(ERROR) << "Failed to instantiate a function: " << s.error_message(); |
| continue; |
| } |
| const FunctionBody* fbody = lib->GetFunctionBody(handle); |
| CHECK_NOTNULL(fbody); |
| candidates.emplace_back(node, fbody); |
| } |
| |
| bool inlined_any = false; |
| for (const auto& p : candidates) { |
| Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second, |
| p.first->IsPartitionedCall() |
| ? options.multi_device_options |
| : options.native_options); |
| if (inlined.ok()) { |
| inlined_any = true; |
| } else { |
| VLOG(1) << "Failed to inline function call: node=" << p.first->name() |
| << " error=" << inlined.error_message(); |
| } |
| } |
| |
| // TODO(ezhulenev): Release handles for inlined function calls. |
| |
| return inlined_any; |
| } |
| |
| string NewName(const Node* n, bool pretty) { |
| if (pretty) { |
| return strings::StrCat(n->type_string(), n->id()); |
| } else { |
| return strings::StrCat("n", n->id()); |
| } |
| } |
| |
| // TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef. |
| // and stash the original NodeDef name as an attr for documentation |
| // purpose. |
| void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) { |
| // We visit nodes in forward topological sort order, which is a |
| // possible execution order of the graph. |
| gtl::InlinedVector<const Edge*, 4> inputs; |
| gdef->Clear(); |
| *gdef->mutable_versions() = g->versions(); |
| |
| std::vector<Node*> start_nodes; |
| for (Node* n : g->nodes()) { |
| if (n->out_edges().empty()) { |
| start_nodes.push_back(n); |
| } |
| } |
| |
| ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) { |
| if (!n->IsOp()) return; |
| NodeDef* ndef = gdef->add_node(); |
| ndef->set_name(NewName(n, pretty)); |
| ndef->set_op(n->type_string()); |
| for (const auto& attr : n->attrs()) { |
| (*ndef->mutable_attr())[attr.first] = attr.second; |
| } |
| |
| if (!n->assigned_device_name().empty()) { |
| ndef->set_device(n->assigned_device_name()); |
| } else { |
| ndef->set_device(n->requested_device()); |
| } |
| |
| inputs.clear(); |
| inputs.resize(n->num_inputs()); |
| for (const Edge* e : n->in_edges()) { |
| if (e->IsControlEdge()) { |
| inputs.push_back(e); |
| } else { |
| if (inputs[e->dst_input()] == nullptr) { |
| inputs[e->dst_input()] = e; |
| } else { |
| LOG(WARNING) << "Malformed graph node. multiple input edges: " |
| << n->DebugString(); |
| } |
| } |
| } |
| // node->name() is merely NodeDef::name, which are not guaranteed |
| // to be unique and stable after optimization rewrites. Therefore, |
| // we use "n<node id>" instead. |
| for (const Edge* e : inputs) { |
| if (e == nullptr) { |
| ndef->add_input("unknown"); |
| continue; |
| } |
| const string srcname = NewName(e->src(), pretty); |
| if (!e->src()->IsOp()) { |
| } else if (e->IsControlEdge()) { |
| ndef->add_input(strings::StrCat("^", srcname)); |
| } else if (e->src_output() == 0) { |
| ndef->add_input(srcname); |
| } else { |
| ndef->add_input(strings::StrCat(srcname, ":", e->src_output())); |
| } |
| } |
| }); |
| } |
| |
| string DebugString(const Graph* g) { |
| GraphDef gdef; |
| ToGraphDef(g, &gdef); |
| return DebugString(gdef); |
| } |
| |
| FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t, |
| DataTypeSlice ret_t, Graph* g) |
| : fdef(f), |
| graph(g), |
| arg_types(arg_t.begin(), arg_t.end()), |
| ret_types(ret_t.begin(), ret_t.end()) { |
| // 1. Find regular Arg/Ret nodes. |
| this->arg_nodes.resize(arg_types.size()); |
| this->ret_nodes.resize(ret_types.size()); |
| for (Node* n : this->graph->op_nodes()) { |
| gtl::InlinedVector<Node*, 4>* node_vec; |
| if (n->type_string() == kRetOp || n->type_string() == kDeviceRetOp) { |
| node_vec = &this->ret_nodes; |
| } else if (n->type_string() == kArgOp || n->type_string() == kDeviceArgOp) { |
| node_vec = &this->arg_nodes; |
| } else { |
| continue; |
| } |
| int index; |
| TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index)); |
| CHECK_LE(0, index); |
| CHECK_LT(index, node_vec->size()); |
| (*node_vec)[index] = n; |
| } |
| // 2. Find ControlRet nodes that must be always executed. |
| std::unordered_set<StringPiece, StringPieceHasher> control_ret_node_names; |
| for (const auto& control_ret : fdef.control_ret()) { |
| control_ret_node_names.insert(control_ret.second); |
| } |
| this->control_ret_nodes.reserve(control_ret_node_names.size()); |
| for (Node* n : this->graph->op_nodes()) { |
| if (control_ret_node_names.count(n->name()) > 0) { |
| this->control_ret_nodes.push_back(n); |
| } |
| } |
| } |
| |
| FunctionBody::~FunctionBody() { delete this->graph; } |
| |
| class SymbolicGradientHelper { |
| public: |
| explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {} |
| ~SymbolicGradientHelper() = default; |
| |
| std::unique_ptr<FunctionBody> Compute(); |
| |
| private: |
| const FunctionBody* fbody_; |
| |
| // Makes a copy of fbody_ in gbody. |
| void Copy(FunctionBody* gbody); |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper); |
| }; |
| |
| void SymbolicGradientHelper::Copy(FunctionBody* gbody) { |
| const Graph& src = *(fbody_->graph); |
| gbody->graph = new Graph(src.op_registry()); |
| Graph* dst = gbody->graph; |
| |
| std::vector<Node*> node_map(src.num_node_ids()); |
| |
| // Copy just the fdef attributes (copy '_noinline' and other similar flags to |
| // the gradient function body). |
| *(gbody->fdef.mutable_attr()) = fbody_->fdef.attr(); |
| |
| // Copy the nodes. |
| node_map[src.source_node()->id()] = dst->source_node(); |
| node_map[src.sink_node()->id()] = dst->sink_node(); |
| for (Node* n : src.op_nodes()) { |
| node_map[n->id()] = dst->CopyNode(n); |
| } |
| |
| // Copy the edges. |
| for (const Edge* e : src.edges()) { |
| Node* src_copy = node_map[e->src()->id()]; |
| Node* dst_copy = node_map[e->dst()->id()]; |
| dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); |
| } |
| |
| // Save inputs in copied graph. |
| CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size()); |
| gbody->arg_types = fbody_->arg_types; |
| for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) { |
| gbody->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]); |
| } |
| |
| // Save outputs in copied graph. |
| CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size()); |
| gbody->ret_types = fbody_->ret_types; |
| for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) { |
| gbody->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]); |
| } |
| } |
| |
| std::unique_ptr<FunctionBody> SymbolicGradientHelper::Compute() { |
| FunctionBody* gbody = new FunctionBody; |
| Copy(gbody); // copy fbody_ into gbody. |
| |
| Graph* g = gbody->graph; |
| |
| const int num_y = static_cast<int>(gbody->ret_nodes.size()); |
| |
| // Populate 'y_node_outputs_' with node function body outputs. |
| // Populate 'y_grad_nodes' with initial gradient nodes for each return node |
| // of the original function body (these will be 'arg' nodes in the function |
| // gradient body). |
| std::vector<NodeOut> y_node_outputs; |
| y_node_outputs.reserve(num_y); |
| std::vector<NodeOut> y_grad_node_outputs; |
| y_grad_node_outputs.reserve(num_y); |
| for (int i = 0; i < num_y; ++i) { |
| Node* y = gbody->ret_nodes[i]; |
| y_node_outputs.push_back({y, 0}); |
| DCHECK_EQ(y->type_string(), kRetOp); |
| const DataType dtype = y->input_type(0); |
| const int index = static_cast<int>(gbody->arg_nodes.size()); |
| Node* dy = AddArg(g, dtype, index); |
| gbody->arg_types.push_back(dtype); |
| gbody->arg_nodes.push_back(dy); |
| y_grad_node_outputs.push_back({dy, 0}); |
| } |
| |
| // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs'). |
| const size_t num_x = fbody_->arg_nodes.size(); |
| std::vector<NodeOut> x_node_outputs; |
| x_node_outputs.reserve(num_x); |
| for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) { |
| x_node_outputs.push_back({gbody->arg_nodes[i], 0}); |
| } |
| |
| // Call AddSymbolicGradients which will add nodes to graph 'g' that |
| // compute the function gradient (adding an entry in 'x_grad_node_outputs' |
| // for each node in 'x_node_outputs'). |
| std::vector<NodeOut> x_grad_node_outputs; |
| TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs, |
| y_grad_node_outputs, &x_grad_node_outputs, |
| g)); |
| |
| // Remove the old return nodes from the function body. |
| for (Node* n : gbody->ret_nodes) { |
| g->RemoveNode(n); |
| } |
| gbody->ret_types = fbody_->arg_types; |
| // TODO(apassos): use the right dtype for gradients of resource variables |
| for (int i = 0; i < gbody->ret_types.size(); ++i) { |
| if (gbody->ret_types[i] == DT_RESOURCE) { |
| gbody->ret_types[i] = DT_FLOAT; |
| } |
| } |
| gbody->ret_nodes.clear(); |
| // Add new return nodes to the function gradient body for each node |
| // in 'x_grad_nodes'. |
| const int arg_types_size = static_cast<int>(fbody_->arg_types.size()); |
| for (int i = 0; i < arg_types_size; ++i) { |
| Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index}; |
| Node* ret = AddRet(g, grad, i); |
| gbody->ret_nodes.push_back(ret); |
| } |
| |
| return std::unique_ptr<FunctionBody>(gbody); |
| } |
| |
| std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) { |
| return SymbolicGradientHelper(f).Compute(); |
| } |
| |
| Status FunctionDefToBodyHelper( |
| const FunctionDef& fdef, const AttrSlice& attrs, |
| const FunctionLibraryDefinition* const lib_def, |
| const std::function<Status(const string&, const OpDef**)>& get_func_sig, |
| std::unique_ptr<FunctionBody>* fbody) { |
| // Instantiates the function template into a graph def. |
| InstantiationResult result; |
| TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result)); |
| |
| std::unique_ptr<Graph> graph(new Graph(lib_def)); |
| GraphConstructorOptions opts; |
| opts.allow_internal_ops = true; |
| opts.expect_device_spec = false; |
| TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get())); |
| |
| // Call BuildControlFlowInfo to validate that this function body has |
| // well-formed control flow. |
| std::vector<ControlFlowInfo> dummy; |
| TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy)); |
| |
| *fbody = absl::make_unique<FunctionBody>(fdef, result.arg_types, |
| result.ret_types, graph.release()); |
| return Status::OK(); |
| } |
| |
| Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs, |
| const FunctionLibraryDefinition* lib_def, |
| std::unique_ptr<FunctionBody>* fbody) { |
| const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) { |
| return lib_def->LookUpOpDef(op, sig); |
| }; |
| return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody); |
| } |
| |
| } // end namespace tensorflow |