| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/common_runtime/single_threaded_executor.h" |
| |
| #include <utility> |
| |
| #include "tensorflow/core/common_runtime/entry.h" |
| #include "tensorflow/core/common_runtime/executor.h" |
| #include "tensorflow/core/common_runtime/executor_factory.h" |
| #include "tensorflow/core/common_runtime/renamed_device.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/macros.h" |
| |
| namespace tensorflow { |
| |
| Status ValidateOpIsSafeForSyncExecution( |
| const Node& n, bool allow_control_flow_sync_execution) { |
| for (DataType dt : n.output_types()) { |
| if (IsRefType(dt)) { |
| return errors::Unimplemented( |
| "Single-threaded executor does not support reference-typed " |
| "edges. But saw type ", |
| DataTypeString(dt), " in outputs of node ", n.name()); |
| } |
| } |
| // Executing Switch nodes requires propagating deadness which is |
| // not currently supported in the SingleThreadedExecutor. |
| if (n.IsSwitch()) { |
| return errors::FailedPrecondition( |
| "Single-threaded executor does not support switch op, but saw node ", |
| n.name(), |
| ". Perhaps your graph contains old-style control flow primitives? " |
| "Try using tf.compat.v1.enable_control_flow_v2()."); |
| } |
| if (n.IsControlFlow() && !allow_control_flow_sync_execution) { |
| return errors::FailedPrecondition( |
| "Single-threaded executor does not support low level control flow, " |
| " but saw control flow node ", |
| n.name(), |
| ". Perhaps your graph contains old-style control flow primitives? " |
| "Try using tf.compat.v1.enable_control_flow_v2()."); |
| } |
| return Status::OK(); |
| } |
| |
| namespace { |
| |
| typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; |
| typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; |
| |
| static const string& kSingleThreadedExecutor = |
| *new string("SINGLE_THREADED_EXECUTOR"); |
| |
| class SingleThreadedExecutorImpl : public Executor { |
| public: |
| explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params) |
| : params_(params) {} |
| |
| ~SingleThreadedExecutorImpl() override { |
| for (const KernelState& kernel_state : kernels_) { |
| params_.delete_kernel(kernel_state.kernel); |
| } |
| for (const ConstTensorKernelState& kernel_state : const_tensor_kernels_) { |
| params_.delete_kernel(kernel_state.kernel); |
| } |
| } |
| |
| Status Initialize(const Graph& graph) { |
| // Topologicially sort `graph` to get a sequence of OpKernels. |
| std::vector<Node*> ordered_nodes; |
| ordered_nodes.reserve(graph.num_nodes()); |
| GetReversePostOrder(graph, &ordered_nodes); |
| int ordered_nodes_size = ordered_nodes.size(); |
| if (ordered_nodes_size != graph.num_nodes()) { |
| return errors::InvalidArgument("Graph had ", graph.num_nodes(), |
| " but reverse post-order had ", |
| ordered_nodes.size()); |
| } |
| |
| // We reserve two less nodes because we do not need to create kernels for |
| // the _SOURCE and _SINK nodes. |
| kernels_.reserve(ordered_nodes.size() - 2); |
| std::vector<Node*> nodes_with_kernels; |
| std::vector<Node*> nodes_with_const_tensor_kernels; |
| nodes_with_kernels.reserve(ordered_nodes.size() - 2); |
| |
| std::map<size_t, Node*> arg_index_to_node_map; |
| absl::flat_hash_map<Node*, size_t> node_to_index_map; |
| |
| // Create the kernel and input-related structures for each node in `graph`. |
| for (Node* n : ordered_nodes) { |
| if (n->IsSource() || n->IsSink()) { |
| continue; |
| } |
| TF_RETURN_IF_ERROR(ValidateOpIsSafeForSyncExecution( |
| *n, params_.allow_control_flow_sync_execution)); |
| if (n->IsArg()) { |
| int32_t arg_index; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &arg_index)); |
| if (arg_index < 0) { |
| return errors::InvalidArgument("Invalid argument index ", arg_index, |
| " in node ", n->name()); |
| } |
| arg_index_to_node_map[arg_index] = n; |
| // We do not create a kernel for Arg nodes, and instead inline the |
| // argument handling directly in the executor code. |
| continue; |
| } |
| |
| OpKernel* kernel; |
| TF_RETURN_IF_ERROR(params_.create_kernel(n->properties(), &kernel)); |
| |
| const Tensor* const_tensor; |
| if (n->num_outputs() == 1 && (const_tensor = kernel->const_tensor())) { |
| // Nodes that produce a single constant tensor are handled specially: |
| // we evaluate the tensor once, and propagate it to its consumers as |
| // a `const Tensor*`, to avoid refcount manipulation. |
| const size_t kernel_index = const_tensor_kernels_.size(); |
| const_tensor_kernels_.push_back({}); |
| nodes_with_const_tensor_kernels.push_back(n); |
| ConstTensorKernelState& kernel_state = |
| const_tensor_kernels_[kernel_index]; |
| kernel_state.kernel = kernel; |
| kernel_state.const_tensor = *const_tensor; |
| } else { |
| const size_t kernel_index = kernels_.size(); |
| kernels_.push_back({}); |
| nodes_with_kernels.push_back(n); |
| KernelState& kernel_state = kernels_[kernel_index]; |
| kernel_state.kernel = kernel; |
| kernel_state.num_inputs = n->num_inputs(); |
| kernel_state.num_outputs = n->num_outputs(); |
| node_to_index_map[n] = kernel_index; |
| if (kernel_index == 0) { |
| kernel_state.input_start_index = 0; |
| } else { |
| const KernelState& previous_kernel_state = kernels_[kernel_index - 1]; |
| kernel_state.input_start_index = |
| previous_kernel_state.input_start_index + |
| previous_kernel_state.num_inputs; |
| } |
| } |
| } |
| |
| // Build the mapping from each Arg node output to the input slot for the |
| // corresponding destination node. |
| if (!arg_index_to_node_map.empty()) { |
| const size_t num_args = arg_index_to_node_map.rbegin()->first + 1; |
| arg_output_locations_.resize(num_args); |
| for (const auto& arg_index_node_pair : arg_index_to_node_map) { |
| const size_t arg_index = arg_index_node_pair.first; |
| const Node* arg_node = arg_index_node_pair.second; |
| arg_output_locations_[arg_index].reserve(arg_node->out_edges().size()); |
| for (const Edge* e : arg_node->out_edges()) { |
| if (e->src_output() == Graph::kControlSlot) { |
| continue; |
| } else if (e->src_output() != 0) { |
| return errors::Internal("Invalid output index ", e->src_output(), |
| " from argument node ", arg_index); |
| } |
| arg_output_locations_[arg_index].push_back( |
| kernels_[node_to_index_map[e->dst()]].input_start_index + |
| e->dst_input()); |
| } |
| } |
| } |
| |
| // Build the mapping from each const tensor kernel to the input slot for the |
| // corresponding destination node. |
| for (size_t i = 0; i < const_tensor_kernels_.size(); ++i) { |
| Node* n = nodes_with_const_tensor_kernels[i]; |
| ConstTensorKernelState& kernel_state = const_tensor_kernels_[i]; |
| for (const Edge* e : n->out_edges()) { |
| if (e->src_output() == Graph::kControlSlot) { |
| continue; |
| } else if (e->src_output() != 0) { |
| return errors::Internal("Invalid output index ", e->src_output(), |
| " from node ", n->DebugString()); |
| } |
| kernel_state.output_locations.push_back( |
| kernels_[node_to_index_map[e->dst()]].input_start_index + |
| e->dst_input()); |
| } |
| |
| bool on_host = |
| kernel_state.kernel->output_memory_types()[0] == HOST_MEMORY; |
| kernel_state.output_alloc_attr.set_on_host(on_host); |
| } |
| |
| // Build the mapping from each node output to the input slot for the |
| // corresponding destination node. |
| for (size_t i = 0; i < kernels_.size(); ++i) { |
| Node* n = nodes_with_kernels[i]; |
| KernelState& kernel_state = kernels_[i]; |
| kernel_state.output_locations.resize(kernel_state.num_outputs); |
| for (const Edge* e : n->out_edges()) { |
| if (!e->IsControlEdge()) { |
| kernel_state.output_locations[e->src_output()].push_back( |
| kernels_[node_to_index_map[e->dst()]].input_start_index + |
| e->dst_input()); |
| } |
| } |
| |
| // Compute allocator attributes for each node output, and corresponding |
| // node input. |
| kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs); |
| AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data(); |
| |
| OpKernel* op_kernel = kernel_state.kernel; |
| for (int out = 0; out < n->num_outputs(); out++) { |
| DCHECK_LT(out, op_kernel->output_memory_types().size()); |
| bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY; |
| if (on_host) { |
| AllocatorAttributes h; |
| h.set_on_host(on_host); |
| attrs[out].Merge(h); |
| } |
| } |
| } |
| |
| if (!kernels_.empty()) { |
| const KernelState& last_kernel_state = kernels_.back(); |
| total_num_inputs_ = |
| last_kernel_state.input_start_index + last_kernel_state.num_inputs; |
| input_alloc_attrs_.resize(total_num_inputs_); |
| for (size_t i = 0; i < kernels_.size(); ++i) { |
| for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) { |
| for (size_t output_location : kernels_[i].output_locations[j]) { |
| input_alloc_attrs_[output_location] = |
| kernels_[i].output_alloc_attrs[j]; |
| } |
| } |
| } |
| } else { |
| total_num_inputs_ = 0; |
| } |
| return Status::OK(); |
| } |
| |
| Status Run(const Args& args) override { |
| // The inputs to each kernel are stored contiguously in `inputs`. |
| // |
| // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to |
| // determine the range of elements in this vector that correspond to |
| // the inputs of `kernels_[i]`. |
| // |
| // This vector has the following layout: |
| // |
| // * Kernel 0, input 0. |
| // * Kernel 0, input 1. |
| // * ... |
| // * Kernel 0, input `kernels_[0].num_inputs - 1`. |
| // * Kernel 1, input 0. |
| // * ... |
| // * Kernel 1, input `kernels_[1].num_inputs - 1`. |
| // * ... |
| // * Kernel `kernels_.size() - 1`, input 0. |
| // * ... |
| // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`. |
| // |
| // Note that kernels with zero inputs do not correspond to any elements in |
| // this vector. |
| // |
| // We use `ManualConstructor<Tensor>` to avoid the overhead of |
| // default-constructing an invalid `Tensor` for each slot at the beginning |
| // of execution: |
| // * Elements are initialized when the outputs of a kernel execution are |
| // propagated to the inputs of kernels that depend on them. |
| // * The elements corresponding to the inputs for kernel `i` are destroyed |
| // after kernel `i` executes. |
| // * In an error case (see below), we use the connectivity information in |
| // `KernelState::output_locations` to determine which locations have been |
| // initialized, and manually destroy them. |
| std::vector<Entry> inputs(total_num_inputs_); |
| |
| // TODO(mrry): Can we avoid copying into these vectors? Consider modifying |
| // OpKernelContext to take the TensorValueVec as a pointer into `inputs`. |
| TensorValueVec node_inputs; |
| AllocatorAttributeVec input_alloc_attrs; |
| |
| // Override intra op thread pool if requested. |
| Device* device = params_.device; |
| std::unique_ptr<Device> user_device; |
| if (args.user_intra_op_threadpool != nullptr) { |
| user_device = RenamedDevice::NewRenamedDevice( |
| device->name(), device, /*owns_underlying=*/false, |
| /*isolate_session_state=*/false, args.user_intra_op_threadpool); |
| device = user_device.get(); |
| } |
| |
| // Prepare the parameters that will be the same for all kernels. |
| OpKernelContext::Params params; |
| params.step_id = args.step_id; |
| params.device = device; |
| params.log_memory = false; // TODO(mrry): Too severe? |
| params.rendezvous = args.rendezvous; |
| params.session_state = args.session_state; |
| params.session_metadata = params_.session_metadata; |
| params.tensor_store = args.tensor_store; |
| params.cancellation_manager = args.cancellation_manager; |
| params.call_frame = args.call_frame; |
| params.function_library = params_.function_library; |
| params.resource_manager = device->resource_manager(); |
| params.step_container = args.step_container; |
| params.collective_executor = args.collective_executor; |
| params.stack_trace = args.stack_trace; |
| params.slice_reader_cache = nullptr; // TODO(mrry): Too severe? |
| params.inputs = &node_inputs; |
| params.input_alloc_attrs = &input_alloc_attrs; |
| |
| Args::Runner runner_copy = args.runner; |
| params.runner = &runner_copy; |
| params.run_all_kernels_inline = args.run_all_kernels_inline; |
| params.stats_collector = args.stats_collector; |
| params.executor_type = &kSingleThreadedExecutor; |
| |
| // NOTE(mrry): We are assuming that the graph is loopless and condless. |
| params.frame_iter = FrameAndIter(0, 0); |
| params.is_input_dead = false; |
| |
| device->TryGetDeviceContext(¶ms.op_device_context).IgnoreError(); |
| auto context_cleanup = gtl::MakeCleanup([¶ms] { |
| if (params.op_device_context != nullptr) { |
| params.op_device_context->Unref(); |
| } |
| }); |
| |
| // TODO(mrry): Consider implementing forwarding. |
| params.forward_from_array = nullptr; |
| |
| const size_t received_args = |
| args.call_frame ? args.call_frame->num_args() : 0; |
| if (TF_PREDICT_FALSE(arg_output_locations_.size() > received_args)) { |
| return errors::InvalidArgument("Expected ", arg_output_locations_.size(), |
| " arguments, but only received ", |
| received_args, "."); |
| } |
| |
| // ArgOp is a relatively expensive OpKernel due to the Tensor |
| // allocations that it performs. Therefore we specialize its implementation |
| // and forward arguments directly to the inputs of kernels that consume |
| // them. |
| for (size_t i = 0; i < arg_output_locations_.size(); ++i) { |
| const size_t num_destinations = arg_output_locations_[i].size(); |
| if (num_destinations > 0) { |
| if (args.call_frame->CanConsumeArg(i)) { |
| // The first destination input can consume the argument. |
| Entry& first_input = inputs[arg_output_locations_[i][0]]; |
| first_input.state = Entry::State::HAS_VALUE; |
| first_input.val.Init(); |
| args.call_frame->ConsumeArg(i, first_input.val.get()); |
| // All subsequent destination inputs get a shallow copy of the first |
| // destination input. |
| // |
| // NOTE: If we had metadata about which kernels might attempt to |
| // forward their input, we could arrange the kernel order so that |
| // one of those kernels was executed last. |
| for (size_t j = 1; j < num_destinations; ++j) { |
| Entry& input = inputs[arg_output_locations_[i][j]]; |
| input.state = Entry::State::HAS_VALUE; |
| input.val.Init(*first_input.val); |
| } |
| } else { |
| const Tensor* arg; |
| TF_RETURN_IF_ERROR(args.call_frame->GetArg(i, &arg)); |
| for (size_t j = 0; j < num_destinations; ++j) { |
| Entry& input = inputs[arg_output_locations_[i][j]]; |
| // NOTE: We must make at least one shallow copy of the argument |
| // tensor that remains live until all consuming kernels have |
| // executed, to keep the reference count > 1, and inhibit buffer |
| // forwarding. For simplicity, we shallow copy into the input entry |
| // for each consuming kernel. |
| input.state = Entry::State::HAS_VALUE; |
| input.val.Init(*arg); |
| } |
| } |
| } |
| } |
| |
| // Kernels that return a constant value (e.g. ConstOp) are relatively |
| // expensive due to the Tensor allocations that they perform. Therefore we |
| // specialize their implementation and forward their constant value directly |
| // to the inputs of kernels that consume them. |
| for (const ConstTensorKernelState& kernel_state : const_tensor_kernels_) { |
| for (size_t i = 0; i < kernel_state.output_locations.size(); ++i) { |
| Entry& input = inputs[kernel_state.output_locations[i]]; |
| input.state = Entry::State::HAS_CONST_TENSOR; |
| input.const_tensor = &kernel_state.const_tensor; |
| } |
| } |
| |
| // Execute the kernels one-at-a-time in topological order. |
| for (size_t i = 0; i < kernels_.size(); ++i) { |
| const KernelState& kernel_state = kernels_[i]; |
| |
| // Prepare the per-kernel parameters. |
| const size_t input_start_index = kernel_state.input_start_index; |
| const size_t num_inputs = kernel_state.num_inputs; |
| const size_t num_outputs = kernel_state.num_outputs; |
| |
| node_inputs.clear(); |
| node_inputs.resize(num_inputs); |
| input_alloc_attrs.clear(); |
| input_alloc_attrs.resize(num_inputs); |
| for (size_t j = 0; j < num_inputs; ++j) { |
| Entry& input = inputs[input_start_index + j]; |
| switch (input.state) { |
| case Entry::State::HAS_CONST_TENSOR: |
| // NOTE(mrry): This `const_cast` is necessary because `TensorValue` |
| // stores a non-const `Tensor*`, and relies on the `OpKernelContext` |
| // accessors making dynamic checks that prevent using an immutable |
| // tensor as a mutable tensor. |
| node_inputs[j].tensor = const_cast<Tensor*>(input.const_tensor); |
| break; |
| case Entry::State::HAS_VALUE: |
| node_inputs[j].tensor = input.val.get(); |
| break; |
| default: |
| DCHECK(false) << "Input did not have a valid value."; |
| } |
| input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j]; |
| } |
| params.op_kernel = kernel_state.kernel; |
| params.output_attr_array = kernel_state.output_alloc_attrs.data(); |
| OpKernelContext ctx(¶ms, num_outputs); |
| |
| // Actually execute the kernel. |
| device->Compute(kernel_state.kernel, &ctx); |
| TF_RETURN_IF_ERROR(ctx.status()); |
| |
| // Free the inputs to the current kernel. |
| for (size_t j = 0; j < num_inputs; ++j) { |
| inputs[input_start_index + j].ClearVal(); |
| } |
| |
| // Forward the outputs of the kernel to the inputs of subsequent kernels. |
| for (size_t j = 0; j < num_outputs; ++j) { |
| TensorValue val = ctx.release_output(j); |
| const size_t num_destinations = kernel_state.output_locations[j].size(); |
| if (num_destinations > 0) { |
| // TODO(mrry): Consider flattening the `output_locations` vector |
| // to improve the cache-friendliness of this loop. |
| for (size_t k = 0; k < num_destinations - 1; ++k) { |
| // TODO(mrry): Validate that the types match the expected values or |
| // ensure that the necessary validation has already happened. |
| Entry& input = inputs[kernel_state.output_locations[j][k]]; |
| input.state = Entry::State::HAS_VALUE; |
| if (val.tensor != nullptr) { |
| input.val.Init(*val.tensor); |
| } else { |
| input.val.Init(Tensor(kernel_state.kernel->output_type(j))); |
| } |
| } |
| // Move `arg` to the last consumer to avoid the cost of copying it. |
| Entry& input = |
| inputs[kernel_state.output_locations[j][num_destinations - 1]]; |
| input.state = Entry::State::HAS_VALUE; |
| if (val.tensor != nullptr) { |
| input.val.Init(std::move(*val.tensor)); |
| } else { |
| input.val.Init(Tensor(kernel_state.kernel->output_type(j))); |
| } |
| } |
| delete val.tensor; |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // Execute all operations in the calling thread when asynchronous execution |
| // is requested. Callers may expect to perform expensive work in the calling |
| // thread even when the execution itself is single-threaded. |
| // |
| // This also avoid stack-overflow issues with functional control flow. |
| void RunAsync(const Args& args, DoneCallback done) override { |
| args.runner([this, args, done]() { done(Run(args)); }); |
| } |
| |
| private: |
| const LocalExecutorParams params_; |
| |
| // All following members are read-only after Initialize(). |
| |
| // The sum of the number of inputs for each node in the graph. This determines |
| // the length of the flat `inputs` vector. See comment at the beginning of |
| // `RunAsync()` for details. |
| size_t total_num_inputs_; |
| |
| // Represents cached graph structure state for each kernel. |
| struct KernelState { |
| // The kernel object. Not owned. |
| // |
| // This pointer is managed by `params_.create_kernel()` and |
| // `params_.delete_kernel()`. |
| OpKernel* kernel; |
| |
| // These fields determine the range of elements in `inputs` that corresponds |
| // to the inputs of `kernel`. |
| size_t input_start_index; |
| size_t num_inputs; |
| |
| size_t num_outputs; |
| |
| // For the `j`th output of `kernel`, `output_locations[j]` contains the |
| // locations in the flat `inputs` vector to which that output must be |
| // copied. See comment at the beginning of `Run()` for details. |
| std::vector<std::vector<size_t>> |
| output_locations; // Length = `num_outputs`. |
| |
| // Memory space information for each output of `kernel`. |
| std::vector<AllocatorAttributes> |
| output_alloc_attrs; // Length = `num_outputs`. |
| }; |
| std::vector<KernelState> kernels_; |
| |
| // For the `i`th argument, `arg_output_locations_[i]` contains the locations |
| // in the flat `inputs` vector to which that argument must be copied. |
| std::vector<std::vector<size_t>> |
| arg_output_locations_; // Length = `num_args`. |
| |
| // Represents cached graph structure state for each kernel that produces |
| // a single constant-valued tensor. |
| struct ConstTensorKernelState { |
| // The kernel object. Not owned. |
| // |
| // This pointer is managed by `params_.create_kernel()` and |
| // `params_.delete_kernel()`. |
| OpKernel* kernel; |
| |
| // The cached value of `kernel->const_tensor()`. |
| // |
| // NOTE: We keep a `Tensor` rather than a `const Tensor*` here in order to |
| // keep the reference count on the underlying buffer above 1. Otherwise, a |
| // kernel could interpret the input as a forwardable tensor, and mutate the |
| // underlying constant tensor. |
| Tensor const_tensor; |
| |
| // For the single output of `kernel`, `output_locations` contains the |
| // locations in the flat `inputs` vector to which that output must be |
| // copied. See comment at the beginning of `Run()` for details. |
| std::vector<size_t> output_locations; // Length = `num_outputs`. |
| |
| // Memory space information for the single output of `kernel`. |
| AllocatorAttributes output_alloc_attr; |
| }; |
| std::vector<ConstTensorKernelState> const_tensor_kernels_; |
| |
| // Memory space information for each input. This information is stored in the |
| // same order as the flat `inputs` vector. See comment at the beginning of |
| // `RunAsync()` for details. |
| std::vector<AllocatorAttributes> |
| input_alloc_attrs_; // Length = `total_num_inputs_`. |
| }; |
| |
| class SingleThreadedExecutorRegistrar { |
| public: |
| SingleThreadedExecutorRegistrar() { |
| ExecutorFactory::Register(kSingleThreadedExecutor, new Factory()); |
| } |
| |
| private: |
| class Factory : public ExecutorFactory { |
| Status NewExecutor(const LocalExecutorParams& params, const Graph& graph, |
| std::unique_ptr<Executor>* out_executor) override { |
| Executor* ret; |
| TF_RETURN_IF_ERROR(NewSingleThreadedExecutor(params, graph, &ret)); |
| out_executor->reset(ret); |
| return Status::OK(); |
| } |
| }; |
| }; |
| static SingleThreadedExecutorRegistrar registrar; |
| |
| } // namespace |
| |
| Status NewSingleThreadedExecutor(const LocalExecutorParams& params, |
| const Graph& graph, Executor** executor) { |
| auto impl = absl::make_unique<SingleThreadedExecutorImpl>(params); |
| TF_RETURN_IF_ERROR(impl->Initialize(graph)); |
| *executor = impl.release(); |
| return Status::OK(); |
| } |
| |
| } // namespace tensorflow |