| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/common_runtime/executor.h" |
| |
| #include <atomic> |
| #include <memory> |
| #include <vector> |
| |
| #include "absl/memory/memory.h" |
| #include "tensorflow/core/common_runtime/costmodel_manager.h" |
| #include "tensorflow/core/common_runtime/entry.h" |
| #include "tensorflow/core/common_runtime/executor_factory.h" |
| #include "tensorflow/core/common_runtime/graph_view.h" |
| #include "tensorflow/core/common_runtime/immutable_executor_state.h" |
| #include "tensorflow/core/common_runtime/pending_counts.h" |
| #include "tensorflow/core/common_runtime/propagator_state.h" |
| #include "tensorflow/core/common_runtime/renamed_device.h" |
| #include "tensorflow/core/common_runtime/step_stats_collector.h" |
| #include "tensorflow/core/framework/allocator.h" |
| #include "tensorflow/core/framework/cancellation.h" |
| #include "tensorflow/core/framework/collective.h" |
| #include "tensorflow/core/framework/control_flow.h" |
| #include "tensorflow/core/framework/device_attributes.pb.h" |
| #include "tensorflow/core/framework/log_memory.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/op_segment.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_reference.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/graph/edgeset.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/graph/graph_node_util.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/notification.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/core/threadpool.h" |
| #include "tensorflow/core/lib/gtl/flatmap.h" |
| #include "tensorflow/core/lib/gtl/inlined_vector.h" |
| #include "tensorflow/core/lib/gtl/manual_constructor.h" |
| #include "tensorflow/core/lib/hash/hash.h" |
| #include "tensorflow/core/platform/context.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/macros.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/profile_utils/cpu_utils.h" |
| #include "tensorflow/core/platform/thread_annotations.h" |
| #include "tensorflow/core/platform/tracing.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/profiler/lib/annotated_traceme.h" |
| #include "tensorflow/core/profiler/lib/scoped_annotation.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/util/tensor_slice_reader_cache.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| // 1-D, 0 element tensor. |
| static const Tensor* const kEmptyTensor = new Tensor; |
| |
| // Helper routines for collecting step stats. |
| namespace nodestats { |
| inline int64 NowInNsec() { return EnvTime::NowNanos(); } |
| |
| void SetScheduled(NodeExecStatsInterface* stats, int64 micros) { |
| if (!stats) return; |
| stats->SetScheduled(micros * EnvTime::kMicrosToNanos); |
| } |
| |
| void SetAllStart(NodeExecStatsInterface* stats) { |
| if (!stats) return; |
| stats->RecordExecutorStarted(); |
| } |
| |
| void SetOpStart(NodeExecStatsInterface* stats) { |
| if (!stats) return; |
| stats->RecordComputeStarted(); |
| } |
| |
| void SetOpEnd(NodeExecStatsInterface* stats) { |
| if (!stats) return; |
| stats->RecordComputeEnded(); |
| } |
| |
| void SetAllEnd(NodeExecStatsInterface* stats) { |
| if (!stats) return; |
| stats->RecordExecutorEnded(); |
| } |
| |
| void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) { |
| if (!stats) return; |
| stats->SetOutput(slot, v); |
| } |
| |
| void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) { |
| if (!stats) return; |
| stats->SetMemory(ctx); |
| } |
| |
| } // namespace nodestats |
| |
| // Time the execution of kernels (in CPU cycles). Used to dynamically identify |
| // inexpensive kernels which can be dispatched inline. |
| struct KernelTimer { |
| uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle(); |
| |
| uint64 ElapsedCycles() { |
| return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles; |
| } |
| }; |
| |
| // TODO(b/152925936): Re-evaluate these constants with current usage patterns. |
| typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; |
| typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; |
| |
| class ExecutorImpl : public Executor { |
| public: |
| explicit ExecutorImpl(const LocalExecutorParams& p) : immutable_state_(p) {} |
| |
| Status Initialize(const Graph& graph) { |
| TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph)); |
| kernel_stats_.Initialize(immutable_state_.graph_view()); |
| return Status::OK(); |
| } |
| |
| void RunAsync(const Args& args, DoneCallback done) override; |
| |
| private: |
| template <class PropagatorStateType> |
| friend class ExecutorState; |
| |
| // Stores execution time information about the kernels in an executor's graph. |
| class KernelStats { |
| public: |
| KernelStats() = default; |
| |
| void Initialize(const GraphView& gview) { |
| is_expensive_ = absl::make_unique<std::atomic<bool>[]>(gview.num_nodes()); |
| cost_estimates_ = |
| absl::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes()); |
| for (int32 i = 0; i < gview.num_nodes(); ++i) { |
| if (gview.node(i)) { |
| is_expensive_[i] = |
| gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive(); |
| cost_estimates_[i] = kInitialCostEstimateCycles; |
| } |
| } |
| } |
| |
| // Returns true iff the given node is considered "expensive". The |
| // executor uses this flag to optimize graph execution, for example |
| // by "inlining" inexpensive kernels. |
| bool IsExpensive(const NodeItem& node) const { |
| return is_expensive_[node.node_id] && |
| (cost_estimates_[node.node_id].load(std::memory_order_relaxed) > |
| kOpIsExpensiveThresholdCycles); |
| } |
| |
| // Updates the dynamic cost estimate, which is used to determine whether the |
| // given node is expensive. The new cost estimate is a weighted average of |
| // the old cost estimate and the latest cost. |
| // |
| // NOTE: We currently only expect updates to the cost estimate when |
| // `is_expensive_[node.node_id]` is true (or at least, it *was* true, when |
| // we started to execute the kernel. As a result, we expect that a kernel |
| // can only ever transition from "expensive" to "inexpensive", but not vice |
| // versa. |
| void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) { |
| // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous |
| // updates may result in one or more updates being ignored. This does not |
| // affect correctness but may slow down the update frequency. |
| std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id]; |
| uint64 new_estimate = (kCostDecay - 1) * |
| cost_estimate.load(std::memory_order_relaxed) / |
| kCostDecay + |
| (elapsed_cycles / kCostDecay); |
| cost_estimate.store(new_estimate, std::memory_order_relaxed); |
| if (new_estimate < kOpIsExpensiveThresholdCycles) { |
| is_expensive_[node.node_id].store(false, std::memory_order_relaxed); |
| } |
| } |
| |
| private: |
| // Initial time (in CPU cycles) we expect an operation to take. Used to |
| // determine whether an operation should be place in a threadpool. |
| // Operations start out "expensive". |
| static const uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000; |
| static const uint64 kOpIsExpensiveThresholdCycles = 5000; |
| static const uint64 kCostDecay = 10; |
| |
| std::unique_ptr<std::atomic<bool>[]> is_expensive_; |
| std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_; |
| }; |
| |
| ImmutableExecutorState immutable_state_; |
| KernelStats kernel_stats_; |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl); |
| }; |
| |
| // The state associated with one invocation of ExecutorImpl::Run. |
| // |
| // ExecutorState dispatches nodes when they become ready, and delegates to an |
| // instance of `PropagatorStateType` to keep track of how many predecessors of a |
| // are still pending. |
| // |
| // The template argument `class PropagatorStateType` must define the following |
| // public members: |
| // * A type `TaggedNode`, representing a node to be processed, with public |
| // members: |
| // * `const NodeItem& get_node_item() const` |
| // * `bool get_is_dead() const` |
| // * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be |
| // processed, with public members (having the same meanings as in an |
| // `std::vector<TaggedNode>`): |
| // * `void push_back(const TaggedNode& node)` |
| // * `TaggedNode front() const` |
| // * `void pop_front()` |
| // * `bool empty() const` |
| // * A type `TaggedNodeSeq`, representing a list of nodes to be schedules, with |
| // public members (having the same meanings as in an |
| // `std::vector<TaggedNode>`): |
| // * `size_t size() const` |
| // * `bool empty() const` |
| // * `void clear()` |
| // * `const_iterator begin() const` |
| // * `const_iterator end() const` |
| // * A public constructor, `PropagatorStateType(const ImmutableExecutorState& |
| // immutable_state, int64 step_id)`. |
| // * The following public methods: |
| // * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, |
| // TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the |
| // nodes in `roots` and adds them to `*ready` |
| // * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* |
| // outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the |
| // given `tagged_node` to the destinations of its output edges, and adds |
| // any newly runnable nodes to `*ready` |
| // * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which |
| // returns a pointer to the input tensors for the given `tagged_node` |
| // * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`, |
| // which creates a `FrameAndIter` for the given `tagged_node` |
| // * `void DumpState()`, which dumps the dynamic state of the executing graph |
| // * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records |
| // that a node has started |
| // * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records |
| // that a node has completed |
| // |
| // See `PropagatorState` in "./propagator_state.h" for an example of a type that |
| // can be used to instantiate `PropagatorStateType`. |
| template <class PropagatorStateType> |
| class ExecutorState { |
| public: |
| ExecutorState(const Executor::Args& args, |
| const ImmutableExecutorState& immutable_state_, |
| ExecutorImpl::KernelStats* kernel_stats_); |
| ~ExecutorState(); |
| |
| void RunAsync(Executor::DoneCallback done); |
| |
| private: |
| // Use `TaggedNode` types defined by `PropagatorStateType`. |
| typedef typename PropagatorStateType::TaggedNode TaggedNode; |
| typedef |
| typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue; |
| typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq; |
| |
| struct AsyncState; |
| |
| // Process a ready node in current thread. |
| void Process(TaggedNode node, int64 scheduled_nsec); |
| |
| Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params, |
| EntryVector* outputs, NodeExecStatsInterface* stats); |
| void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params, |
| const TaggedNode& tagged_node, Entry* first_input, |
| NodeExecStatsInterface* stats); |
| void ProcessNoop(NodeExecStatsInterface* stats); |
| void ProcessConstTensor(const NodeItem& item, EntryVector* outputs, |
| NodeExecStatsInterface* stats); |
| |
| // Before invoking item->kernel, fills in its "inputs". |
| Status PrepareInputs(const NodeItem& item, Entry* first_input, |
| TensorValueVec* inputs, |
| AllocatorAttributeVec* input_alloc_attrs, |
| bool* is_input_dead); |
| |
| // After item->kernel computation is done, processes its outputs. |
| Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, |
| EntryVector* outputs, NodeExecStatsInterface* stats); |
| |
| // Called after each node finishes. Takes ownership of "stats". Returns true |
| // if execution has completed. |
| // |
| // This method will clear `*ready` before returning. |
| bool NodeDone(const Status& s, TaggedNodeSeq* ready, |
| NodeExecStatsInterface* stats, |
| TaggedNodeReadyQueue* inline_ready); |
| |
| // Schedule all the expensive nodes in '*ready', and put all the inexpensive |
| // nodes in 'ready' into 'inline_ready'. |
| // |
| // This method will clear `*ready` before returning. |
| void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready); |
| |
| // Clean up when this executor is done. |
| void Finish(); |
| void ScheduleFinish(); |
| |
| // Contains the device context assigned by the device at the beginning of a |
| // step. |
| DeviceContext* device_context_ = nullptr; |
| |
| const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply. |
| |
| // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply. |
| const bool log_memory_; |
| |
| int64 step_id_; |
| // Not owned. |
| RendezvousInterface* rendezvous_; |
| CollectiveExecutor* collective_executor_ = nullptr; |
| SessionState* session_state_; |
| string session_handle_; |
| const SessionMetadata* session_metadata_ = nullptr; |
| TensorStore* tensor_store_; |
| // Step-local container. |
| ScopedStepContainer* step_container_; |
| StepStatsCollectorInterface* const stats_collector_; |
| const tracing::EventCollector* const event_collector_; |
| Context context_; |
| |
| // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper |
| // instead of a pointer? (avoids having to delete). |
| checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; |
| CallFrameInterface* call_frame_; |
| const ImmutableExecutorState& immutable_state_; |
| ExecutorImpl::KernelStats* const kernel_stats_; |
| CancellationManager* cancellation_manager_; |
| // If not null, use this device to schedule intra-op operation |
| std::unique_ptr<DeviceBase> user_device_; |
| Executor::Args::Runner runner_; |
| bool sync_on_finish_; |
| const bool run_all_kernels_inline_; |
| |
| PropagatorStateType propagator_; |
| |
| // Invoked when the execution finishes. |
| Executor::DoneCallback done_cb_; |
| |
| std::atomic_int_fast32_t num_outstanding_ops_; |
| |
| // Available via OpKernelContext to every OpKernel invocation. |
| mutex num_deferred_ops_mu_; |
| int64 num_deferred_ops_ TF_GUARDED_BY(num_deferred_ops_mu_) = 0; |
| bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) = |
| false; |
| |
| mutex mu_; |
| Status status_ TF_GUARDED_BY(mu_); |
| }; |
| |
| template <class PropagatorStateType> |
| ExecutorState<PropagatorStateType>::ExecutorState( |
| const Executor::Args& args, const ImmutableExecutorState& immutable_state, |
| ExecutorImpl::KernelStats* kernel_stats) |
| : vlog_(VLOG_IS_ON(1)), |
| log_memory_(LogMemory::IsEnabled()), |
| step_id_(args.step_id), |
| rendezvous_(args.rendezvous), |
| collective_executor_(args.collective_executor), |
| session_state_(args.session_state), |
| session_handle_(args.session_handle), |
| session_metadata_(immutable_state.params().session_metadata), |
| tensor_store_(args.tensor_store), |
| step_container_(args.step_container), |
| stats_collector_(args.stats_collector), |
| event_collector_( |
| tracing::GetEventCollector(tracing::EventCategory::kCompute)), |
| context_(ContextKind::kThread), |
| slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), |
| call_frame_(args.call_frame), |
| immutable_state_(immutable_state), |
| kernel_stats_(kernel_stats), |
| cancellation_manager_(args.cancellation_manager), |
| runner_(args.runner), |
| sync_on_finish_(args.sync_on_finish), |
| run_all_kernels_inline_(args.run_all_kernels_inline), |
| propagator_(immutable_state, step_id_), |
| num_outstanding_ops_(0) { |
| if (args.user_intra_op_threadpool != nullptr) { |
| Device* device = immutable_state_.params().device; |
| user_device_ = RenamedDevice::NewRenamedDevice( |
| device->name(), device, false, false, args.user_intra_op_threadpool); |
| } |
| } |
| |
| template <class PropagatorStateType> |
| ExecutorState<PropagatorStateType>::~ExecutorState() { |
| if (device_context_) { |
| device_context_->Unref(); |
| } |
| delete slice_reader_cache_; |
| } |
| |
| template <class PropagatorStateType> |
| void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) { |
| TaggedNodeSeq ready; |
| |
| // Ask the device to fill in the device context map. |
| Device* device = immutable_state_.params().device; |
| const Status get_context_status = |
| device->TryGetDeviceContext(&device_context_); |
| if (!get_context_status.ok()) { |
| delete this; |
| done(get_context_status); |
| return; |
| } |
| |
| // Initialize the ready queue. |
| ready.reserve(immutable_state_.root_nodes().size()); |
| propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready); |
| num_outstanding_ops_ = ready.size(); |
| if (ready.empty()) { |
| delete this; |
| done(Status::OK()); |
| } else { |
| done_cb_ = std::move(done); |
| // Schedule to run all the ready ops in thread pool. |
| ScheduleReady(&ready, nullptr); |
| } |
| } |
| |
| // State kept alive for executing an asynchronous node in another |
| // thread. NOTE: We need to make a copy of p.input and p.input_alloc_attrs for |
| // asynchronous kernels because OpKernelContext methods like input_type(i) needs |
| // the param points to valid input type vector. It's not an issue for |
| // sync kernels because these vectors are kept on the stack. |
| template <class PropagatorStateType> |
| struct ExecutorState<PropagatorStateType>::AsyncState { |
| AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node, |
| const NodeItem* _item, Entry* _first_input, |
| NodeExecStatsInterface* _stats) |
| : saved_inputs(*p.inputs), |
| saved_input_alloc_attrs(*p.input_alloc_attrs), |
| params(p), |
| tagged_node(_tagged_node), |
| item(_item), |
| first_input(_first_input), |
| // ParamsButClearingEigenGPUDevice does equivalent of |
| // params.eigen_gpu_device = nullptr; |
| ctx(ParamsButClearingEigenGPUDevice(¶ms), item->num_outputs), |
| stats(_stats) { |
| params.inputs = &saved_inputs; |
| params.input_alloc_attrs = &saved_input_alloc_attrs; |
| } |
| |
| TensorValueVec saved_inputs; |
| AllocatorAttributeVec saved_input_alloc_attrs; |
| OpKernelContext::Params params; |
| TaggedNode tagged_node; |
| const NodeItem* item; |
| Entry* first_input; |
| OpKernelContext ctx; |
| NodeExecStatsInterface* stats; |
| |
| private: |
| OpKernelContext::Params* ParamsButClearingEigenGPUDevice( |
| OpKernelContext::Params* p) { |
| // Ensure OpKernelContext constructor will make a new eigen GPU device if |
| // necessary. |
| p->eigen_gpu_device = nullptr; // Force allocation |
| return p; |
| } |
| }; |
| |
| // Returns true if `item` might be traced by the given trace and event |
| // collectors. Returns false only if `item` definitely will not be traced. |
| bool MightTrace(const tracing::EventCollector* event_collector, |
| bool is_expensive) { |
| // Tracing will only be enabled if either `event_collector` is non null, |
| // or `trace_collector` is non-null and enabled for this particular kernel. |
| // Although `profiler::TraceMe`, `profiler::ScopedAnnotation`, and |
| // `tracing::ScopedRegion` check subsets of these properties internally in |
| // their constructors, the cost of passing the necessary arguments to them can |
| // be significant, so we avoid constructing them in the common case (when we |
| // know they will not be used). |
| if (event_collector != nullptr) { |
| return true; |
| } |
| |
| if (profiler::ScopedAnnotation::IsEnabled()) return true; |
| |
| return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive)); |
| } |
| |
| template <class PropagatorStateType> |
| Status ExecutorState<PropagatorStateType>::ProcessSync( |
| const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs, |
| NodeExecStatsInterface* stats) { |
| Status s; |
| OpKernelContext ctx(params, item.num_outputs); |
| nodestats::SetOpStart(stats); |
| |
| OpKernel* op_kernel = item.kernel; |
| Device* device = immutable_state_.params().device; |
| const bool is_expensive = kernel_stats_->IsExpensive(item); |
| |
| if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) { |
| tracing::ScopedRegion region(tracing::EventCategory::kCompute, |
| op_kernel->name_view()); |
| profiler::AnnotatedTraceMe activity( |
| [&] { |
| return op_kernel->TraceString( |
| &ctx, /*verbose=*/profiler::TfOpDetailsEnabled()); |
| }, |
| profiler::GetTFTraceMeLevel(is_expensive)); |
| device->Compute(op_kernel, &ctx); |
| nodestats::SetOpEnd(stats); |
| s = ProcessOutputs(item, &ctx, outputs, stats); |
| } else { |
| // In the common case, avoid creating any tracing objects. |
| if (is_expensive) { |
| KernelTimer timer; |
| device->Compute(op_kernel, &ctx); |
| kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles()); |
| } else { |
| device->Compute(op_kernel, &ctx); |
| } |
| nodestats::SetOpEnd(stats); |
| s = ProcessOutputs(item, &ctx, outputs, stats); |
| } |
| nodestats::SetMemory(stats, &ctx); |
| return s; |
| } |
| |
| template <class PropagatorStateType> |
| void ExecutorState<PropagatorStateType>::ProcessAsync( |
| const NodeItem& item, const OpKernelContext::Params& params, |
| const TaggedNode& tagged_node, Entry* first_input, |
| NodeExecStatsInterface* stats) { |
| AsyncOpKernel* async_kernel = item.kernel->AsAsync(); |
| DCHECK(async_kernel != nullptr); |
| AsyncState* state = |
| new AsyncState(params, tagged_node, &item, first_input, stats); |
| |
| auto done = [this, state]() { |
| Device* device = immutable_state_.params().device; |
| NodeExecStatsInterface* stats = state->stats; // Shorthand |
| Entry* first_input = state->first_input; // Shorthand |
| |
| nodestats::SetOpEnd(stats); |
| EntryVector outputs; |
| Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats); |
| nodestats::SetMemory(stats, &state->ctx); |
| if (vlog_) { |
| VLOG(2) << "Async kernel done: " << state->item->node_id << " step " |
| << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def()) |
| << (state->tagged_node.get_is_dead() ? " is dead" : "") |
| << " device: " << device->name(); |
| } |
| |
| // Clears inputs. |
| const int num_inputs = state->item->num_inputs; |
| for (int i = 0; i < num_inputs; ++i) { |
| (first_input + i)->ClearVal(); |
| } |
| propagator_.MaybeMarkCompleted(state->tagged_node); |
| TaggedNodeSeq ready; |
| if (s.ok()) { |
| propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready); |
| } |
| outputs.clear(); |
| const bool completed = NodeDone(s, &ready, stats, nullptr); |
| delete state; |
| if (completed) ScheduleFinish(); |
| }; |
| nodestats::SetOpStart(stats); |
| { |
| profiler::AnnotatedTraceMe activity( |
| [&] { |
| return async_kernel->TraceString( |
| &state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled()); |
| }, |
| profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item))); |
| immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx, |
| std::move(done)); |
| } |
| } |
| |
| template <class PropagatorStateType> |
| void ExecutorState<PropagatorStateType>::ProcessNoop( |
| NodeExecStatsInterface* stats) { |
| nodestats::SetOpStart(stats); |
| nodestats::SetOpEnd(stats); |
| } |
| |
| template <class PropagatorStateType> |
| void ExecutorState<PropagatorStateType>::ProcessConstTensor( |
| const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) { |
| nodestats::SetOpStart(stats); |
| nodestats::SetOpEnd(stats); |
| outputs->resize(1); |
| Entry& output = (*outputs)[0]; |
| output.state = Entry::State::HAS_CONST_TENSOR; |
| output.const_tensor = item.const_tensor; |
| output.alloc_attr = item.output_attrs()[0]; |
| } |
| |
| template <class PropagatorStateType> |
| void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node, |
| int64 scheduled_nsec) { |
| profiler::TraceMe activity( |
| [&] { return absl::StrCat("ExecutorState::Process#id=", step_id_, "#"); }, |
| 2); |
| WithContext wc(context_); |
| TaggedNodeSeq ready; |
| TaggedNodeReadyQueue inline_ready; |
| |
| // Parameters passed to OpKernel::Compute. |
| TensorValueVec inputs; |
| AllocatorAttributeVec input_alloc_attrs; |
| |
| OpKernelContext::Params params; |
| params.step_id = step_id_; |
| // Override device's threadpool if user provides an intra_op_threadpool |
| Device* device = immutable_state_.params().device; |
| if (user_device_) { |
| params.device = user_device_.get(); |
| } else { |
| params.device = device; |
| } |
| params.log_memory = log_memory_; |
| params.rendezvous = rendezvous_; |
| params.collective_executor = collective_executor_; |
| params.session_state = session_state_; |
| params.session_handle = session_handle_; |
| params.session_metadata = session_metadata_; |
| params.tensor_store = tensor_store_; |
| params.cancellation_manager = cancellation_manager_; |
| params.call_frame = call_frame_; |
| params.function_library = immutable_state_.params().function_library; |
| params.resource_manager = device->resource_manager(); |
| params.step_container = step_container_; |
| params.slice_reader_cache = slice_reader_cache_; |
| params.inputs = &inputs; |
| params.input_alloc_attrs = &input_alloc_attrs; |
| params.runner = &runner_; |
| params.run_all_kernels_inline = run_all_kernels_inline_; |
| params.stats_collector = stats_collector_; |
| params.inc_num_deferred_ops_function = [this]() { |
| mutex_lock lock(num_deferred_ops_mu_); |
| num_deferred_ops_++; |
| }; |
| params.dec_num_deferred_ops_function = [this]() { |
| bool finish_when_deferred_ops_done = false; |
| { |
| mutex_lock lock(num_deferred_ops_mu_); |
| num_deferred_ops_--; |
| if (num_deferred_ops_ == 0) { |
| finish_when_deferred_ops_done = finish_when_deferred_ops_done_; |
| } |
| } |
| // Invoke Finish if the graph processing has completed. Finish is always |
| // called exactly once per ExecutorState, either here if there are any |
| // deferred ops, or in ScheduleFinish if there aren't any deferred ops. |
| if (finish_when_deferred_ops_done) Finish(); |
| }; |
| |
| // Set the device_context for this device, if it exists. |
| params.op_device_context = device_context_; |
| |
| Status s; |
| NodeExecStatsInterface* stats = nullptr; |
| |
| EntryVector outputs; |
| bool completed = false; |
| inline_ready.push_back(tagged_node); |
| while (!inline_ready.empty()) { |
| tagged_node = inline_ready.front(); |
| inline_ready.pop_front(); |
| const NodeItem& item = tagged_node.get_node_item(); |
| const int id = item.node_id; |
| |
| propagator_.MaybeMarkStarted(tagged_node); |
| |
| params.track_allocations = false; |
| stats = nullptr; |
| if (stats_collector_ && !tagged_node.get_is_dead()) { |
| stats = stats_collector_->CreateNodeExecStats(&item.kernel->def()); |
| // Track allocations if and only if we are collecting statistics, and |
| // `stats` object is expecting allocations to be tracked. |
| params.track_allocations = stats ? stats->TrackAllocations() : false; |
| nodestats::SetScheduled(stats, scheduled_nsec); |
| nodestats::SetAllStart(stats); |
| } |
| |
| if (vlog_) { |
| VLOG(1) << "Process node: " << id << " step " << params.step_id << " " |
| << SummarizeNodeDef(item.kernel->def()) |
| << (tagged_node.get_is_dead() ? " is dead" : "") |
| << " device: " << device->name(); |
| } |
| |
| Entry* first_input = propagator_.GetInputTensors(tagged_node); |
| outputs.clear(); |
| |
| // Only execute this node if it is not dead or it is a send/recv |
| // transfer node. For transfer nodes, we need to propagate the "dead" |
| // bit even when the node is dead. |
| bool launched_asynchronously = false; |
| if (tagged_node.get_is_dead() && !item.is_transfer_node) { |
| outputs.resize(item.num_outputs); |
| } else if (TF_PREDICT_FALSE(item.is_noop)) { |
| ProcessNoop(stats); |
| } else if (item.const_tensor != nullptr && !params.track_allocations) { |
| ProcessConstTensor(item, &outputs, stats); |
| } else { |
| // Prepares inputs. |
| bool is_input_dead = false; |
| s = PrepareInputs(item, first_input, &inputs, &input_alloc_attrs, |
| &is_input_dead); |
| if (!s.ok()) { |
| // Clear inputs. |
| const int num_inputs = item.num_inputs; |
| for (int i = 0; i < num_inputs; ++i) { |
| (first_input + i)->ClearVal(); |
| } |
| propagator_.MaybeMarkCompleted(tagged_node); |
| // Continue to process the nodes in 'inline_ready'. |
| completed = NodeDone(s, &ready, stats, &inline_ready); |
| continue; |
| } |
| |
| // Set up compute params. |
| params.op_kernel = item.kernel; |
| params.frame_iter = propagator_.GetFrameAndIter(tagged_node); |
| params.is_input_dead = is_input_dead; |
| params.output_attr_array = item.output_attrs(); |
| params.forward_from_array = item.forward_from(); |
| params.outputs_required_array = item.outputs_required.get(); |
| |
| if (item.kernel_is_async) { |
| ProcessAsync(item, params, tagged_node, first_input, stats); |
| launched_asynchronously = true; |
| } else { |
| s = ProcessSync(item, ¶ms, &outputs, stats); |
| } |
| } |
| |
| if (!launched_asynchronously) { |
| if (vlog_) { |
| VLOG(2) << "Synchronous kernel done: " << id << " step " |
| << params.step_id << " " << SummarizeNodeDef(item.kernel->def()) |
| << (tagged_node.get_is_dead() ? " is dead: " : "") |
| << " device: " << device->name(); |
| } |
| |
| // Clears inputs. |
| const int num_inputs = item.num_inputs; |
| for (int i = 0; i < num_inputs; ++i) { |
| (first_input + i)->ClearVal(); |
| } |
| propagator_.MaybeMarkCompleted(tagged_node); |
| // Propagates outputs. |
| if (s.ok()) { |
| propagator_.PropagateOutputs(tagged_node, &outputs, &ready); |
| } |
| outputs.clear(); |
| if (stats) { |
| scheduled_nsec = nodestats::NowInNsec(); |
| } |
| // Postprocess. |
| completed = NodeDone(s, &ready, stats, &inline_ready); |
| } |
| } // while !inline_ready.empty() |
| |
| // This thread of computation is done if completed = true. |
| if (completed) ScheduleFinish(); |
| } |
| |
| template <class PropagatorStateType> |
| Status ExecutorState<PropagatorStateType>::PrepareInputs( |
| const NodeItem& item, Entry* first_input, TensorValueVec* inputs, |
| AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) { |
| inputs->clear(); |
| inputs->resize(item.num_inputs); |
| input_alloc_attrs->clear(); |
| input_alloc_attrs->resize(item.num_inputs); |
| |
| *is_input_dead = false; |
| |
| bool is_merge = item.is_merge; |
| for (int i = 0; i < item.num_inputs; ++i) { |
| const bool expect_ref = IsRefType(item.input_type(i)); |
| Entry* entry = first_input + i; |
| (*input_alloc_attrs)[i] = entry->alloc_attr; |
| |
| // i-th input. |
| TensorValue* inp = &(*inputs)[i]; |
| |
| switch (entry->state) { |
| case Entry::State::NO_VALUE: { |
| // Only merge and transfer nodes can have no-value inputs. |
| if (!is_merge) { |
| DCHECK(item.is_transfer_node) |
| << item.kernel->name() << " - input " << i; |
| entry->state = Entry::State::HAS_CONST_TENSOR; |
| entry->const_tensor = kEmptyTensor; |
| // NOTE(mrry): This `const_cast` is necessary because `TensorValue` |
| // stores a non-const `Tensor*`, and relies on the `OpKernelContext` |
| // accessors making dynamic checks that prevent using an immutable |
| // tensor as a mutable tensor. |
| inp->tensor = const_cast<Tensor*>(kEmptyTensor); |
| *is_input_dead = true; |
| } |
| break; |
| } |
| |
| case Entry::State::HAS_VALUE: { |
| if (expect_ref) { |
| return AttachDef( |
| errors::InvalidArgument(i, "-th input expects a ref type"), |
| item.kernel->def()); |
| } |
| inp->tensor = entry->val.get(); |
| break; |
| } |
| |
| case Entry::State::HAS_CONST_TENSOR: { |
| if (expect_ref) { |
| return AttachDef( |
| errors::InvalidArgument(i, "-th input expects a ref type"), |
| item.kernel->def()); |
| } |
| // NOTE(mrry): This `const_cast` is necessary because `TensorValue` |
| // stores a non-const `Tensor*`, and relies on the `OpKernelContext` |
| // accessors making dynamic checks that prevent using an immutable |
| // tensor as a mutable tensor. |
| inp->tensor = const_cast<Tensor*>(entry->const_tensor); |
| break; |
| } |
| |
| case Entry::State::HAS_REF_TENSOR: { |
| { |
| tf_shared_lock ml(*entry->ref_tensor.mu); |
| if (!entry->ref_tensor.tensor->IsInitialized() && |
| !item.is_initialization_op) { |
| return AttachDef(errors::FailedPrecondition( |
| "Attempting to use uninitialized value ", |
| item.kernel->requested_input(i)), |
| item.kernel->def()); |
| } |
| } |
| |
| if (expect_ref) { |
| inp->mutex_if_ref = entry->ref_tensor.mu; |
| inp->tensor = entry->ref_tensor.tensor; |
| } else { |
| // Automatically deref the tensor ref when the op expects a |
| // tensor but is given a ref to a tensor. Need to deref it |
| // under the mutex. |
| { |
| mutex* ref_mu = entry->ref_tensor.mu; |
| Tensor* ref_tensor = entry->ref_tensor.tensor; |
| tf_shared_lock l(*ref_mu); |
| entry->val.Init(*ref_tensor); |
| } |
| entry->state = Entry::State::HAS_VALUE; |
| |
| inp->tensor = entry->val.get(); |
| // The dtype of entry->ref_tensor.tensor could have been changed by |
| // another operation that ran after the operation that "produced" it |
| // executed, so re-validate that the type of the dereferenced tensor |
| // matches the expected input type. |
| if (item.input_type(i) != inp->tensor->dtype()) { |
| return AttachDef( |
| errors::InvalidArgument( |
| i, "-th input expects type ", |
| DataTypeString(item.input_type(i)), |
| " but automatically dereferenced input tensor has type ", |
| DataTypeString(inp->tensor->dtype())), |
| item.kernel->def()); |
| } |
| } |
| break; |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| template <class PropagatorStateType> |
| Status ExecutorState<PropagatorStateType>::ProcessOutputs( |
| const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs, |
| NodeExecStatsInterface* stats) { |
| DCHECK_EQ(0, outputs->size()); |
| outputs->resize(item.num_outputs); |
| |
| Status s = ctx->status(); |
| if (!s.ok()) { |
| s = AttachDef(s, item.kernel->def()); |
| // TODO(misard) Replace with a finer-grain enabling flag once we |
| // add better optional debugging support. |
| if (vlog_ && VLOG_IS_ON(1)) { |
| LOG(WARNING) << this << " Compute status: " << s; |
| propagator_.DumpState(); |
| } |
| if (s.code() == error::RESOURCE_EXHAUSTED) { |
| if (stats_collector_) { |
| string err = stats_collector_->ReportAllocsOnResourceExhausted( |
| s.error_message()); |
| s = Status(s.code(), strings::StrCat(s.error_message(), err)); |
| } else { |
| s = Status( |
| s.code(), |
| strings::StrCat( |
| s.error_message(), |
| "\nHint: If you want to see a list of allocated tensors when " |
| "OOM happens, add report_tensor_allocations_upon_oom " |
| "to RunOptions for current allocation info.\n")); |
| } |
| } |
| return s; |
| } |
| |
| for (int i = 0; i < item.num_outputs; ++i) { |
| const TensorValue val = ctx->release_output(i); |
| if (val.tensor == nullptr) { |
| // Unless it's a Switch or a Recv, or the executor has marked the output |
| // as not required, the node must produce a tensor value at i-th output. |
| if (!(item.is_recv_or_switch || |
| (item.outputs_required && !item.outputs_required[i]))) { |
| s.Update(errors::Internal("Missing ", i, "-th output from ", |
| FormatNodeDefForError(item.kernel->def()))); |
| } |
| } else { |
| Entry* out = &((*outputs)[i]); |
| |
| // Set the allocator attributes of the output entry. |
| out->alloc_attr = ctx->output_alloc_attr(i); |
| |
| // Sanity check of output tensor types. We need to inspect this safely as |
| // we are in the tensor buffer. |
| DataType dtype = val.dtype_safe(); |
| if (dtype == item.output_type(i)) { |
| if (stats && val.tensor->IsInitialized()) { |
| nodestats::SetOutput(stats, i, val.tensor); |
| } |
| if (val.is_ref()) { |
| out->state = Entry::State::HAS_REF_TENSOR; |
| out->ref_tensor.tensor = val.tensor; |
| out->ref_tensor.mu = val.mutex_if_ref; |
| if (log_memory_) { |
| Tensor to_log; |
| { |
| // Dereference the tensor under the lock. |
| tf_shared_lock l(*out->ref_tensor.mu); |
| to_log = *out->ref_tensor.tensor; |
| } |
| LogMemory::RecordTensorOutput(ctx->op_kernel().name(), |
| ctx->step_id(), i, to_log); |
| } |
| } else { |
| // NOTE that std::move is used here, so val.tensor goes to |
| // uninitialized state (val.tensor->IsInitialized return false). |
| out->state = Entry::State::HAS_VALUE; |
| out->val.Init(std::move(*val.tensor)); |
| if (log_memory_) { |
| LogMemory::RecordTensorOutput(ctx->op_kernel().name(), |
| ctx->step_id(), i, *out->val); |
| } |
| } |
| } else { |
| s.Update( |
| errors::Internal("Output ", i, " of type ", DataTypeString(dtype), |
| " does not match declared output type ", |
| DataTypeString(item.output_type(i)), " for node ", |
| FormatNodeDefForError(item.kernel->def()))); |
| } |
| } |
| if (!val.is_ref()) { |
| // If OpKernelContext returns outputs via pass-by-value, we |
| // don't need this trouble. |
| delete val.tensor; |
| } |
| } |
| return s; |
| } |
| |
| template <class PropagatorStateType> |
| bool ExecutorState<PropagatorStateType>::NodeDone( |
| const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats, |
| TaggedNodeReadyQueue* inline_ready) { |
| nodestats::SetAllEnd(stats); |
| if (stats) { |
| if (stats_collector_) { |
| stats->Done(immutable_state_.params().device->name()); |
| } else { |
| delete stats; |
| } |
| } |
| |
| bool abort_run = false; |
| if (!s.ok()) { |
| // Some error happened. This thread of computation is done. |
| mutex_lock l(mu_); |
| if (status_.ok()) { |
| abort_run = true; |
| |
| // If execution has been cancelled, mark any new errors as being derived. |
| // This ensures any errors triggered by cancellation are marked as |
| // derived. |
| if (cancellation_manager_ && cancellation_manager_->IsCancelled()) { |
| status_ = StatusGroup::MakeDerived(s); |
| } else { |
| status_ = s; |
| } |
| } |
| } |
| if (abort_run) { |
| TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); |
| if (cancellation_manager_) { |
| // only log when the abort happens during the actual run time. |
| auto device_name = immutable_state_.params().device->name(); |
| // Use VLOG instead of LOG(warning) because error status is expected when |
| // the executor is run under the grappler optimization phase or when |
| // iterating through a tf.data input pipeline. |
| VLOG(1) << "[" << device_name << "] Executor start aborting: " << s; |
| } |
| |
| if (rendezvous_) { |
| rendezvous_->StartAbort(s); |
| } |
| if (collective_executor_) { |
| collective_executor_->StartAbort(s); |
| } |
| if (cancellation_manager_) { |
| cancellation_manager_->StartCancel(); |
| } |
| } |
| |
| bool completed = false; |
| const size_t ready_size = ready->size(); |
| if (ready_size == 0 || !s.ok()) { |
| completed = (num_outstanding_ops_.fetch_sub(1) == 1); |
| } else if (ready_size > 1) { |
| num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed); |
| } |
| |
| // Schedule the ready nodes in 'ready'. |
| if (s.ok()) { |
| ScheduleReady(ready, inline_ready); |
| } |
| return completed; |
| } |
| |
| template <class PropagatorStateType> |
| void ExecutorState<PropagatorStateType>::ScheduleReady( |
| TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) { |
| if (ready->empty()) return; |
| |
| int64 scheduled_nsec = 0; |
| if (stats_collector_) { |
| scheduled_nsec = nodestats::NowInNsec(); |
| } |
| |
| if (run_all_kernels_inline_) { |
| if (inline_ready == nullptr) { |
| // Schedule all ready kernels from a single closure. This ensure that, |
| // regardless of the `runner_` implementation, all kernels will run |
| // sequentially on the same thread, and thread wakeup overhead and |
| // executor mutex contention will be minimized. |
| runner_([this, ready = std::move(*ready), scheduled_nsec]() { |
| for (auto& tagged_node : ready) { |
| Process(tagged_node, scheduled_nsec); |
| } |
| }); |
| } else { |
| for (auto& tagged_node : *ready) { |
| inline_ready->push_back(tagged_node); |
| } |
| } |
| } else { |
| const TaggedNode* curr_expensive_node = nullptr; |
| if (inline_ready == nullptr) { |
| // Schedule to run all the ready ops in thread pool. |
| for (auto& tagged_node : *ready) { |
| runner_([=]() { Process(tagged_node, scheduled_nsec); }); |
| } |
| } else { |
| for (auto& tagged_node : *ready) { |
| const NodeItem& item = *tagged_node.node_item; |
| if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) { |
| // Inline this inexpensive node. |
| inline_ready->push_back(tagged_node); |
| } else { |
| if (curr_expensive_node) { |
| // Dispatch to another thread since there is plenty of work to |
| // do for this thread. |
| runner_(std::bind(&ExecutorState::Process, this, |
| *curr_expensive_node, scheduled_nsec)); |
| } |
| curr_expensive_node = &tagged_node; |
| } |
| } |
| } |
| if (curr_expensive_node) { |
| if (inline_ready->empty()) { |
| inline_ready->push_back(*curr_expensive_node); |
| } else { |
| // There are inline nodes to run already. We dispatch this expensive |
| // node to other thread. |
| runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node, |
| scheduled_nsec)); |
| } |
| } |
| } |
| ready->clear(); |
| } |
| |
| template <class PropagatorStateType> |
| void ExecutorState<PropagatorStateType>::ScheduleFinish() { |
| // Checks condition to decide if needs to invoke Finish(). If there are |
| // in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke |
| // Finish(). Otherwise, invoke Finish() directly. |
| // Note that it is critical that the ScheduleFinish / Finish codepath does not |
| // block, otherwise we might deadlock. See b/124523000 for details. |
| { |
| mutex_lock lock(num_deferred_ops_mu_); |
| if (num_deferred_ops_ > 0) { |
| finish_when_deferred_ops_done_ = true; |
| return; |
| } |
| } |
| // Finish is always called exactly once per ExecutorState, either here if |
| // there aren't any deferred ops, or in the dec_num_deferred_ops_function if |
| // there are deferred ops. |
| Finish(); |
| } |
| |
| template <class PropagatorStateType> |
| void ExecutorState<PropagatorStateType>::Finish() { |
| mu_.lock(); |
| auto status = status_; |
| auto done_cb = std::move(done_cb_); |
| auto runner = std::move(runner_); |
| mu_.unlock(); |
| int64 step_id = step_id_; |
| CHECK(done_cb != nullptr); |
| Device* device = immutable_state_.params().device; |
| |
| // There are several potential race conditions below. To name a few: |
| // 1. Even if the device's status is OK at the precise moment when |
| // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus() |
| // is called below, caused by work enqueued onto the same device by other |
| // concurrent ExecutorState objects. |
| // 2. Some implementations of Device::RefreshStatus, such as |
| // XlaDevice::RefreshStatus, may be inherently racy because it releases the |
| // device mutex after a stream pointer is acquired and before the stream is |
| // queried for status. |
| // 3. It's the same for some implementations of Device::Sync, such as |
| // XlaDevice::Sync. |
| // |
| // However, these race conditions are acceptable because a stream (and |
| // therefore an XlaDevice) can only go from OK to not-OK, never the opposite, |
| // which means we will at worst report errors when there isn't any, never the |
| // opposite. |
| |
| // An early exit for devices don't allow sync on completion. Ops that run on |
| // these devices should have used num_deferred_ops correctly to ensure the |
| // device has finished all relevant work at this point. |
| if (!device->AllowsSyncOnCompletion()) { |
| status.Update(device->RefreshStatus()); |
| if (!status.ok()) { |
| // In device async execution mode, it's possible for device execution to |
| // lag behind ExecutorState scheduling so much that this is the first |
| // place a device execution error surfaces. |
| // If so, all ExecutorState::NodeDone calls have already happened with OK |
| // status. This is the last defense where StartCancel must be called to |
| // abort all computation still running on any device. |
| // TODO(b/124523000): Always call Finish in a separate thread, so even if |
| // StartCancel blocks the current thread's execution, we won't encounter |
| // deadlocks caused by inter-op thread exhaustion. |
| if (rendezvous_) { |
| rendezvous_->StartAbort(status); |
| } |
| if (collective_executor_) { |
| collective_executor_->StartAbort(status); |
| } |
| if (cancellation_manager_) { |
| cancellation_manager_->StartCancel(); |
| } |
| } |
| delete this; |
| runner([step_id, status, done_cb = std::move(done_cb)]() { |
| profiler::TraceMe traceme( |
| [&] { |
| return absl::StrCat("ExecutorDoneCallback#id=", step_id, "#"); |
| }, |
| 2); |
| done_cb(status); |
| }); |
| return; |
| } |
| |
| if (sync_on_finish_ && status.ok()) { |
| // Block until the device has finished all queued operations. For |
| // devices like GPUs that continue to execute Ops after their Compute |
| // methods have completed, this ensures that control is not returned to |
| // the user until the step (and its side-effects) has actually completed. |
| device->Sync([this, step_id, runner = std::move(runner), |
| done_cb = std::move(done_cb)](const Status& status) mutable { |
| delete this; |
| runner([step_id, status, done_cb = std::move(done_cb)]() { |
| profiler::TraceMe traceme( |
| [&] { |
| return absl::StrCat("ExecutorDoneCallback#id=", step_id, "#"); |
| }, |
| 2); |
| done_cb(status); |
| }); |
| }); |
| } else { |
| delete this; |
| runner([step_id, status, done_cb = std::move(done_cb)]() { |
| profiler::TraceMe traceme( |
| [&] { |
| return absl::StrCat("ExecutorDoneCallback#id=", step_id, "#"); |
| }, |
| 2); |
| done_cb(status); |
| }); |
| } |
| } |
| |
| void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { |
| (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_)) |
| ->RunAsync(std::move(done)); |
| } |
| |
| } // namespace |
| |
| Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph, |
| Executor** executor) { |
| ExecutorImpl* impl = new ExecutorImpl(params); |
| const Status s = impl->Initialize(graph); |
| if (s.ok()) { |
| *executor = impl; |
| } else { |
| delete impl; |
| } |
| return s; |
| } |
| |
| Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, |
| const std::shared_ptr<const NodeProperties>& props, |
| int graph_def_version, OpKernel** kernel) { |
| const auto device_type = DeviceType(device->attributes().device_type()); |
| auto allocator = device->GetAllocator(AllocatorAttributes()); |
| return CreateOpKernel(device_type, device, allocator, flib, |
| device->resource_manager(), props, graph_def_version, |
| kernel); |
| } |
| |
| void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; } |
| |
| namespace { |
| |
| class DefaultExecutorRegistrar { |
| public: |
| DefaultExecutorRegistrar() { |
| Factory* factory = new Factory; |
| ExecutorFactory::Register("", factory); |
| ExecutorFactory::Register("DEFAULT", factory); |
| } |
| |
| private: |
| class Factory : public ExecutorFactory { |
| Status NewExecutor(const LocalExecutorParams& params, const Graph& graph, |
| std::unique_ptr<Executor>* out_executor) override { |
| Executor* ret = nullptr; |
| TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret)); |
| out_executor->reset(ret); |
| return Status::OK(); |
| } |
| }; |
| }; |
| static DefaultExecutorRegistrar registrar; |
| |
| } // namespace |
| |
| } // namespace tensorflow |