[Executor] Split `ExecutorState` into `PropagatorState` and `ExecutorState<PropagatorStateType>`.
This change is part of an ongoing refactoring to simplify "executor.cc" and enable the substitution of more efficient implementations of `PropagateOutputs()`.
PiperOrigin-RevId: 304262448
Change-Id: I46a2d7fcdde89a71c502d272f35adfd34b0c4cab
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 4fd816d..95a7b4d 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2546,6 +2546,7 @@
"common_runtime/debugger_state_interface.h",
"common_runtime/device_resolver_local.h",
"common_runtime/dma_helper.h",
+ "common_runtime/entry.h",
"common_runtime/executor.h",
"common_runtime/executor_factory.h",
"common_runtime/function_optimization_registry.h",
@@ -2553,6 +2554,7 @@
"common_runtime/graph_view.h",
"common_runtime/immutable_executor_state.h",
"common_runtime/input_colocation_exemption_registry.h",
+ "common_runtime/inspecting_placer.h",
"common_runtime/isolate_placer_inspection_required_ops_pass.h",
"common_runtime/local_device.h",
"common_runtime/lower_function_call_op.h",
@@ -2567,7 +2569,7 @@
"common_runtime/partitioning_utils.h",
"common_runtime/placer.h",
"common_runtime/process_util.h",
- "common_runtime/inspecting_placer.h",
+ "common_runtime/propagator_state.h",
"common_runtime/profile_handler.h",
"common_runtime/renamed_device.h",
"common_runtime/rendezvous_mgr.h",
@@ -2640,6 +2642,7 @@
"common_runtime/process_function_library_runtime.cc",
"common_runtime/process_state.cc",
"common_runtime/process_util.cc",
+ "common_runtime/propagator_state.cc",
"common_runtime/renamed_device.cc",
"common_runtime/rendezvous_mgr.cc",
"common_runtime/rendezvous_util.cc",
diff --git a/tensorflow/core/common_runtime/entry.h b/tensorflow/core/common_runtime/entry.h
new file mode 100644
index 0000000..27c1838
--- /dev/null
+++ b/tensorflow/core/common_runtime/entry.h
@@ -0,0 +1,142 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/manual_constructor.h"
+
+namespace tensorflow {
+
+class mutex;
+class Tensor;
+
+// An Entry store a single input value for an individual kernel invocation in
+// an executor.
+//
+// Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
+struct Entry {
+ enum class State {
+ NO_VALUE = 0, // The default state for a newly-created Entry.
+ HAS_VALUE, // `this->val` is valid.
+ HAS_CONST_TENSOR, // `this->const_tensor` is valid.
+ HAS_REF_TENSOR, // `this->ref_tensor` is valid.
+ };
+
+ Entry() : state(State::NO_VALUE) {}
+ Entry(const Entry& other) : state(other.state), alloc_attr(other.alloc_attr) {
+ switch (state) {
+ case State::NO_VALUE:
+ break;
+ case State::HAS_VALUE:
+ val.Init(*other.val);
+ break;
+ case State::HAS_CONST_TENSOR:
+ const_tensor = other.const_tensor;
+ break;
+ case State::HAS_REF_TENSOR:
+ ref_tensor = other.ref_tensor;
+ break;
+ }
+ }
+
+ ~Entry() {
+ if (state == State::HAS_VALUE) val.Destroy();
+ }
+
+ Entry& operator=(const Entry& other) {
+ if (state == State::HAS_VALUE) {
+ val.Destroy();
+ }
+ state = other.state;
+ alloc_attr = other.alloc_attr;
+ switch (state) {
+ case State::NO_VALUE:
+ break;
+ case State::HAS_VALUE:
+ val.Init(*other.val);
+ break;
+ case State::HAS_CONST_TENSOR:
+ const_tensor = other.const_tensor;
+ break;
+ case State::HAS_REF_TENSOR:
+ ref_tensor = other.ref_tensor;
+ break;
+ }
+ return *this;
+ }
+
+ Entry& operator=(Entry&& other) {
+ if (state == State::HAS_VALUE) {
+ val.Destroy();
+ }
+ state = other.state;
+ alloc_attr = other.alloc_attr;
+ switch (state) {
+ case State::NO_VALUE:
+ break;
+ case State::HAS_VALUE:
+ val.Init(std::move(*other.val));
+ break;
+ case State::HAS_CONST_TENSOR:
+ const_tensor = other.const_tensor;
+ break;
+ case State::HAS_REF_TENSOR:
+ ref_tensor = other.ref_tensor;
+ break;
+ }
+ return *this;
+ }
+
+ // Clears the <val> field, and sets this entry to the `NO_VALUE` state.
+ void ClearVal() {
+ if (state == State::HAS_VALUE) {
+ val.Destroy();
+ }
+ state = State::NO_VALUE;
+ }
+
+ union {
+ // A tensor value. Valid iff `state_ == HAS_VALUE`.
+ ManualConstructor<Tensor> val;
+
+ // A pointer to a constant tensor value. Valid iff `state_ ==
+ // HAS_CONST_TENSOR`.
+ const Tensor* const_tensor;
+
+ // A tensor reference and associated mutex. Valid iff `state_ ==
+ // HAS_REF_TENSOR`.
+ struct {
+ Tensor* tensor;
+ mutex* mu;
+ } ref_tensor;
+ };
+
+ // The current state of this entry, indicating which member of the above
+ // union is active.
+ State state;
+
+ // The attributes of the allocator that creates the tensor.
+ AllocatorAttributes alloc_attr;
+};
+
+// TODO(b/152925936): Re-evaluate this constant with current usage patterns.
+typedef gtl::InlinedVector<Entry, 4> EntryVector;
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 4397258..39f396d 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -21,10 +21,12 @@
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/costmodel_manager.h"
+#include "tensorflow/core/common_runtime/entry.h"
#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/graph_view.h"
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
#include "tensorflow/core/common_runtime/pending_counts.h"
+#include "tensorflow/core/common_runtime/propagator_state.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/allocator.h"
@@ -112,8 +114,6 @@
} // namespace nodestats
-class ExecutorImpl;
-
// Time the execution of kernels (in CPU cycles). Used to dynamically identify
// inexpensive kernels which can be dispatched inline.
struct KernelTimer {
@@ -124,6 +124,7 @@
}
};
+// TODO(b/152925936): Re-evaluate these constants with current usage patterns.
typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
@@ -140,6 +141,7 @@
void RunAsync(const Args& args, DoneCallback done) override;
private:
+ template <class PropagatorStateType>
friend class ExecutorState;
// Stores execution time information about the kernels in an executor's graph.
@@ -212,8 +214,55 @@
};
// The state associated with one invocation of ExecutorImpl::Run.
-// ExecutorState dispatches nodes when they become ready and keeps
-// track of how many predecessors of a node have not done (pending_).
+//
+// ExecutorState dispatches nodes when they become ready, and delegates to an
+// instance of `PropagatorStateType` to keep track of how many predecessors of a
+// are still pending.
+//
+// The template argument `class PropagatorStateType` must define the following
+// public members:
+// * A type `TaggedNode`, representing a node to be processed, with public
+// members:
+// * `const NodeItem& get_node_item() const`
+// * `bool get_is_dead() const`
+// * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be
+// processed, with public members (having the same meanings as in an
+// `std::vector<TaggedNode>`):
+// * `void push_back(const TaggedNode& node)`
+// * `TaggedNode front() const`
+// * `void pop_front()`
+// * `bool empty() const`
+// * A type `TaggedNodeSeq`, representing a list of nodes to be schedules, with
+// public members (having the same meanings as in an
+// `std::vector<TaggedNode>`):
+// * `size_t size() const`
+// * `bool empty() const`
+// * `void clear()`
+// * `const_iterator begin() const`
+// * `const_iterator end() const`
+// * A public constructor, `PropagatorStateType(const ImmutableExecutorState&
+// immutable_state, int64 step_id)`.
+// * The following public methods:
+// * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
+// TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the
+// nodes in `roots` and adds them to `*ready`
+// * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector*
+// outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the
+// given `tagged_node` to the destinations of its output edges, and adds
+// any newly runnable nodes to `*ready`
+// * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which
+// returns a pointer to the input tensors for the given `tagged_node`
+// * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`,
+// which creates a `FrameAndIter` for the given `tagged_node`
+// * `void DumpState()`, which dumps the dynamic state of the executing graph
+// * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records
+// that a node has started
+// * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records
+// that a node has completed
+//
+// See `PropagatorState` in "./propagator_state.h" for an example of a type that
+// can be used to instantiate `PropagatorStateType`.
+template <class PropagatorStateType>
class ExecutorState {
public:
ExecutorState(const Executor::Args& args,
@@ -224,452 +273,58 @@
void RunAsync(Executor::DoneCallback done);
private:
- // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
- struct Entry {
- enum class State {
- NO_VALUE = 0, // The default state for a newly-created Entry.
- HAS_VALUE, // `this->val` is valid.
- HAS_CONST_TENSOR, // `this->const_tensor` is valid.
- HAS_REF_TENSOR, // `this->ref_tensor` is valid.
- };
+ // Use `TaggedNode` types defined by `PropagatorStateType`.
+ typedef typename PropagatorStateType::TaggedNode TaggedNode;
+ typedef
+ typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue;
+ typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq;
- Entry() : state(State::NO_VALUE) {}
- Entry(const Entry& other)
- : state(other.state), alloc_attr(other.alloc_attr) {
- switch (state) {
- case State::NO_VALUE:
- break;
- case State::HAS_VALUE:
- val.Init(*other.val);
- break;
- case State::HAS_CONST_TENSOR:
- const_tensor = other.const_tensor;
- break;
- case State::HAS_REF_TENSOR:
- ref_tensor = other.ref_tensor;
- break;
- }
- }
+ struct AsyncState;
- ~Entry() {
- if (state == State::HAS_VALUE) val.Destroy();
- }
+ // Process a ready node in current thread.
+ void Process(TaggedNode node, int64 scheduled_nsec);
- Entry& operator=(const Entry& other) {
- if (state == State::HAS_VALUE) {
- val.Destroy();
- }
- state = other.state;
- alloc_attr = other.alloc_attr;
- switch (state) {
- case State::NO_VALUE:
- break;
- case State::HAS_VALUE:
- val.Init(*other.val);
- break;
- case State::HAS_CONST_TENSOR:
- const_tensor = other.const_tensor;
- break;
- case State::HAS_REF_TENSOR:
- ref_tensor = other.ref_tensor;
- break;
- }
- return *this;
- }
+ Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
+ EntryVector* outputs, NodeExecStatsInterface* stats);
+ void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
+ const TaggedNode& tagged_node, Entry* first_input,
+ NodeExecStatsInterface* stats);
+ void ProcessNoop(NodeExecStatsInterface* stats);
+ void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
+ NodeExecStatsInterface* stats);
- Entry& operator=(Entry&& other) {
- if (state == State::HAS_VALUE) {
- val.Destroy();
- }
- state = other.state;
- alloc_attr = other.alloc_attr;
- switch (state) {
- case State::NO_VALUE:
- break;
- case State::HAS_VALUE:
- val.Init(std::move(*other.val));
- break;
- case State::HAS_CONST_TENSOR:
- const_tensor = other.const_tensor;
- break;
- case State::HAS_REF_TENSOR:
- ref_tensor = other.ref_tensor;
- break;
- }
- return *this;
- }
+ // Before invoking item->kernel, fills in its "inputs".
+ Status PrepareInputs(const NodeItem& item, Entry* first_input,
+ TensorValueVec* inputs,
+ AllocatorAttributeVec* input_alloc_attrs,
+ bool* is_input_dead);
- // Clears the <val> field, and sets this entry to the `NO_VALUE` state.
- void ClearVal() {
- if (state == State::HAS_VALUE) {
- val.Destroy();
- }
- state = State::NO_VALUE;
- }
+ // After item->kernel computation is done, processes its outputs.
+ Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
+ EntryVector* outputs, NodeExecStatsInterface* stats);
- union {
- // A tensor value. Valid iff `state_ == HAS_VALUE`.
- ManualConstructor<Tensor> val;
+ // Called after each node finishes. Takes ownership of "stats". Returns true
+ // if execution has completed.
+ //
+ // This method will clear `*ready` before returning.
+ bool NodeDone(const Status& s, TaggedNodeSeq* ready,
+ NodeExecStatsInterface* stats,
+ TaggedNodeReadyQueue* inline_ready);
- // A pointer to a constant tensor value. Valid iff `state_ ==
- // HAS_CONST_TENSOR`.
- const Tensor* const_tensor;
+ // Schedule all the expensive nodes in '*ready', and put all the inexpensive
+ // nodes in 'ready' into 'inline_ready'.
+ //
+ // This method will clear `*ready` before returning.
+ void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
- // A tensor reference and associated mutex. Valid iff `state_ ==
- // HAS_REF_TENSOR`.
- struct {
- Tensor* tensor;
- mutex* mu;
- } ref_tensor;
- };
-
- // The current state of this entry, indicating which member of the above
- // union is active.
- State state;
-
- // The attributes of the allocator that creates the tensor.
- AllocatorAttributes alloc_attr;
- };
+ // Clean up when this executor is done.
+ void Finish();
+ void ScheduleFinish();
// Contains the device context assigned by the device at the beginning of a
// step.
DeviceContext* device_context_ = nullptr;
- struct TaggedNode;
- typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
- typedef gtl::InlinedVector<Entry, 4> EntryVector;
-
- struct IterationState {
- explicit IterationState(const PendingCounts* pending_counts,
- int total_input_tensors)
- : input_tensors(new Entry[total_input_tensors]),
- outstanding_ops(0),
- outstanding_frame_count(0),
- counts(*pending_counts) { // Initialize with copy of *pending_counts
- }
-
- // The state of an iteration.
-
- // One copy per iteration. For iteration k, i-th node's j-th input is in
- // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
- // either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
- //
- // NOTE: No need to protect input_tensors[i] by any locks because it
- // is resized once. Each element of tensors_ is written once by the
- // source node of an edge and is cleared by the destination of the same
- // edge. The latter node is never run concurrently with the former node.
- Entry* input_tensors;
-
- // The number of outstanding ops for each iteration.
- size_t outstanding_ops;
-
- // The number of outstanding frames for each iteration.
- int outstanding_frame_count;
- int pending(PendingCounts::Handle h) { return counts.pending(h); }
- int decrement_pending(PendingCounts::Handle h, int v) {
- return counts.decrement_pending(h, v);
- }
- // Mark a merge node as live
- // REQUIRES: Node corresponding to "h" is a merge node
- void mark_live(PendingCounts::Handle h) { counts.mark_live(h); }
- // Mark a node to show that processing has started.
- void mark_started(PendingCounts::Handle h) { counts.mark_started(h); }
- // Mark a node to show that processing has completed.
- void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); }
- PendingCounts::NodeState node_state(PendingCounts::Handle h) {
- return counts.node_state(h);
- }
-
- int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); }
- void increment_dead_count(PendingCounts::Handle h) {
- counts.increment_dead_count(h);
- }
- PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h,
- bool increment_dead) {
- return counts.adjust_for_activation(h, increment_dead);
- }
-
- ~IterationState() { delete[] input_tensors; }
-
- private:
- PendingCounts counts;
- };
-
- struct FrameState {
- explicit FrameState(const ImmutableExecutorState& immutable_state,
- int parallel_iters)
- : immutable_state(immutable_state),
- max_parallel_iterations(parallel_iters),
- num_outstanding_iterations(1),
- iterations(parallel_iters + 1),
- iterations_raw(iterations.data()) {}
-
- // A new frame is created for each loop. Execution starts at iteration 0.
- // When a value at iteration 0 passes through a NextIteration node,
- // iteration 1 is created and starts running. Note that iteration 0 may
- // still be running so multiple iterations may run in parallel. The
- // frame maintains the state of iterations in several data structures
- // such as pending_count and input_tensors. When iteration 0 completes,
- // we garbage collect the state of iteration 0.
- //
- // A frame instance is considered "done" and can be garbage collected
- // if all its inputs have entered and all its iterations are "done".
- //
- // A frame manages the live iterations of an iterative computation.
- // Iteration i is considered "done" when there are no outstanding ops,
- // frames at iteration i are done, all recvs for this iteration are
- // completed, and iteration i-1 is done. For iteration 0, we instead
- // wait for there to be no more pending inputs of the frame.
- //
- // Frames and iterations are garbage collected once they are done.
- // The state we need to keep around is highly dependent on the
- // parallelism enabled by the scheduler. We may want to have the
- // scheduler dynamically control the outstanding number of live
- // parallel frames and iterations. To reduce the state space, the
- // scheduler might want to schedule ops in inner frames first and
- // lower iterations first.
- //
- // This frame state is mostly initialized lazily on demand so we
- // don't introduce unnecessary overhead.
-
- // The immutable state of the executor the frame is in.
- const ImmutableExecutorState& immutable_state;
-
- // The name of this frame, which is the concatenation of its parent
- // frame name, the iteration of the parent frame when this frame was
- // created, and the value of the attr 'frame_name'.
- string frame_name;
-
- // The unique id for this frame. Generated by fingerprinting
- // frame_name.
- uint64 frame_id;
-
- // The iteration id of its parent frame when this frame is created.
- // -1 if there is no parent frame. The frame_name/parent_iter pair
- // uniquely identifies this FrameState.
- int64 parent_iter = -1;
-
- // The FrameState of its parent frame.
- FrameState* parent_frame = nullptr;
-
- // The maximum allowed number of parallel iterations.
- const int max_parallel_iterations;
-
- // The number of inputs this frame is still waiting.
- int num_pending_inputs = 0;
-
- // The highest iteration number we have reached so far in this frame.
- int64 iteration_count TF_GUARDED_BY(mu) = 0;
-
- // The number of outstanding iterations.
- int num_outstanding_iterations TF_GUARDED_BY(mu) = 1;
-
- private:
- // The active iteration states of this frame.
- gtl::InlinedVector<IterationState*, 12> iterations;
- IterationState** const iterations_raw TF_GUARDED_BY(mu);
- IterationState* iterations_first TF_GUARDED_BY(mu);
-
- public:
- // The NextIteration nodes to enter a new iteration. If the number of
- // outstanding iterations reaches the limit, we will defer the start of
- // the next iteration until the number of outstanding iterations falls
- // below the limit.
- std::vector<std::pair<const NodeItem*, Entry>> next_iter_roots
- TF_GUARDED_BY(mu);
-
- // The values of the loop invariants for this loop. They are added into
- // this list as they "enter" the frame. When a loop invariant enters,
- // we make it available to all active iterations. When the frame starts
- // a new iteration, we make all the current loop invariants available
- // to the new iteration.
- std::vector<std::pair<const NodeItem*, Entry>> inv_values TF_GUARDED_BY(mu);
-
- // The list of dead exit node items for the current highest iteration. We
- // will only "execute" the dead exits of the final iteration.
- std::vector<const NodeItem*> dead_exits TF_GUARDED_BY(mu);
-
- // Static information specific to this frame.
- PendingCounts* pending_counts = nullptr;
- int total_input_tensors = 0;
- std::vector<const NodeItem*>* nodes = nullptr;
-
- // Lock ordering: ExecutorState.mu_ < mu;
- // during structured traversal: parent_frame->mu < mu.
- mutex mu;
-
- void InitializeFrameInfo(const string& enter_name) {
- const ImmutableExecutorState::FrameInfo* finfo =
- immutable_state.get_frame_info(enter_name);
- DCHECK_NE(finfo, nullptr);
- pending_counts = finfo->pending_counts.get();
- total_input_tensors = finfo->total_inputs;
- num_pending_inputs = finfo->input_count;
- nodes = finfo->nodes.get();
- }
-
- inline IterationState* GetIteration(int64 iter)
- TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
- if (TF_PREDICT_TRUE(iter == 0)) {
- return iterations_first;
- } else {
- size_t index = iter % (max_parallel_iterations + 1);
- return iterations_raw[index];
- }
- }
-
- inline void SetIteration(int64 iter, IterationState* state)
- TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
- size_t index = iter % (max_parallel_iterations + 1);
- DCHECK(state == nullptr || iterations[index] == nullptr);
- iterations_raw[index] = state;
- if (index == 0) {
- iterations_first = state;
- }
- }
-
- // Decrement the outstanding op count and clean up the iterations in the
- // frame. Return true iff the execution of the frame is done.
- inline bool DecrementOutstandingOps(const GraphView* gview, int64 iter,
- TaggedNodeSeq* ready) {
- mutex_lock l(mu);
- return DecrementOutstandingOpsLocked(gview, iter, ready);
- }
-
- // Decrement the outstanding op count and clean up the iterations in the
- // frame. Return true iff the execution of the frame is done.
- inline bool DecrementOutstandingOpsLocked(const GraphView* gview,
- int64 iter, TaggedNodeSeq* ready)
- TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
- IterationState* istate = GetIteration(iter);
- istate->outstanding_ops--;
- if (istate->outstanding_ops != 0) {
- return false;
- } else {
- return CleanupIterations(gview, iter, ready);
- }
- }
-
- // Returns true if the computation in the frame is completed.
- inline bool IsFrameDone() TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
- return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
- }
-
- // Returns true if the iteration of the frame is completed.
- bool IsIterationDone(int64 iter) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
-
- // Increments the iteration id. If this is a new iteration, initialize it.
- void IncrementIteration(const GraphView* gview, TaggedNodeSeq* ready)
- TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
-
- // Activate all the deferred NextIteration nodes in a new iteration.
- void ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready)
- TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
-
- // Activate all the current loop invariants in a new iteration.
- void ActivateLoopInvs(const GraphView* gview, int64 iter,
- TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
-
- // Add a new loop invariant and make it available to all active
- // iterations.
- void AddLoopInv(const NodeItem* item, const Entry& entry,
- TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
-
- // Activate the successors of a node. Contents of *outputs are left in an
- // indeterminate state after returning from this method.
- void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter,
- EntryVector* outputs, TaggedNodeSeq* ready)
- TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
-
- // Cleanup iterations of this frame starting from iteration iter.
- bool CleanupIterations(const GraphView* gview, int64 iter,
- TaggedNodeSeq* ready)
- TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
-
- void DumpIterationState(ExecutorState* parent) {
- mutex_lock l(mu);
- for (IterationState* iteration : iterations) {
- if (iteration) {
- LOG(WARNING) << " Iteration:";
- parent->DumpIterationState(this, iteration);
- }
- }
- }
-
- ~FrameState() {
- for (size_t i = 0; i < iterations.size(); ++i) {
- delete iterations[i];
- iterations[i] = nullptr;
- }
- }
-
- private:
- // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
- void ActivateNodesFastPath(const NodeItem* item, const bool is_dead,
- int64 iter, EntryVector* outputs,
- TaggedNodeSeq* ready)
- TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
-
- void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
- int64 iter, EntryVector* outputs,
- TaggedNodeSeq* ready)
- TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
- };
-
- // A tagged node: <frame*, iter, node*>.
- struct TaggedNode {
- const NodeItem* node_item;
- FrameState* input_frame; // = nullptr;
- int64 input_iter; // = -1;
- bool is_dead; // = false;
-
- TaggedNode() {}
-
- TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter,
- bool dead)
- : node_item(node_item),
- input_frame(in_frame),
- input_iter(in_iter),
- is_dead(dead) {}
- };
-
- // A drop-in replacement for std::deque<TaggedNode>. We typically don't
- // have that many nodes in the ready queue, so we just use a vector and
- // don't free up memory from the queue as we consume nodes.
- class TaggedNodeReadyQueue {
- public:
- TaggedNodeReadyQueue() : front_index_(0) {}
-
- void push_back(const TaggedNode& node) { ready_.push_back(node); }
- TaggedNode front() const {
- DCHECK_LT(front_index_, ready_.size());
- return ready_[front_index_];
- }
- void pop_front() {
- DCHECK_LT(front_index_, ready_.size());
- front_index_++;
- if ((front_index_ == ready_.size()) || (front_index_ > 16384)) {
- if (front_index_ == ready_.size()) {
- ready_.clear();
- } else {
- // Lots of unused entries at beginning of vector: move everything
- // down to start of vector.
- ready_.erase(ready_.begin(), ready_.begin() + front_index_);
- }
- front_index_ = 0;
- }
- }
- bool empty() const { return ready_.empty(); }
- const TaggedNode* begin() const { return ready_.begin() + front_index_; }
- const TaggedNode* end() const { return ready_.end(); }
-
- private:
- gtl::InlinedVector<TaggedNode, 16> ready_;
- int front_index_;
- };
-
- struct AsyncState;
-
const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply.
// true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
@@ -702,14 +357,7 @@
bool sync_on_finish_;
const bool run_all_kernels_inline_;
- // Owned.
-
- // A flag that is set on error after the frame state has been
- // dumped for diagnostic purposes.
- bool dumped_on_error_ = false;
-
- // The root frame in which the execution of this step is started.
- FrameState* root_frame_;
+ PropagatorStateType propagator_;
// Invoked when the execution finishes.
Executor::DoneCallback done_cb_;
@@ -724,110 +372,12 @@
mutex mu_;
Status status_ TF_GUARDED_BY(mu_);
-
- // Mapping from frame name to outstanding frames. A new frame is created
- // at some iteration of an active frame. So the unique key for the new
- // child frame is composed of the name of the parent frame, the iteration
- // number at which the parent frame is creating the new frame, and the
- // name of the new frame from nodedef.
- gtl::FlatMap<string, FrameState*> outstanding_frames_ TF_GUARDED_BY(mu_);
-
- // The unique name of a frame.
- inline string MakeFrameName(FrameState* frame, int64 iter_id,
- const string& name) {
- return strings::StrCat(frame->frame_name, ";", iter_id, ";", name);
- }
-
- // Find an existing or create a new child frame in the frame 'frame' at
- // iteration 'iter'.
- void FindOrCreateChildFrame(FrameState* frame, int64 iter,
- const NodeItem& node_item, FrameState** child);
-
- // Delete a frame. Called when the frame is done.
- void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
-
- // Cleanup frames and iterations starting from frame/iter. Called when
- // a child frame is done.
- void CleanupFramesIterations(FrameState* frame, int64 iter,
- TaggedNodeSeq* ready);
-
- // Process a ready node in current thread.
- void Process(TaggedNode node, int64 scheduled_nsec);
-
- Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
- EntryVector* outputs,
- NodeExecStatsInterface* stats);
- void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
- const TaggedNode& tagged_node, Entry* first_input,
- NodeExecStatsInterface* stats);
- void ProcessNoop(NodeExecStatsInterface* stats);
- void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
- NodeExecStatsInterface* stats);
-
- // Before invoking item->kernel, fills in its "inputs".
- Status PrepareInputs(const NodeItem& item, Entry* first_input,
- TensorValueVec* inputs,
- AllocatorAttributeVec* input_alloc_attrs,
- bool* is_input_dead);
-
- // After item->kernel computation is done, processes its outputs.
- Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
- EntryVector* outputs, NodeExecStatsInterface* stats);
-
- // After processing the outputs, propagates the outputs to their dsts.
- // Contents of *outputs are left in an indeterminate state after
- // returning from this method.
- void PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item,
- EntryVector* outputs, TaggedNodeSeq* ready);
-
- // Called after each node finishes. Takes ownership of "stats". Returns true
- // if execution has completed.
- //
- // This method will clear `*ready` before returning.
- bool NodeDone(const Status& s, TaggedNodeSeq* ready,
- NodeExecStatsInterface* stats,
- TaggedNodeReadyQueue* inline_ready);
-
- // Schedule all the expensive nodes in '*ready', and put all the inexpensive
- // nodes in 'ready' into 'inline_ready'.
- //
- // This method will clear `*ready` before returning.
- void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
-
- // For debugging/logging only.
- inline void MaybeMarkCompleted(FrameState* frame, int64 iter,
- const int node_id);
-
- // Provide debugging output about an outstanding node in the executor.
- void DumpPendingNodeState(const int node_id, const Entry* input_vector,
- bool show_nodes_with_no_ready_inputs);
- void DumpActiveNodeState(const int node_id, const Entry* input_vector);
-
- // Provide debugging output about an outstanding iteration in the executor.
- void DumpIterationState(const FrameState* frame, IterationState* iteration);
-
- // Provide debugging output of the state of the executor.
- void DumpState();
- const Tensor* GetTensorValueForDump(const Entry& input);
-
- // Clean up when this executor is done.
- void Finish();
- void ScheduleFinish();
-
- // A standalone routine for this expression so that we can express
- // that we don't want thread safety analysis on this reference (it's
- // safe to do without the lock because the iterations array never
- // resizes and this particular iteration's array element will not
- // be changed out from under us because the iteration is still alive).
- Entry* GetInputTensors(FrameState* input_frame,
- int64 input_iter) const TF_NO_THREAD_SAFETY_ANALYSIS {
- return input_frame->GetIteration(input_iter)->input_tensors;
- }
};
-ExecutorState::ExecutorState(const Executor::Args& args,
- const ImmutableExecutorState& immutable_state,
- ExecutorImpl::KernelStats* kernel_stats)
+template <class PropagatorStateType>
+ExecutorState<PropagatorStateType>::ExecutorState(
+ const Executor::Args& args, const ImmutableExecutorState& immutable_state,
+ ExecutorImpl::KernelStats* kernel_stats)
: vlog_(VLOG_IS_ON(1)),
log_memory_(LogMemory::IsEnabled()),
step_id_(args.step_id),
@@ -850,39 +400,25 @@
runner_(args.runner),
sync_on_finish_(args.sync_on_finish),
run_all_kernels_inline_(args.run_all_kernels_inline),
+ propagator_(immutable_state, step_id_),
num_outstanding_ops_(0) {
if (args.user_intra_op_threadpool != nullptr) {
Device* device = immutable_state_.params().device;
user_device_ = RenamedDevice::NewRenamedDevice(
device->name(), device, false, false, args.user_intra_op_threadpool);
}
-
- // We start the entire execution in iteration 0 of the root frame
- // so let us create the root frame and the state for iteration 0.
- // We assume root_frame_->frame_name.empty().
- root_frame_ = new FrameState(immutable_state_, 1);
- root_frame_->frame_id = 0; // must be 0
- root_frame_->InitializeFrameInfo(root_frame_->frame_name);
-
- // Initialize iteration 0.
- root_frame_->SetIteration(
- 0, new IterationState(root_frame_->pending_counts,
- root_frame_->total_input_tensors));
-
- outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
}
-ExecutorState::~ExecutorState() {
- for (auto name_frame : outstanding_frames_) {
- delete name_frame.second;
- }
+template <class PropagatorStateType>
+ExecutorState<PropagatorStateType>::~ExecutorState() {
if (device_context_) {
device_context_->Unref();
}
delete slice_reader_cache_;
}
-void ExecutorState::RunAsync(Executor::DoneCallback done) {
+template <class PropagatorStateType>
+void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) {
TaggedNodeSeq ready;
// Ask the device to fill in the device context map.
@@ -897,19 +433,12 @@
// Initialize the ready queue.
ready.reserve(immutable_state_.root_nodes().size());
- for (const NodeItem* item : immutable_state_.root_nodes()) {
- DCHECK_EQ(item->num_inputs, 0);
- ready.push_back(TaggedNode{item, root_frame_, 0, false});
- }
+ propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready);
+ num_outstanding_ops_ = ready.size();
if (ready.empty()) {
delete this;
done(Status::OK());
} else {
- num_outstanding_ops_ = ready.size();
- {
- mutex_lock l(root_frame_->mu);
- root_frame_->GetIteration(0)->outstanding_ops = ready.size();
- }
done_cb_ = std::move(done);
// Schedule to run all the ready ops in thread pool.
ScheduleReady(&ready, nullptr);
@@ -921,7 +450,8 @@
// asynchronous kernels because OpKernelContext methods like input_type(i) needs
// the param points to valid input type vector. It's not an issue for
// sync kernels because these vectors are kept on the stack.
-struct ExecutorState::AsyncState {
+template <class PropagatorStateType>
+struct ExecutorState<PropagatorStateType>::AsyncState {
AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
const NodeItem* _item, Entry* _first_input,
NodeExecStatsInterface* _stats)
@@ -978,10 +508,10 @@
return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive));
}
-Status ExecutorState::ProcessSync(const NodeItem& item,
- OpKernelContext::Params* params,
- EntryVector* outputs,
- NodeExecStatsInterface* stats) {
+template <class PropagatorStateType>
+Status ExecutorState<PropagatorStateType>::ProcessSync(
+ const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs,
+ NodeExecStatsInterface* stats) {
Status s;
OpKernelContext ctx(params, item.num_outputs);
nodestats::SetOpStart(stats);
@@ -1018,11 +548,11 @@
return s;
}
-void ExecutorState::ProcessAsync(const NodeItem& item,
- const OpKernelContext::Params& params,
- const TaggedNode& tagged_node,
- Entry* first_input,
- NodeExecStatsInterface* stats) {
+template <class PropagatorStateType>
+void ExecutorState<PropagatorStateType>::ProcessAsync(
+ const NodeItem& item, const OpKernelContext::Params& params,
+ const TaggedNode& tagged_node, Entry* first_input,
+ NodeExecStatsInterface* stats) {
AsyncOpKernel* async_kernel = item.kernel->AsAsync();
DCHECK(async_kernel != nullptr);
AsyncState* state =
@@ -1040,7 +570,7 @@
if (vlog_) {
VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
<< step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
- << (state->tagged_node.is_dead ? " is dead" : "")
+ << (state->tagged_node.get_is_dead() ? " is dead" : "")
<< " device: " << device->name();
}
@@ -1049,12 +579,10 @@
for (int i = 0; i < num_inputs; ++i) {
(first_input + i)->ClearVal();
}
- FrameState* input_frame = state->tagged_node.input_frame;
- const int64 input_iter = state->tagged_node.input_iter;
- MaybeMarkCompleted(input_frame, input_iter, state->item->node_id);
+ propagator_.MaybeMarkCompleted(state->tagged_node);
TaggedNodeSeq ready;
if (s.ok()) {
- PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
+ propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready);
}
outputs.clear();
const bool completed = NodeDone(s, &ready, stats, nullptr);
@@ -1074,14 +602,16 @@
}
}
-void ExecutorState::ProcessNoop(NodeExecStatsInterface* stats) {
+template <class PropagatorStateType>
+void ExecutorState<PropagatorStateType>::ProcessNoop(
+ NodeExecStatsInterface* stats) {
nodestats::SetOpStart(stats);
nodestats::SetOpEnd(stats);
}
-void ExecutorState::ProcessConstTensor(const NodeItem& item,
- EntryVector* outputs,
- NodeExecStatsInterface* stats) {
+template <class PropagatorStateType>
+void ExecutorState<PropagatorStateType>::ProcessConstTensor(
+ const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) {
nodestats::SetOpStart(stats);
nodestats::SetOpEnd(stats);
outputs->resize(1);
@@ -1091,12 +621,11 @@
output.alloc_attr = item.output_attrs()[0];
}
-void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
+template <class PropagatorStateType>
+void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node,
+ int64 scheduled_nsec) {
profiler::TraceMe activity(
- [&] {
- return absl::StrCat("ExecutorState::Process#id=", step_id_,
- ",iter_num=", tagged_node.input_iter, "#");
- },
+ [&] { return absl::StrCat("ExecutorState::Process#id=", step_id_, "#"); },
2);
WithContext wc(context_);
TaggedNodeSeq ready;
@@ -1164,22 +693,14 @@
while (!inline_ready.empty()) {
tagged_node = inline_ready.front();
inline_ready.pop_front();
- const NodeItem& item = *tagged_node.node_item;
- FrameState* input_frame = tagged_node.input_frame;
- const int64 input_iter = tagged_node.input_iter;
+ const NodeItem& item = tagged_node.get_node_item();
const int id = item.node_id;
- // TODO(misard) Replace with a finer-grain enabling flag once we
- // add better optional debugging support.
- if (vlog_ && VLOG_IS_ON(1)) {
- mutex_lock l(input_frame->mu);
- input_frame->GetIteration(input_iter)
- ->mark_started(immutable_state_.pending_ids()[id]);
- }
+ propagator_.MaybeMarkStarted(tagged_node);
params.track_allocations = false;
stats = nullptr;
- if (stats_collector_ && !tagged_node.is_dead) {
+ if (stats_collector_ && !tagged_node.get_is_dead()) {
stats = stats_collector_->CreateNodeExecStats(&item.kernel->def());
// Track allocations if and only if we are collecting statistics, and
// `stats` object is expecting allocations to be tracked.
@@ -1191,19 +712,18 @@
if (vlog_) {
VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
<< SummarizeNodeDef(item.kernel->def())
- << (tagged_node.is_dead ? " is dead" : "")
+ << (tagged_node.get_is_dead() ? " is dead" : "")
<< " device: " << device->name();
}
- Entry* input_tensors = GetInputTensors(input_frame, input_iter);
- Entry* first_input = input_tensors + item.input_start;
+ Entry* first_input = propagator_.GetInputTensors(tagged_node);
outputs.clear();
// Only execute this node if it is not dead or it is a send/recv
// transfer node. For transfer nodes, we need to propagate the "dead"
// bit even when the node is dead.
bool launched_asynchronously = false;
- if (tagged_node.is_dead && !item.is_transfer_node) {
+ if (tagged_node.get_is_dead() && !item.is_transfer_node) {
outputs.resize(item.num_outputs);
} else if (TF_PREDICT_FALSE(item.is_noop)) {
ProcessNoop(stats);
@@ -1220,7 +740,7 @@
for (int i = 0; i < num_inputs; ++i) {
(first_input + i)->ClearVal();
}
- MaybeMarkCompleted(input_frame, input_iter, id);
+ propagator_.MaybeMarkCompleted(tagged_node);
// Continue to process the nodes in 'inline_ready'.
completed = NodeDone(s, &ready, stats, &inline_ready);
continue;
@@ -1228,7 +748,7 @@
// Set up compute params.
params.op_kernel = item.kernel;
- params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
+ params.frame_iter = propagator_.GetFrameAndIter(tagged_node);
params.is_input_dead = is_input_dead;
params.output_attr_array = item.output_attrs();
params.forward_from_array = item.forward_from();
@@ -1246,7 +766,7 @@
if (vlog_) {
VLOG(2) << "Synchronous kernel done: " << id << " step "
<< params.step_id << " " << SummarizeNodeDef(item.kernel->def())
- << (tagged_node.is_dead ? " is dead: " : "")
+ << (tagged_node.get_is_dead() ? " is dead: " : "")
<< " device: " << device->name();
}
@@ -1255,10 +775,10 @@
for (int i = 0; i < num_inputs; ++i) {
(first_input + i)->ClearVal();
}
- MaybeMarkCompleted(input_frame, input_iter, id);
+ propagator_.MaybeMarkCompleted(tagged_node);
// Propagates outputs.
if (s.ok()) {
- PropagateOutputs(tagged_node, &item, &outputs, &ready);
+ propagator_.PropagateOutputs(tagged_node, &outputs, &ready);
}
outputs.clear();
if (stats) {
@@ -1273,10 +793,10 @@
if (completed) ScheduleFinish();
}
-Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
- TensorValueVec* inputs,
- AllocatorAttributeVec* input_alloc_attrs,
- bool* is_input_dead) {
+template <class PropagatorStateType>
+Status ExecutorState<PropagatorStateType>::PrepareInputs(
+ const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
+ AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
inputs->clear();
inputs->resize(item.num_inputs);
input_alloc_attrs->clear();
@@ -1384,9 +904,10 @@
return Status::OK();
}
-Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
- EntryVector* outputs,
- NodeExecStatsInterface* stats) {
+template <class PropagatorStateType>
+Status ExecutorState<PropagatorStateType>::ProcessOutputs(
+ const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs,
+ NodeExecStatsInterface* stats) {
DCHECK_EQ(0, outputs->size());
outputs->resize(item.num_outputs);
@@ -1397,7 +918,7 @@
// add better optional debugging support.
if (vlog_ && VLOG_IS_ON(1)) {
LOG(WARNING) << this << " Compute status: " << s;
- DumpState();
+ propagator_.DumpState();
}
if (s.code() == error::RESOURCE_EXHAUSTED) {
if (stats_collector_) {
@@ -1481,120 +1002,10 @@
return s;
}
-void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
- const NodeItem* item, EntryVector* outputs,
- TaggedNodeSeq* ready) {
- profiler::TraceMe activity(
- [&]() {
- return strings::StrCat(
- "ExecutorPropagateOutputs#", "id=", step_id_,
- ",kernel_name=", item->kernel->name_view(),
- ",num_output_edges=", item->num_output_edges,
- ",num_output_control_edges=", item->num_output_control_edges, "#");
- },
- profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
-
- FrameState* input_frame = tagged_node.input_frame;
- const int64 input_iter = tagged_node.input_iter;
- const bool is_dead = tagged_node.is_dead;
-
- // Propagates outputs along out edges, and puts newly ready nodes
- // into the ready queue.
- DCHECK(ready->empty());
- bool is_frame_done = false;
- FrameState* output_frame = input_frame;
- int64 output_iter = input_iter;
-
- if (!item->is_enter_exit_or_next_iter) {
- // Fast path for nodes types that don't need special handling
- DCHECK_EQ(input_frame, output_frame);
- // Normal path for most nodes
- mutex_lock l(input_frame->mu);
- output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
- is_frame_done = input_frame->DecrementOutstandingOpsLocked(
- &immutable_state_.graph_view(), input_iter, ready);
- } else if (item->is_enter) {
- FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
- output_iter = 0;
- {
- mutex_lock l(output_frame->mu);
- if (item->is_constant_enter) {
- // Propagate to all active iterations if this is a loop invariant.
- output_frame->AddLoopInv(item, (*outputs)[0], ready);
- } else {
- output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
- }
- output_frame->num_pending_inputs--;
- }
- is_frame_done = input_frame->DecrementOutstandingOps(
- &immutable_state_.graph_view(), input_iter, ready);
- } else if (item->is_exit) {
- if (is_dead) {
- mutex_lock l(input_frame->mu);
- // Stop and remember this node if it is a dead exit.
- if (input_iter == input_frame->iteration_count) {
- input_frame->dead_exits.push_back(item);
- }
- is_frame_done = input_frame->DecrementOutstandingOpsLocked(
- &immutable_state_.graph_view(), input_iter, ready);
- } else {
- output_frame = input_frame->parent_frame;
- output_iter = input_frame->parent_iter;
- {
- mutex_lock l(output_frame->mu);
- output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
- }
- is_frame_done = input_frame->DecrementOutstandingOps(
- &immutable_state_.graph_view(), input_iter, ready);
- }
- } else {
- DCHECK(item->is_next_iteration);
- mutex_lock l(input_frame->mu);
- if (is_dead) {
- // Stop the deadness propagation.
- output_frame = nullptr;
- } else {
- if (input_iter == input_frame->iteration_count &&
- input_frame->num_outstanding_iterations ==
- input_frame->max_parallel_iterations) {
- // Reached the maximum for parallel iterations.
- input_frame->next_iter_roots.push_back({item, (*outputs)[0]});
- output_frame = nullptr;
- } else {
- // If this is a new iteration, start it.
- if (input_iter == input_frame->iteration_count) {
- input_frame->IncrementIteration(&immutable_state_.graph_view(),
- ready);
- }
- output_iter = input_iter + 1;
- }
- }
- if (output_frame != nullptr) {
- // This is the case when node is not Enter, Exit, or NextIteration.
- DCHECK(input_frame == output_frame);
- output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
- }
- is_frame_done = input_frame->DecrementOutstandingOpsLocked(
- &immutable_state_.graph_view(), input_iter, ready);
- }
-
- // At this point, this node is completely done. We also know if the
- // completion of this node makes its frame completed.
- if (is_frame_done) {
- FrameState* parent_frame = input_frame->parent_frame;
- const int64 parent_iter = input_frame->parent_iter;
- DeleteFrame(input_frame, ready);
- if (parent_frame != nullptr) {
- // The completion of frame may cause completions in its parent frame.
- // So clean things up recursively.
- CleanupFramesIterations(parent_frame, parent_iter, ready);
- }
- }
-}
-
-bool ExecutorState::NodeDone(const Status& s, TaggedNodeSeq* ready,
- NodeExecStatsInterface* stats,
- TaggedNodeReadyQueue* inline_ready) {
+template <class PropagatorStateType>
+bool ExecutorState<PropagatorStateType>::NodeDone(
+ const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats,
+ TaggedNodeReadyQueue* inline_ready) {
nodestats::SetAllEnd(stats);
if (stats) {
if (stats_collector_) {
@@ -1658,8 +1069,9 @@
return completed;
}
-void ExecutorState::ScheduleReady(TaggedNodeSeq* ready,
- TaggedNodeReadyQueue* inline_ready) {
+template <class PropagatorStateType>
+void ExecutorState<PropagatorStateType>::ScheduleReady(
+ TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
if (ready->empty()) return;
int64 scheduled_nsec = 0;
@@ -1693,7 +1105,7 @@
} else {
for (auto& tagged_node : *ready) {
const NodeItem& item = *tagged_node.node_item;
- if (tagged_node.is_dead || !kernel_stats_->IsExpensive(item)) {
+ if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) {
// Inline this inexpensive node.
inline_ready->push_back(tagged_node);
} else {
@@ -1721,135 +1133,8 @@
ready->clear();
}
-inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter,
- const int node_id) {
- // TODO(misard) Replace with a finer-grain enabling flag once we
- // add better optional debugging support.
- if (vlog_ && VLOG_IS_ON(1)) {
- mutex_lock l(frame->mu);
- frame->GetIteration(iter)->mark_completed(
- immutable_state_.pending_ids()[node_id]);
- }
-}
-
-const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) {
- switch (input.state) {
- case Entry::State::NO_VALUE:
- return kEmptyTensor;
- case Entry::State::HAS_VALUE:
- return input.val.get();
- case Entry::State::HAS_CONST_TENSOR:
- return input.const_tensor;
- case Entry::State::HAS_REF_TENSOR:
- return input.ref_tensor.tensor;
- }
-}
-
-void ExecutorState::DumpPendingNodeState(
- const int node_id, const Entry* input_vector,
- const bool show_nodes_with_no_ready_inputs) {
- const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
- const int input_base = node_item.input_start;
- if (!show_nodes_with_no_ready_inputs) {
- bool has_ready_input = false;
- for (int i = 0; i < node_item.num_inputs; ++i) {
- const Entry& input = input_vector[input_base + i];
- const Tensor* tensor = GetTensorValueForDump(input);
- if (tensor->IsInitialized()) {
- has_ready_input = true;
- break;
- }
- }
- if (!has_ready_input) {
- return;
- }
- }
- LOG(WARNING) << " Pending Node: " << node_item.DebugString();
- for (int i = 0; i < node_item.num_inputs; ++i) {
- const Entry& input = input_vector[input_base + i];
- const Tensor* tensor = GetTensorValueForDump(input);
- if (tensor->IsInitialized()) {
- LOG(WARNING) << " Input " << i << ": "
- << strings::StrCat(
- "Tensor<type: ", DataTypeString(tensor->dtype()),
- " shape: ", tensor->shape().DebugString(), ">");
- } else {
- LOG(WARNING) << " Input " << i << ": not present";
- }
- }
-}
-
-void ExecutorState::DumpActiveNodeState(const int node_id,
- const Entry* input_vector) {
- const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
- LOG(WARNING) << " Active Node: " << node_item.DebugString();
- const int input_base = node_item.input_start;
- for (int i = 0; i < node_item.num_inputs; ++i) {
- const Entry& input = input_vector[input_base + i];
- const Tensor* tensor = GetTensorValueForDump(input);
- if (tensor->IsInitialized()) {
- LOG(WARNING) << " Input " << i << ": "
- << strings::StrCat(
- "Tensor<type: ", DataTypeString(tensor->dtype()),
- " shape: ", tensor->shape().DebugString(), ">");
- } else {
- LOG(WARNING) << " Input " << i << ": not present";
- }
- }
-}
-
-void ExecutorState::DumpIterationState(const FrameState* frame,
- IterationState* iteration) {
- const std::vector<const NodeItem*>* nodes = frame->nodes;
- // Dump any waiting nodes that are holding on to tensors.
- for (const NodeItem* node : *nodes) {
- PendingCounts::Handle pending_id =
- immutable_state_.pending_ids()[node->node_id];
- if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
- iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
- DumpPendingNodeState(node->node_id, iteration->input_tensors, false);
- }
- }
- // Then the active nodes.
- for (const NodeItem* node : *nodes) {
- PendingCounts::Handle pending_id =
- immutable_state_.pending_ids()[node->node_id];
- if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
- DumpActiveNodeState(node->node_id, iteration->input_tensors);
- }
- }
- // Show all input tensors in use.
- const int total_input_tensors = frame->total_input_tensors;
- size_t total_bytes = 0;
- for (int i = 0; i < total_input_tensors; ++i) {
- const Entry& input = iteration->input_tensors[i];
- const Tensor* tensor = GetTensorValueForDump(input);
- if (tensor->IsInitialized()) {
- LOG(WARNING) << " Input " << i << ": "
- << strings::StrCat(
- "Tensor<type: ", DataTypeString(tensor->dtype()),
- " shape: ", tensor->shape().DebugString(),
- ", bytes: ", tensor->TotalBytes(), ">");
- total_bytes += tensor->TotalBytes();
- }
- }
- LOG(WARNING) << " Total bytes " << total_bytes;
-}
-
-void ExecutorState::DumpState() {
- mutex_lock l(mu_);
- if (!dumped_on_error_) {
- LOG(WARNING) << "Dumping state";
- for (auto& frame : outstanding_frames_) {
- LOG(WARNING) << frame.first;
- FrameState* frame_state = frame.second;
- frame_state->DumpIterationState(this);
- }
- dumped_on_error_ = true;
- }
-}
-
-void ExecutorState::ScheduleFinish() {
+template <class PropagatorStateType>
+void ExecutorState<PropagatorStateType>::ScheduleFinish() {
// Checks condition to decide if needs to invoke Finish(). If there are
// in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke
// Finish(). Otherwise, invoke Finish() directly.
@@ -1868,7 +1153,8 @@
Finish();
}
-void ExecutorState::Finish() {
+template <class PropagatorStateType>
+void ExecutorState<PropagatorStateType>::Finish() {
mu_.lock();
auto status = status_;
auto done_cb = std::move(done_cb_);
@@ -1962,447 +1248,8 @@
}
}
-void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
- const NodeItem& node_item,
- FrameState** child) {
- // Get the child frame name.
- AttrSlice attrs(node_item.kernel->def());
- const string& enter_name = GetNodeAttrString(attrs, "frame_name");
- DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
- << node_item.kernel->name();
- const string child_name = MakeFrameName(frame, iter, enter_name);
-
- {
- mutex_lock executor_lock(mu_);
- auto it = outstanding_frames_.find(child_name);
- if (it != outstanding_frames_.end()) {
- *child = it->second;
- return;
- }
- }
-
- // Need to create a new frame instance.
- // Note that this new frame instance is created without any locks.
- if (vlog_) VLOG(2) << "Create frame: " << child_name;
-
- int parallel_iters;
- bool found_parallel_iters =
- TryGetNodeAttr(attrs, "parallel_iterations", ¶llel_iters);
- DCHECK(found_parallel_iters)
- << "Could not find \"parallel_iterations\" attr in node "
- << node_item.kernel->name();
- FrameState* temp = new FrameState(immutable_state_, parallel_iters);
- temp->frame_name = child_name;
- temp->frame_id = Hash64(child_name);
- temp->parent_frame = frame;
- temp->parent_iter = iter;
- temp->InitializeFrameInfo(enter_name);
-
- // Initialize iteration 0.
- {
- mutex_lock l(temp->mu);
- temp->SetIteration(
- 0, new IterationState(temp->pending_counts, temp->total_input_tensors));
- }
-
- {
- mutex_lock executor_lock(mu_);
- auto it = outstanding_frames_.find(child_name);
- if (it != outstanding_frames_.end()) {
- *child = it->second;
- } else {
- mutex_lock frame_lock(frame->mu);
- frame->GetIteration(iter)->outstanding_frame_count++;
- outstanding_frames_[child_name] = temp;
- *child = temp;
- temp = nullptr;
- }
- }
- delete temp; // Not used so delete it.
-}
-
-void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
- // First, propagate dead_exits (if any) to the parent frame.
- FrameState* parent_frame = frame->parent_frame;
- const int64 parent_iter = frame->parent_iter;
- if (parent_frame != nullptr) {
- mutex_lock parent_frame_lock(parent_frame->mu);
- // Propagate all the dead exits to the parent frame.
- mutex_lock this_frame_lock(frame->mu);
-
- for (const NodeItem* item : frame->dead_exits) {
- auto parent_iter_state = parent_frame->GetIteration(parent_iter);
-
- auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
- bool dst_dead) {
- if (dst_ready) {
- if (dst_item.is_control_trigger) dst_dead = false;
- ready->emplace_back(&dst_item, parent_frame, parent_iter, dst_dead);
- parent_iter_state->outstanding_ops++;
- }
- };
-
- auto propagate_to_non_merge = [&](PendingCounts::Handle dst_pending_id) {
- parent_iter_state->increment_dead_count(dst_pending_id);
- return parent_iter_state->decrement_pending(dst_pending_id, 1) == 0;
- };
-
- for (const EdgeInfo& e : item->output_edges()) {
- const NodeItem& dst_item =
- *immutable_state_.graph_view().node(e.dst_id);
- const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
-
- bool dst_dead = true;
- bool dst_ready;
- // We know this is a dead input to dst.
- if (dst_item.is_merge) {
- parent_iter_state->increment_dead_count(dst_pending_id);
- const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
- dst_dead = (dead_cnt == dst_item.num_inputs);
- dst_ready =
- (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
- } else {
- dst_ready = propagate_to_non_merge(dst_pending_id);
- }
- maybe_add_to_ready(dst_item, dst_ready, dst_dead);
- }
-
- for (const ControlEdgeInfo& e : item->output_control_edges()) {
- const NodeItem& dst_item =
- *immutable_state_.graph_view().node(e.dst_id);
- const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
-
- bool dst_dead;
- bool dst_ready;
- // We know this is a dead input to dst.
- if (dst_item.is_merge) {
- parent_iter_state->decrement_pending(dst_pending_id, 2);
- int count = parent_iter_state->pending(dst_pending_id);
- int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
- dst_dead = (dead_cnt == dst_item.num_inputs);
- dst_ready = (count == 0) || ((count == 1) && dst_dead);
- } else {
- dst_dead = true;
- dst_ready = propagate_to_non_merge(dst_pending_id);
- }
- maybe_add_to_ready(dst_item, dst_ready, dst_dead);
- }
- }
- }
-
- // Delete the frame.
- const string& frame_name = frame->frame_name;
- if (vlog_) VLOG(2) << "Delete frame " << frame_name;
- {
- mutex_lock executor_lock(mu_);
- outstanding_frames_.erase(frame_name);
- }
- delete frame;
-}
-
-void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
- TaggedNodeSeq* ready) {
- bool is_frame_done = false;
- {
- mutex_lock frame_lock(frame->mu);
- frame->GetIteration(iter)->outstanding_frame_count--;
- is_frame_done =
- frame->CleanupIterations(&immutable_state_.graph_view(), iter, ready);
- }
- if (is_frame_done) {
- FrameState* parent_frame = frame->parent_frame;
- const int64 parent_iter = frame->parent_iter;
- DeleteFrame(frame, ready);
- if (parent_frame != nullptr) {
- // The completion of frame may cause completions in its parent frame.
- // So clean things up recursively.
- CleanupFramesIterations(parent_frame, parent_iter, ready);
- }
- }
-}
-
-void ExecutorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
- const bool is_dead,
- int64 iter,
- EntryVector* outputs,
- TaggedNodeSeq* ready) {
- // If we know that none of the item's edge destinations require special
- // handling (i.e. none of the nodes is a merge or control trigger node), we
- // can take a fast path that avoids accessing the destination NodeItem.
- const GraphView& gview = immutable_state.graph_view();
- IterationState* iter_state = GetIteration(iter);
-
-// Add dst to the ready queue if it's ready
-//
-// NOTE(mrry): Use a macro here instead of a lambda, because this method is
-// performance-critical and we need to ensure that the code is inlined.
-#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
- do { \
- if (!adjust_result.any_pending) { \
- const NodeItem* dst_item = gview.node(dst_id); \
- TaggedNode& t = ready->emplace_back(); \
- t.node_item = dst_item; \
- t.input_frame = this; \
- t.input_iter = iter; \
- t.is_dead = adjust_result.any_dead; \
- iter_state->outstanding_ops++; \
- } \
- } while (0);
-
- Entry* input_tensors = iter_state->input_tensors;
-
- for (const EdgeInfo& e : item->output_edges()) {
- const int dst_id = e.dst_id;
- const PendingCounts::Handle dst_pending_id =
- immutable_state.pending_ids()[dst_id];
- const int src_slot = e.output_slot;
-
- const bool increment_dead =
- (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
- const PendingCounts::AdjustResult adjust_result =
- iter_state->adjust_for_activation(dst_pending_id, increment_dead);
- const int dst_loc = e.input_slot;
- if (e.is_last) {
- input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
- } else {
- input_tensors[dst_loc] = (*outputs)[src_slot];
- }
- MAYBE_ADD_TO_READY(dst_id, adjust_result);
- }
-
- for (const ControlEdgeInfo& e : item->output_control_edges()) {
- const int dst_id = e.dst_id;
- const PendingCounts::Handle dst_pending_id =
- immutable_state.pending_ids()[dst_id];
- const PendingCounts::AdjustResult adjust_result =
- iter_state->adjust_for_activation(dst_pending_id, is_dead);
- MAYBE_ADD_TO_READY(dst_id, adjust_result);
- }
-#undef MAYBE_ADD_TO_READY
-}
-
-void ExecutorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
- const bool is_dead,
- int64 iter,
- EntryVector* outputs,
- TaggedNodeSeq* ready) {
- // If any of the edge destinations is a merge or a control trigger node,
- // we need to read each destination NodeItem to determine what action
- // to take.
- const GraphView& gview = immutable_state.graph_view();
- IterationState* iter_state = GetIteration(iter);
-
- auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
- bool dst_ready, bool dst_dead) {
- // Add dst to the ready queue if it's ready
- if (dst_ready) {
- if (dst_item->is_control_trigger) dst_dead = false;
- ready->emplace_back(dst_item, this, iter, dst_dead);
- iter_state->outstanding_ops++;
- }
- };
-
- Entry* input_tensors = iter_state->input_tensors;
-
- for (const EdgeInfo& e : item->output_edges()) {
- const int dst_id = e.dst_id;
- const NodeItem* dst_item = gview.node(dst_id);
- const PendingCounts::Handle dst_pending_id =
- immutable_state.pending_ids()[dst_id];
- const int src_slot = e.output_slot;
-
- bool dst_dead = false;
- bool dst_ready = false;
- bool dst_need_input = true;
-
- if (dst_item->is_merge) {
- // A merge node is ready if all control inputs have arrived and either
- // a) a live data input becomes available or b) all data inputs are
- // dead. For Merge, pending's LSB is set iff a live data input has
- // arrived.
- if ((*outputs)[src_slot].state != Entry::State::NO_VALUE) {
- // This is a live data input.
- int count = iter_state->pending(dst_pending_id);
- iter_state->mark_live(dst_pending_id);
- // Only the first live edge sets the input and (potentially)
- // triggers execution. The low bit of count is set if and
- // only if no live input has been used yet (mark_live clears
- // it). The node should be started if and only if this is
- // the first live input and there are no pending control
- // edges, i.e. count == 1.
- dst_ready = (count == 1);
- dst_need_input = ((count & 0x1) == 1);
- } else {
- // This is a dead data input. Note that dst_node is dead if node is
- // a dead enter. We need this to handle properly a while loop on
- // the untaken branch of a conditional.
- // TODO(yuanbyu): This is a bit hacky, but a good solution for
- // now.
- iter_state->increment_dead_count(dst_pending_id);
- const int dead_cnt = iter_state->dead_count(dst_pending_id);
- dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
- dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
- dst_need_input = false;
- }
- } else {
- // Handle all other (non-merge) nodes.
- const bool increment_dead =
- (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
- const PendingCounts::AdjustResult adjust_result =
- iter_state->adjust_for_activation(dst_pending_id, increment_dead);
- dst_dead = adjust_result.any_dead;
- dst_ready = !adjust_result.any_pending;
- }
-
- if (dst_need_input) {
- const int dst_loc = e.input_slot;
- if (e.is_last) {
- input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
- } else {
- input_tensors[dst_loc] = (*outputs)[src_slot];
- }
- }
-
- maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
- }
-
- for (const ControlEdgeInfo& e : item->output_control_edges()) {
- const int dst_id = e.dst_id;
- const NodeItem* dst_item = gview.node(dst_id);
- const PendingCounts::Handle dst_pending_id =
- immutable_state.pending_ids()[dst_id];
-
- bool dst_dead;
- bool dst_ready;
- if (dst_item->is_merge) {
- // A merge node is ready if all control inputs have arrived and either
- // a) a live data input becomes available or b) all data inputs are
- // dead. For Merge, pending's LSB is set iff a live data input has
- // arrived.
- iter_state->decrement_pending(dst_pending_id, 2);
- int count = iter_state->pending(dst_pending_id);
- int dead_cnt = iter_state->dead_count(dst_pending_id);
- dst_dead = (dead_cnt == dst_item->num_inputs);
- dst_ready = (count == 0) || ((count == 1) && dst_dead);
- } else {
- // Handle all other (non-merge) nodes.
- const PendingCounts::AdjustResult adjust_result =
- iter_state->adjust_for_activation(dst_pending_id, is_dead);
- dst_dead = adjust_result.any_dead;
- dst_ready = !adjust_result.any_pending;
- }
- maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
- }
-}
-
-void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
- const bool is_dead, int64 iter,
- EntryVector* outputs,
- TaggedNodeSeq* ready) {
- if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
- ActivateNodesSlowPath(item, is_dead, iter, outputs, ready);
- } else {
- ActivateNodesFastPath(item, is_dead, iter, outputs, ready);
- }
-}
-
-void ExecutorState::FrameState::ActivateNexts(const GraphView* gview,
- int64 iter,
- TaggedNodeSeq* ready) {
- // Propagate the deferred NextIteration nodes to the new iteration.
- for (auto& node_entry : next_iter_roots) {
- const NodeItem* item = node_entry.first;
- const Entry& entry = node_entry.second;
- const bool is_dead = entry.state == Entry::State::NO_VALUE;
- EntryVector outputs{entry};
- ActivateNodes(item, is_dead, iter, &outputs, ready);
- }
- next_iter_roots.clear();
-}
-
-void ExecutorState::FrameState::ActivateLoopInvs(const GraphView* gview,
- int64 iter,
- TaggedNodeSeq* ready) {
- // Propagate loop invariants to the new iteration.
- for (auto& node_entry : inv_values) {
- const NodeItem* item = node_entry.first;
- const Entry& entry = node_entry.second;
- const bool is_dead = entry.state == Entry::State::NO_VALUE;
- EntryVector outputs{entry};
- ActivateNodes(item, is_dead, iter, &outputs, ready);
- }
-}
-
-void ExecutorState::FrameState::AddLoopInv(const NodeItem* item,
- const Entry& entry,
- TaggedNodeSeq* ready) {
- // Store this value.
- inv_values.push_back({item, entry});
-
- // Make this value available to all iterations.
- const bool is_dead = entry.state == Entry::State::NO_VALUE;
- for (int i = 0; i <= iteration_count; ++i) {
- EntryVector outputs{entry};
- ActivateNodes(item, is_dead, i, &outputs, ready);
- }
-}
-
-bool ExecutorState::FrameState::IsIterationDone(int64 iter) {
- IterationState* iter_state = GetIteration(iter);
- if (iter_state->outstanding_ops == 0 &&
- iter_state->outstanding_frame_count == 0) {
- if (iter == 0) {
- // The enclosing frame has no pending input.
- return num_pending_inputs == 0;
- } else {
- // The preceding iteration is deleted (and therefore done).
- return (GetIteration(iter - 1) == nullptr);
- }
- }
- return false;
-}
-
-void ExecutorState::FrameState::IncrementIteration(const GraphView* gview,
- TaggedNodeSeq* ready) {
- iteration_count++;
- const int64 next_iter = iteration_count;
-
- // Initialize the next iteration.
- IterationState* iter_state =
- new IterationState(pending_counts, total_input_tensors);
- SetIteration(next_iter, iter_state);
- num_outstanding_iterations++;
- dead_exits.clear();
-
- // Activate the successors of the deferred roots in the new iteration.
- ActivateNexts(gview, next_iter, ready);
-
- // Activate the loop invariants in the new iteration.
- ActivateLoopInvs(gview, next_iter, ready);
-}
-
-bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview,
- int64 iter,
- TaggedNodeSeq* ready) {
- int64 curr_iter = iter;
- while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) {
- // Delete the iteration curr_iter.
- delete GetIteration(curr_iter);
- SetIteration(curr_iter, nullptr);
- --num_outstanding_iterations;
- ++curr_iter;
-
- // When one iteration is completed, we check for deferred iteration,
- // and start it if there is one.
- if (!next_iter_roots.empty()) {
- IncrementIteration(gview, ready);
- }
- }
- return IsFrameDone();
-}
-
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
- (new ExecutorState(args, immutable_state_, &kernel_stats_))
+ (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
->RunAsync(std::move(done));
}
diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc
new file mode 100644
index 0000000..e2827a8
--- /dev/null
+++ b/tensorflow/core/common_runtime/propagator_state.cc
@@ -0,0 +1,777 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/propagator_state.h"
+
+#include "tensorflow/core/common_runtime/graph_view.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+
+namespace tensorflow {
+
+// 1-D, 0 element tensor.
+static const Tensor* const kEmptyTensor = new Tensor;
+
+typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
+typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
+
+PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
+ int64 step_id)
+ : immutable_state_(immutable_state),
+ step_id_(step_id),
+ vlog_(VLOG_IS_ON(1)) {
+ // We start the entire execution in iteration 0 of the root frame
+ // so let us create the root frame and the state for iteration 0.
+ // We assume root_frame_->frame_name.empty().
+ root_frame_ = new FrameState(immutable_state_, 1);
+ root_frame_->frame_id = 0; // must be 0
+ root_frame_->InitializeFrameInfo(root_frame_->frame_name);
+
+ // Initialize iteration 0.
+ root_frame_->SetIteration(
+ 0, new PropagatorState::IterationState(root_frame_->pending_counts,
+ root_frame_->total_input_tensors));
+
+ outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
+}
+
+PropagatorState::~PropagatorState() {
+ for (auto name_frame : outstanding_frames_) {
+ delete name_frame.second;
+ }
+}
+
+void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
+ TaggedNodeSeq* ready) {
+ for (const NodeItem* item : roots) {
+ DCHECK_EQ(item->num_inputs, 0);
+ ready->push_back(TaggedNode{item, root_frame_, 0, false});
+ }
+ mutex_lock l(root_frame_->mu);
+ root_frame_->GetIteration(0)->outstanding_ops = ready->size();
+}
+
+void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
+ EntryVector* outputs,
+ TaggedNodeSeq* ready) {
+ profiler::TraceMe activity(
+ [&]() {
+ return strings::StrCat(
+ "ExecutorPropagateOutputs#", "id=", step_id_,
+ ",kernel_name=", tagged_node.node_item->kernel->name_view(),
+ ",num_output_edges=", tagged_node.node_item->num_output_edges,
+ ",num_output_control_edges=",
+ tagged_node.node_item->num_output_control_edges, "#");
+ },
+ profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
+
+ const NodeItem* const item = tagged_node.node_item;
+ FrameState* const input_frame = tagged_node.input_frame;
+ const int64 input_iter = tagged_node.input_iter;
+ const bool is_dead = tagged_node.is_dead;
+
+ // Propagates outputs along out edges, and puts newly ready nodes
+ // into the ready queue.
+ DCHECK(ready->empty());
+ bool is_frame_done = false;
+ FrameState* output_frame = input_frame;
+ int64 output_iter = input_iter;
+
+ if (!item->is_enter_exit_or_next_iter) {
+ // Fast path for nodes types that don't need special handling
+ DCHECK_EQ(input_frame, output_frame);
+ // Normal path for most nodes
+ mutex_lock l(input_frame->mu);
+ output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
+ is_frame_done =
+ input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
+ } else if (item->is_enter) {
+ FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
+ output_iter = 0;
+ {
+ mutex_lock l(output_frame->mu);
+ if (item->is_constant_enter) {
+ // Propagate to all active iterations if this is a loop invariant.
+ output_frame->AddLoopInv(item, (*outputs)[0], ready);
+ } else {
+ output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
+ }
+ output_frame->num_pending_inputs--;
+ }
+ is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
+ } else if (item->is_exit) {
+ if (is_dead) {
+ mutex_lock l(input_frame->mu);
+ // Stop and remember this node if it is a dead exit.
+ if (input_iter == input_frame->iteration_count) {
+ input_frame->dead_exits.push_back(item);
+ }
+ is_frame_done =
+ input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
+ } else {
+ output_frame = input_frame->parent_frame;
+ output_iter = input_frame->parent_iter;
+ {
+ mutex_lock l(output_frame->mu);
+ output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
+ }
+ is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
+ }
+ } else {
+ DCHECK(item->is_next_iteration);
+ mutex_lock l(input_frame->mu);
+ if (is_dead) {
+ // Stop the deadness propagation.
+ output_frame = nullptr;
+ } else {
+ if (input_iter == input_frame->iteration_count &&
+ input_frame->num_outstanding_iterations ==
+ input_frame->max_parallel_iterations) {
+ // Reached the maximum for parallel iterations.
+ input_frame->next_iter_roots.push_back({item, (*outputs)[0]});
+ output_frame = nullptr;
+ } else {
+ // If this is a new iteration, start it.
+ if (input_iter == input_frame->iteration_count) {
+ input_frame->IncrementIteration(ready);
+ }
+ output_iter = input_iter + 1;
+ }
+ }
+ if (output_frame != nullptr) {
+ // This is the case when node is not Enter, Exit, or NextIteration.
+ DCHECK(input_frame == output_frame);
+ output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
+ }
+ is_frame_done =
+ input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
+ }
+
+ // At this point, this node is completely done. We also know if the
+ // completion of this node makes its frame completed.
+ if (is_frame_done) {
+ FrameState* parent_frame = input_frame->parent_frame;
+ const int64 parent_iter = input_frame->parent_iter;
+ DeleteFrame(input_frame, ready);
+ if (parent_frame != nullptr) {
+ // The completion of frame may cause completions in its parent frame.
+ // So clean things up recursively.
+ CleanupFramesIterations(parent_frame, parent_iter, ready);
+ }
+ }
+}
+
+const Tensor* PropagatorState::GetTensorValueForDump(const Entry& input) {
+ switch (input.state) {
+ case Entry::State::NO_VALUE:
+ return kEmptyTensor;
+ case Entry::State::HAS_VALUE:
+ return input.val.get();
+ case Entry::State::HAS_CONST_TENSOR:
+ return input.const_tensor;
+ case Entry::State::HAS_REF_TENSOR:
+ return input.ref_tensor.tensor;
+ }
+}
+
+void PropagatorState::DumpPendingNodeState(
+ const int node_id, const Entry* input_vector,
+ const bool show_nodes_with_no_ready_inputs) {
+ const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
+ const int input_base = node_item.input_start;
+ if (!show_nodes_with_no_ready_inputs) {
+ bool has_ready_input = false;
+ for (int i = 0; i < node_item.num_inputs; ++i) {
+ const Entry& input = input_vector[input_base + i];
+ const Tensor* tensor = GetTensorValueForDump(input);
+ if (tensor->IsInitialized()) {
+ has_ready_input = true;
+ break;
+ }
+ }
+ if (!has_ready_input) {
+ return;
+ }
+ }
+ LOG(WARNING) << " Pending Node: " << node_item.DebugString();
+ for (int i = 0; i < node_item.num_inputs; ++i) {
+ const Entry& input = input_vector[input_base + i];
+ const Tensor* tensor = GetTensorValueForDump(input);
+ if (tensor->IsInitialized()) {
+ LOG(WARNING) << " Input " << i << ": "
+ << strings::StrCat(
+ "Tensor<type: ", DataTypeString(tensor->dtype()),
+ " shape: ", tensor->shape().DebugString(), ">");
+ } else {
+ LOG(WARNING) << " Input " << i << ": not present";
+ }
+ }
+}
+
+void PropagatorState::DumpActiveNodeState(const int node_id,
+ const Entry* input_vector) {
+ const NodeItem& node_item = *immutable_state_.graph_view().node(node_id);
+ LOG(WARNING) << " Active Node: " << node_item.DebugString();
+ const int input_base = node_item.input_start;
+ for (int i = 0; i < node_item.num_inputs; ++i) {
+ const Entry& input = input_vector[input_base + i];
+ const Tensor* tensor = GetTensorValueForDump(input);
+ if (tensor->IsInitialized()) {
+ LOG(WARNING) << " Input " << i << ": "
+ << strings::StrCat(
+ "Tensor<type: ", DataTypeString(tensor->dtype()),
+ " shape: ", tensor->shape().DebugString(), ">");
+ } else {
+ LOG(WARNING) << " Input " << i << ": not present";
+ }
+ }
+}
+
+void PropagatorState::DumpIterationState(const FrameState* frame,
+ IterationState* iteration) {
+ const std::vector<const NodeItem*>* nodes = frame->nodes;
+ // Dump any waiting nodes that are holding on to tensors.
+ for (const NodeItem* node : *nodes) {
+ PendingCounts::Handle pending_id =
+ immutable_state_.pending_ids()[node->node_id];
+ if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
+ iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
+ DumpPendingNodeState(node->node_id, iteration->input_tensors, false);
+ }
+ }
+ // Then the active nodes.
+ for (const NodeItem* node : *nodes) {
+ PendingCounts::Handle pending_id =
+ immutable_state_.pending_ids()[node->node_id];
+ if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
+ DumpActiveNodeState(node->node_id, iteration->input_tensors);
+ }
+ }
+ // Show all input tensors in use.
+ const int total_input_tensors = frame->total_input_tensors;
+ size_t total_bytes = 0;
+ for (int i = 0; i < total_input_tensors; ++i) {
+ const Entry& input = iteration->input_tensors[i];
+ const Tensor* tensor = GetTensorValueForDump(input);
+ if (tensor->IsInitialized()) {
+ LOG(WARNING) << " Input " << i << ": "
+ << strings::StrCat(
+ "Tensor<type: ", DataTypeString(tensor->dtype()),
+ " shape: ", tensor->shape().DebugString(),
+ ", bytes: ", tensor->TotalBytes(), ">");
+ total_bytes += tensor->TotalBytes();
+ }
+ }
+ LOG(WARNING) << " Total bytes " << total_bytes;
+}
+
+void PropagatorState::DumpState() {
+ mutex_lock l(mu_);
+ if (!dumped_on_error_) {
+ LOG(WARNING) << "Dumping state";
+ for (auto& frame : outstanding_frames_) {
+ LOG(WARNING) << frame.first;
+ FrameState* frame_state = frame.second;
+ frame_state->DumpIterationState(this);
+ }
+ dumped_on_error_ = true;
+ }
+}
+
+void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
+ const NodeItem& node_item,
+ FrameState** child) {
+ // Get the child frame name.
+ AttrSlice attrs(node_item.kernel->def());
+ const string& enter_name = GetNodeAttrString(attrs, "frame_name");
+ DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
+ << node_item.kernel->name();
+ const string child_name =
+ strings::StrCat(frame->frame_name, ";", iter, ";", enter_name);
+
+ {
+ mutex_lock executor_lock(mu_);
+ auto it = outstanding_frames_.find(child_name);
+ if (it != outstanding_frames_.end()) {
+ *child = it->second;
+ return;
+ }
+ }
+
+ // Need to create a new frame instance.
+ // Note that this new frame instance is created without any locks.
+ if (vlog_) VLOG(2) << "Create frame: " << child_name;
+
+ int parallel_iters;
+ bool found_parallel_iters =
+ TryGetNodeAttr(attrs, "parallel_iterations", ¶llel_iters);
+ DCHECK(found_parallel_iters)
+ << "Could not find \"parallel_iterations\" attr in node "
+ << node_item.kernel->name();
+ FrameState* temp = new FrameState(immutable_state_, parallel_iters);
+ temp->frame_name = child_name;
+ temp->frame_id = Hash64(child_name);
+ temp->parent_frame = frame;
+ temp->parent_iter = iter;
+ temp->InitializeFrameInfo(enter_name);
+
+ // Initialize iteration 0.
+ {
+ mutex_lock l(temp->mu);
+ temp->SetIteration(
+ 0, new IterationState(temp->pending_counts, temp->total_input_tensors));
+ }
+
+ {
+ mutex_lock executor_lock(mu_);
+ auto it = outstanding_frames_.find(child_name);
+ if (it != outstanding_frames_.end()) {
+ *child = it->second;
+ } else {
+ mutex_lock frame_lock(frame->mu);
+ frame->GetIteration(iter)->outstanding_frame_count++;
+ outstanding_frames_[child_name] = temp;
+ *child = temp;
+ temp = nullptr;
+ }
+ }
+ delete temp; // Not used so delete it.
+}
+
+void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
+ // First, propagate dead_exits (if any) to the parent frame.
+ FrameState* parent_frame = frame->parent_frame;
+ const int64 parent_iter = frame->parent_iter;
+ if (parent_frame != nullptr) {
+ mutex_lock parent_frame_lock(parent_frame->mu);
+ // Propagate all the dead exits to the parent frame.
+ mutex_lock this_frame_lock(frame->mu);
+
+ for (const NodeItem* item : frame->dead_exits) {
+ auto parent_iter_state = parent_frame->GetIteration(parent_iter);
+
+ auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
+ bool dst_dead) {
+ if (dst_ready) {
+ if (dst_item.is_control_trigger) dst_dead = false;
+ ready->emplace_back(&dst_item, parent_frame, parent_iter, dst_dead);
+ parent_iter_state->outstanding_ops++;
+ }
+ };
+
+ auto propagate_to_non_merge = [&](PendingCounts::Handle dst_pending_id) {
+ parent_iter_state->increment_dead_count(dst_pending_id);
+ return parent_iter_state->decrement_pending(dst_pending_id, 1) == 0;
+ };
+
+ for (const EdgeInfo& e : item->output_edges()) {
+ const NodeItem& dst_item =
+ *immutable_state_.graph_view().node(e.dst_id);
+ const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
+
+ bool dst_dead = true;
+ bool dst_ready;
+ // We know this is a dead input to dst.
+ if (dst_item.is_merge) {
+ parent_iter_state->increment_dead_count(dst_pending_id);
+ const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
+ dst_dead = (dead_cnt == dst_item.num_inputs);
+ dst_ready =
+ (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
+ } else {
+ dst_ready = propagate_to_non_merge(dst_pending_id);
+ }
+ maybe_add_to_ready(dst_item, dst_ready, dst_dead);
+ }
+
+ for (const ControlEdgeInfo& e : item->output_control_edges()) {
+ const NodeItem& dst_item =
+ *immutable_state_.graph_view().node(e.dst_id);
+ const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
+
+ bool dst_dead;
+ bool dst_ready;
+ // We know this is a dead input to dst.
+ if (dst_item.is_merge) {
+ parent_iter_state->decrement_pending(dst_pending_id, 2);
+ int count = parent_iter_state->pending(dst_pending_id);
+ int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
+ dst_dead = (dead_cnt == dst_item.num_inputs);
+ dst_ready = (count == 0) || ((count == 1) && dst_dead);
+ } else {
+ dst_dead = true;
+ dst_ready = propagate_to_non_merge(dst_pending_id);
+ }
+ maybe_add_to_ready(dst_item, dst_ready, dst_dead);
+ }
+ }
+ }
+
+ // Delete the frame.
+ const string& frame_name = frame->frame_name;
+ if (vlog_) VLOG(2) << "Delete frame " << frame_name;
+ {
+ mutex_lock executor_lock(mu_);
+ outstanding_frames_.erase(frame_name);
+ }
+ delete frame;
+}
+
+void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready) {
+ bool is_frame_done = false;
+ {
+ mutex_lock frame_lock(frame->mu);
+ frame->GetIteration(iter)->outstanding_frame_count--;
+ is_frame_done = frame->CleanupIterations(iter, ready);
+ }
+ if (is_frame_done) {
+ FrameState* parent_frame = frame->parent_frame;
+ const int64 parent_iter = frame->parent_iter;
+ DeleteFrame(frame, ready);
+ if (parent_frame != nullptr) {
+ // The completion of frame may cause completions in its parent frame.
+ // So clean things up recursively.
+ CleanupFramesIterations(parent_frame, parent_iter, ready);
+ }
+ }
+}
+
+void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
+ const bool is_dead,
+ int64 iter,
+ EntryVector* outputs,
+ TaggedNodeSeq* ready) {
+ // If we know that none of the item's edge destinations require special
+ // handling (i.e. none of the nodes is a merge or control trigger node), we
+ // can take a fast path that avoids accessing the destination NodeItem.
+ const GraphView& gview = immutable_state.graph_view();
+ IterationState* iter_state = GetIteration(iter);
+
+// Add dst to the ready queue if it's ready
+//
+// NOTE(mrry): Use a macro here instead of a lambda, because this method is
+// performance-critical and we need to ensure that the code is inlined.
+#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
+ do { \
+ if (!adjust_result.any_pending) { \
+ const NodeItem* dst_item = gview.node(dst_id); \
+ TaggedNode& t = ready->emplace_back(); \
+ t.node_item = dst_item; \
+ t.input_frame = this; \
+ t.input_iter = iter; \
+ t.is_dead = adjust_result.any_dead; \
+ iter_state->outstanding_ops++; \
+ } \
+ } while (0);
+
+ Entry* input_tensors = iter_state->input_tensors;
+
+ for (const EdgeInfo& e : item->output_edges()) {
+ const int dst_id = e.dst_id;
+ const PendingCounts::Handle dst_pending_id =
+ immutable_state.pending_ids()[dst_id];
+ const int src_slot = e.output_slot;
+
+ const bool increment_dead =
+ (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
+ const PendingCounts::AdjustResult adjust_result =
+ iter_state->adjust_for_activation(dst_pending_id, increment_dead);
+ const int dst_loc = e.input_slot;
+ if (e.is_last) {
+ input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
+ } else {
+ input_tensors[dst_loc] = (*outputs)[src_slot];
+ }
+ MAYBE_ADD_TO_READY(dst_id, adjust_result);
+ }
+
+ for (const ControlEdgeInfo& e : item->output_control_edges()) {
+ const int dst_id = e.dst_id;
+ const PendingCounts::Handle dst_pending_id =
+ immutable_state.pending_ids()[dst_id];
+ const PendingCounts::AdjustResult adjust_result =
+ iter_state->adjust_for_activation(dst_pending_id, is_dead);
+ MAYBE_ADD_TO_READY(dst_id, adjust_result);
+ }
+#undef MAYBE_ADD_TO_READY
+}
+
+void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
+ const bool is_dead,
+ int64 iter,
+ EntryVector* outputs,
+ TaggedNodeSeq* ready) {
+ // If any of the edge destinations is a merge or a control trigger node,
+ // we need to read each destination NodeItem to determine what action
+ // to take.
+ const GraphView& gview = immutable_state.graph_view();
+ IterationState* iter_state = GetIteration(iter);
+
+ auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
+ bool dst_ready, bool dst_dead) {
+ // Add dst to the ready queue if it's ready
+ if (dst_ready) {
+ if (dst_item->is_control_trigger) dst_dead = false;
+ ready->emplace_back(dst_item, this, iter, dst_dead);
+ iter_state->outstanding_ops++;
+ }
+ };
+
+ Entry* input_tensors = iter_state->input_tensors;
+
+ for (const EdgeInfo& e : item->output_edges()) {
+ const int dst_id = e.dst_id;
+ const NodeItem* dst_item = gview.node(dst_id);
+ const PendingCounts::Handle dst_pending_id =
+ immutable_state.pending_ids()[dst_id];
+ const int src_slot = e.output_slot;
+
+ bool dst_dead = false;
+ bool dst_ready = false;
+ bool dst_need_input = true;
+
+ if (dst_item->is_merge) {
+ // A merge node is ready if all control inputs have arrived and either
+ // a) a live data input becomes available or b) all data inputs are
+ // dead. For Merge, pending's LSB is set iff a live data input has
+ // arrived.
+ if ((*outputs)[src_slot].state != Entry::State::NO_VALUE) {
+ // This is a live data input.
+ int count = iter_state->pending(dst_pending_id);
+ iter_state->mark_live(dst_pending_id);
+ // Only the first live edge sets the input and (potentially)
+ // triggers execution. The low bit of count is set if and
+ // only if no live input has been used yet (mark_live clears
+ // it). The node should be started if and only if this is
+ // the first live input and there are no pending control
+ // edges, i.e. count == 1.
+ dst_ready = (count == 1);
+ dst_need_input = ((count & 0x1) == 1);
+ } else {
+ // This is a dead data input. Note that dst_node is dead if node is
+ // a dead enter. We need this to handle properly a while loop on
+ // the untaken branch of a conditional.
+ // TODO(yuanbyu): This is a bit hacky, but a good solution for
+ // now.
+ iter_state->increment_dead_count(dst_pending_id);
+ const int dead_cnt = iter_state->dead_count(dst_pending_id);
+ dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
+ dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
+ dst_need_input = false;
+ }
+ } else {
+ // Handle all other (non-merge) nodes.
+ const bool increment_dead =
+ (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
+ const PendingCounts::AdjustResult adjust_result =
+ iter_state->adjust_for_activation(dst_pending_id, increment_dead);
+ dst_dead = adjust_result.any_dead;
+ dst_ready = !adjust_result.any_pending;
+ }
+
+ if (dst_need_input) {
+ const int dst_loc = e.input_slot;
+ if (e.is_last) {
+ input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
+ } else {
+ input_tensors[dst_loc] = (*outputs)[src_slot];
+ }
+ }
+
+ maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
+ }
+
+ for (const ControlEdgeInfo& e : item->output_control_edges()) {
+ const int dst_id = e.dst_id;
+ const NodeItem* dst_item = gview.node(dst_id);
+ const PendingCounts::Handle dst_pending_id =
+ immutable_state.pending_ids()[dst_id];
+
+ bool dst_dead;
+ bool dst_ready;
+ if (dst_item->is_merge) {
+ // A merge node is ready if all control inputs have arrived and either
+ // a) a live data input becomes available or b) all data inputs are
+ // dead. For Merge, pending's LSB is set iff a live data input has
+ // arrived.
+ iter_state->decrement_pending(dst_pending_id, 2);
+ int count = iter_state->pending(dst_pending_id);
+ int dead_cnt = iter_state->dead_count(dst_pending_id);
+ dst_dead = (dead_cnt == dst_item->num_inputs);
+ dst_ready = (count == 0) || ((count == 1) && dst_dead);
+ } else {
+ // Handle all other (non-merge) nodes.
+ const PendingCounts::AdjustResult adjust_result =
+ iter_state->adjust_for_activation(dst_pending_id, is_dead);
+ dst_dead = adjust_result.any_dead;
+ dst_ready = !adjust_result.any_pending;
+ }
+ maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
+ }
+}
+
+void PropagatorState::FrameState::ActivateNodes(const NodeItem* item,
+ const bool is_dead, int64 iter,
+ EntryVector* outputs,
+ TaggedNodeSeq* ready) {
+ if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
+ ActivateNodesSlowPath(item, is_dead, iter, outputs, ready);
+ } else {
+ ActivateNodesFastPath(item, is_dead, iter, outputs, ready);
+ }
+}
+
+void PropagatorState::FrameState::ActivateNexts(int64 iter,
+ TaggedNodeSeq* ready) {
+ // Propagate the deferred NextIteration nodes to the new iteration.
+ for (auto& node_entry : next_iter_roots) {
+ const NodeItem* item = node_entry.first;
+ const Entry& entry = node_entry.second;
+ const bool is_dead = entry.state == Entry::State::NO_VALUE;
+ EntryVector outputs{entry};
+ ActivateNodes(item, is_dead, iter, &outputs, ready);
+ }
+ next_iter_roots.clear();
+}
+
+void PropagatorState::FrameState::ActivateLoopInvs(int64 iter,
+ TaggedNodeSeq* ready) {
+ // Propagate loop invariants to the new iteration.
+ for (auto& node_entry : inv_values) {
+ const NodeItem* item = node_entry.first;
+ const Entry& entry = node_entry.second;
+ const bool is_dead = entry.state == Entry::State::NO_VALUE;
+ EntryVector outputs{entry};
+ ActivateNodes(item, is_dead, iter, &outputs, ready);
+ }
+}
+
+void PropagatorState::FrameState::AddLoopInv(const NodeItem* item,
+ const Entry& entry,
+ TaggedNodeSeq* ready) {
+ // Store this value.
+ inv_values.push_back({item, entry});
+
+ // Make this value available to all iterations.
+ const bool is_dead = entry.state == Entry::State::NO_VALUE;
+ for (int i = 0; i <= iteration_count; ++i) {
+ EntryVector outputs{entry};
+ ActivateNodes(item, is_dead, i, &outputs, ready);
+ }
+}
+
+bool PropagatorState::FrameState::IsIterationDone(int64 iter) {
+ IterationState* iter_state = GetIteration(iter);
+ if (iter_state->outstanding_ops == 0 &&
+ iter_state->outstanding_frame_count == 0) {
+ if (iter == 0) {
+ // The enclosing frame has no pending input.
+ return num_pending_inputs == 0;
+ } else {
+ // The preceding iteration is deleted (and therefore done).
+ return (GetIteration(iter - 1) == nullptr);
+ }
+ }
+ return false;
+}
+
+void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
+ iteration_count++;
+ const int64 next_iter = iteration_count;
+
+ // Initialize the next iteration.
+ IterationState* iter_state =
+ new IterationState(pending_counts, total_input_tensors);
+ SetIteration(next_iter, iter_state);
+ num_outstanding_iterations++;
+ dead_exits.clear();
+
+ // Activate the successors of the deferred roots in the new iteration.
+ ActivateNexts(next_iter, ready);
+
+ // Activate the loop invariants in the new iteration.
+ ActivateLoopInvs(next_iter, ready);
+}
+
+bool PropagatorState::FrameState::CleanupIterations(int64 iter,
+ TaggedNodeSeq* ready) {
+ int64 curr_iter = iter;
+ while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) {
+ // Delete the iteration curr_iter.
+ delete GetIteration(curr_iter);
+ SetIteration(curr_iter, nullptr);
+ --num_outstanding_iterations;
+ ++curr_iter;
+
+ // When one iteration is completed, we check for deferred iteration,
+ // and start it if there is one.
+ if (!next_iter_roots.empty()) {
+ IncrementIteration(ready);
+ }
+ }
+ return IsFrameDone();
+}
+
+void PropagatorState::FrameState::InitializeFrameInfo(
+ const string& enter_name) {
+ const ImmutableExecutorState::FrameInfo* finfo =
+ immutable_state.get_frame_info(enter_name);
+ DCHECK_NE(finfo, nullptr);
+ pending_counts = finfo->pending_counts.get();
+ total_input_tensors = finfo->total_inputs;
+ num_pending_inputs = finfo->input_count;
+ nodes = finfo->nodes.get();
+}
+
+void PropagatorState::FrameState::SetIteration(int64 iter,
+ IterationState* state)
+ TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
+ size_t index = iter % (max_parallel_iterations + 1);
+ DCHECK(state == nullptr || iterations[index] == nullptr);
+ iterations_raw[index] = state;
+ if (index == 0) {
+ iterations_first = state;
+ }
+}
+
+// Decrement the outstanding op count and clean up the iterations in the
+// frame. Return true iff the execution of the frame is done.
+bool PropagatorState::FrameState::DecrementOutstandingOps(
+ int64 iter, TaggedNodeSeq* ready) {
+ mutex_lock l(mu);
+ return DecrementOutstandingOpsLocked(iter, ready);
+}
+
+// Decrement the outstanding op count and clean up the iterations in the
+// frame. Return true iff the execution of the frame is done.
+bool PropagatorState::FrameState::DecrementOutstandingOpsLocked(
+ int64 iter, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
+ IterationState* istate = GetIteration(iter);
+ istate->outstanding_ops--;
+ if (istate->outstanding_ops != 0) {
+ return false;
+ } else {
+ return CleanupIterations(iter, ready);
+ }
+}
+
+// Returns true if the computation in the frame is completed.
+bool PropagatorState::FrameState::IsFrameDone()
+ TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
+ return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h
new file mode 100644
index 0000000..d82d3bf
--- /dev/null
+++ b/tensorflow/core/common_runtime/propagator_state.h
@@ -0,0 +1,466 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
+
+#include <vector>
+
+#include "tensorflow/core/common_runtime/entry.h"
+#include "tensorflow/core/common_runtime/immutable_executor_state.h"
+#include "tensorflow/core/common_runtime/pending_counts.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
+typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
+
+// Represents the ephemeral "edge state" associated with one invocation of
+// `Executor::Run()`.
+//
+// `PropagatorState` is responsible for propagating values along dataflow
+// edges in a TensorFlow graph and determining which nodes are runnable. The
+// executor primarily updates `PropagatorState` by calling `PropagateOutputs()`
+// after processing a node, and `PropagatorState` dispatches `TaggedNode`s by
+// adding them to a `TaggedNodeSeq`.
+class PropagatorState {
+ public:
+ PropagatorState(const ImmutableExecutorState& immutable_state, int64 step_id);
+ ~PropagatorState();
+
+ private:
+ // Forward declaration so that `TaggedNode` can include a `FrameState*`.
+ struct FrameState;
+
+ public:
+ // A `TaggedNode` corresponds to a single invocation of a node's kernel,
+ // and it is created when the kernel becomes runnable (in a particular
+ // iteration of a particular frame).
+ struct TaggedNode {
+ const NodeItem* node_item;
+ FrameState* input_frame;
+ int64 input_iter;
+ bool is_dead;
+
+ TaggedNode() = default;
+ TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter,
+ bool dead)
+ : node_item(node_item),
+ input_frame(in_frame),
+ input_iter(in_iter),
+ is_dead(dead) {}
+
+ const NodeItem& get_node_item() const { return *node_item; }
+
+ bool get_is_dead() const { return is_dead; }
+ };
+
+ // A drop-in replacement for std::deque<TaggedNode>. We typically don't
+ // have that many nodes in the ready queue, so we just use a vector and
+ // don't free up memory from the queue as we consume nodes.
+ class TaggedNodeReadyQueue {
+ public:
+ TaggedNodeReadyQueue() : front_index_(0) {}
+
+ void push_back(const TaggedNode& node) { ready_.push_back(node); }
+ TaggedNode front() const {
+ DCHECK_LT(front_index_, ready_.size());
+ return ready_[front_index_];
+ }
+ void pop_front() {
+ DCHECK_LT(front_index_, ready_.size());
+ front_index_++;
+ if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
+ if (front_index_ == ready_.size()) {
+ ready_.clear();
+ } else {
+ // Lots of unused entries at beginning of vector: move everything
+ // down to start of vector.
+ ready_.erase(ready_.begin(), ready_.begin() + front_index_);
+ }
+ front_index_ = 0;
+ }
+ }
+ bool empty() const { return ready_.empty(); }
+
+ private:
+ // TODO(b/152925936): Re-evaluate these constants with current usage
+ // patterns.
+ static constexpr int kSpillThreshold = 16384;
+ gtl::InlinedVector<TaggedNode, 16> ready_;
+ int front_index_;
+ };
+
+ // TODO(b/152925936): Re-evaluate this constant with current usage patterns.
+ typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
+
+ private:
+ struct IterationState {
+ explicit IterationState(const PendingCounts* pending_counts,
+ int total_input_tensors)
+ : input_tensors(new Entry[total_input_tensors]),
+ outstanding_ops(0),
+ outstanding_frame_count(0),
+ counts(*pending_counts) { // Initialize with copy of *pending_counts
+ }
+
+ // The state of an iteration.
+
+ // One copy per iteration. For iteration k, i-th node's j-th input is in
+ // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
+ // either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
+ //
+ // NOTE: No need to protect input_tensors[i] by any locks because it
+ // is resized once. Each element of tensors_ is written once by the
+ // source node of an edge and is cleared by the destination of the same
+ // edge. The latter node is never run concurrently with the former node.
+ Entry* input_tensors;
+
+ // The number of outstanding ops for each iteration.
+ size_t outstanding_ops;
+
+ // The number of outstanding frames for each iteration.
+ int outstanding_frame_count;
+ int pending(PendingCounts::Handle h) { return counts.pending(h); }
+ int decrement_pending(PendingCounts::Handle h, int v) {
+ return counts.decrement_pending(h, v);
+ }
+ // Mark a merge node as live
+ // REQUIRES: Node corresponding to "h" is a merge node
+ void mark_live(PendingCounts::Handle h) { counts.mark_live(h); }
+ // Mark a node to show that processing has started.
+ void mark_started(PendingCounts::Handle h) { counts.mark_started(h); }
+ // Mark a node to show that processing has completed.
+ void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); }
+ PendingCounts::NodeState node_state(PendingCounts::Handle h) {
+ return counts.node_state(h);
+ }
+
+ int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); }
+ void increment_dead_count(PendingCounts::Handle h) {
+ counts.increment_dead_count(h);
+ }
+ PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h,
+ bool increment_dead) {
+ return counts.adjust_for_activation(h, increment_dead);
+ }
+
+ ~IterationState() { delete[] input_tensors; }
+
+ private:
+ PendingCounts counts;
+ };
+
+ struct FrameState {
+ explicit FrameState(const ImmutableExecutorState& immutable_state,
+ int parallel_iters)
+ : immutable_state(immutable_state),
+ max_parallel_iterations(parallel_iters),
+ num_outstanding_iterations(1),
+ iterations(parallel_iters + 1),
+ iterations_raw(iterations.data()) {}
+
+ // A new frame is created for each loop. Execution starts at iteration 0.
+ // When a value at iteration 0 passes through a NextIteration node,
+ // iteration 1 is created and starts running. Note that iteration 0 may
+ // still be running so multiple iterations may run in parallel. The
+ // frame maintains the state of iterations in several data structures
+ // such as pending_count and input_tensors. When iteration 0 completes,
+ // we garbage collect the state of iteration 0.
+ //
+ // A frame instance is considered "done" and can be garbage collected
+ // if all its inputs have entered and all its iterations are "done".
+ //
+ // A frame manages the live iterations of an iterative computation.
+ // Iteration i is considered "done" when there are no outstanding ops,
+ // frames at iteration i are done, all recvs for this iteration are
+ // completed, and iteration i-1 is done. For iteration 0, we instead
+ // wait for there to be no more pending inputs of the frame.
+ //
+ // Frames and iterations are garbage collected once they are done.
+ // The state we need to keep around is highly dependent on the
+ // parallelism enabled by the scheduler. We may want to have the
+ // scheduler dynamically control the outstanding number of live
+ // parallel frames and iterations. To reduce the state space, the
+ // scheduler might want to schedule ops in inner frames first and
+ // lower iterations first.
+ //
+ // This frame state is mostly initialized lazily on demand so we
+ // don't introduce unnecessary overhead.
+
+ // The immutable state of the executor the frame is in.
+ const ImmutableExecutorState& immutable_state;
+
+ // The name of this frame, which is the concatenation of its parent
+ // frame name, the iteration of the parent frame when this frame was
+ // created, and the value of the attr 'frame_name'.
+ string frame_name;
+
+ // The unique id for this frame. Generated by fingerprinting
+ // frame_name.
+ uint64 frame_id;
+
+ // The iteration id of its parent frame when this frame is created.
+ // -1 if there is no parent frame. The frame_name/parent_iter pair
+ // uniquely identifies this FrameState.
+ int64 parent_iter = -1;
+
+ // The FrameState of its parent frame.
+ FrameState* parent_frame = nullptr;
+
+ // The maximum allowed number of parallel iterations.
+ const int max_parallel_iterations;
+
+ // The number of inputs this frame is still waiting.
+ int num_pending_inputs = 0;
+
+ // The highest iteration number we have reached so far in this frame.
+ int64 iteration_count TF_GUARDED_BY(mu) = 0;
+
+ // The number of outstanding iterations.
+ int num_outstanding_iterations TF_GUARDED_BY(mu) = 1;
+
+ private:
+ // The active iteration states of this frame.
+ gtl::InlinedVector<IterationState*, 12> iterations;
+ IterationState** const iterations_raw TF_GUARDED_BY(mu);
+ IterationState* iterations_first TF_GUARDED_BY(mu);
+
+ public:
+ // The NextIteration nodes to enter a new iteration. If the number of
+ // outstanding iterations reaches the limit, we will defer the start of
+ // the next iteration until the number of outstanding iterations falls
+ // below the limit.
+ std::vector<std::pair<const NodeItem*, Entry>> next_iter_roots
+ TF_GUARDED_BY(mu);
+
+ // The values of the loop invariants for this loop. They are added into
+ // this list as they "enter" the frame. When a loop invariant enters,
+ // we make it available to all active iterations. When the frame starts
+ // a new iteration, we make all the current loop invariants available
+ // to the new iteration.
+ std::vector<std::pair<const NodeItem*, Entry>> inv_values TF_GUARDED_BY(mu);
+
+ // The list of dead exit node items for the current highest iteration. We
+ // will only "execute" the dead exits of the final iteration.
+ std::vector<const NodeItem*> dead_exits TF_GUARDED_BY(mu);
+
+ // Static information specific to this frame.
+ PendingCounts* pending_counts = nullptr;
+ int total_input_tensors = 0;
+ std::vector<const NodeItem*>* nodes = nullptr;
+
+ // Lock ordering: ExecutorState.mu_ < mu;
+ // during structured traversal: parent_frame->mu < mu.
+ mutex mu;
+
+ void InitializeFrameInfo(const string& enter_name);
+
+ inline IterationState* GetIteration(int64 iter)
+ TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
+ if (TF_PREDICT_TRUE(iter == 0)) {
+ return iterations_first;
+ } else {
+ size_t index = iter % (max_parallel_iterations + 1);
+ return iterations_raw[index];
+ }
+ }
+
+ void SetIteration(int64 iter, IterationState* state);
+
+ // Decrement the outstanding op count and clean up the iterations in the
+ // frame. Return true iff the execution of the frame is done.
+ bool DecrementOutstandingOps(int64 iter, TaggedNodeSeq* ready);
+
+ // Decrement the outstanding op count and clean up the iterations in the
+ // frame. Return true iff the execution of the frame is done.
+ bool DecrementOutstandingOpsLocked(int64 iter, TaggedNodeSeq* ready);
+
+ // Returns true if the computation in the frame is completed.
+ bool IsFrameDone();
+
+ // Returns true if the iteration of the frame is completed.
+ bool IsIterationDone(int64 iter) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
+
+ // Increments the iteration id. If this is a new iteration, initialize it.
+ void IncrementIteration(TaggedNodeSeq* ready)
+ TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
+
+ // Activate all the deferred NextIteration nodes in a new iteration.
+ void ActivateNexts(int64 iter, TaggedNodeSeq* ready)
+ TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
+
+ // Activate all the current loop invariants in a new iteration.
+ void ActivateLoopInvs(int64 iter, TaggedNodeSeq* ready)
+ TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
+
+ // Add a new loop invariant and make it available to all active
+ // iterations.
+ void AddLoopInv(const NodeItem* item, const Entry& entry,
+ TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
+
+ // Activate the successors of a node. Contents of *outputs are left in an
+ // indeterminate state after returning from this method.
+ void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter,
+ EntryVector* outputs, TaggedNodeSeq* ready)
+ TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
+
+ // Cleanup iterations of this frame starting from iteration iter.
+ bool CleanupIterations(int64 iter, TaggedNodeSeq* ready)
+ TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
+
+ void DumpIterationState(PropagatorState* parent) {
+ mutex_lock l(mu);
+ for (IterationState* iteration : iterations) {
+ if (iteration) {
+ LOG(WARNING) << " Iteration:";
+ parent->DumpIterationState(this, iteration);
+ }
+ }
+ }
+
+ ~FrameState() {
+ for (size_t i = 0; i < iterations.size(); ++i) {
+ delete iterations[i];
+ iterations[i] = nullptr;
+ }
+ }
+
+ private:
+ // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
+ void ActivateNodesFastPath(const NodeItem* item, const bool is_dead,
+ int64 iter, EntryVector* outputs,
+ TaggedNodeSeq* ready)
+ TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
+
+ void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
+ int64 iter, EntryVector* outputs,
+ TaggedNodeSeq* ready)
+ TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
+ };
+
+ public:
+ // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
+ void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
+ TaggedNodeSeq* ready);
+
+ // After processing the outputs, propagates the outputs to their dsts.
+ // Contents of *outputs are left in an indeterminate state after
+ // returning from this method.
+ void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
+ TaggedNodeSeq* ready);
+
+ // Returns an array of `Entry` objects corresponding to the inputs of
+ // `tagged_node`.
+ //
+ // NOTE: Thread safety analysis is disabled on this method, because the
+ // underlying `IterationState` and its array of `input_tensors` retain the
+ // same address while the iteration is live.
+ Entry* GetInputTensors(const TaggedNode& tagged_node) const
+ TF_NO_THREAD_SAFETY_ANALYSIS {
+ return tagged_node.input_frame->GetIteration(tagged_node.input_iter)
+ ->input_tensors +
+ tagged_node.node_item->input_start;
+ }
+
+ FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
+ return {tagged_node.input_frame->frame_id, tagged_node.input_iter};
+ }
+
+ // Provide debugging output of the state of the executor.
+ void DumpState();
+
+ // For debugging/logging only.
+ void MaybeMarkStarted(const TaggedNode& tagged_node) {
+ // TODO(misard) Replace with a finer-grain enabling flag once we add better
+ // optional debugging support.
+ if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
+ mutex_lock l(tagged_node.input_frame->mu);
+ tagged_node.input_frame->GetIteration(tagged_node.input_iter)
+ ->mark_started(
+ immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
+ }
+ }
+
+ void MaybeMarkCompleted(const TaggedNode& tagged_node) {
+ // TODO(misard) Replace with a finer-grain enabling flag once we add better
+ // optional debugging support.
+ if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
+ mutex_lock l(tagged_node.input_frame->mu);
+ tagged_node.input_frame->GetIteration(tagged_node.input_iter)
+ ->mark_completed(
+ immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
+ }
+ }
+
+ private:
+ // Find an existing or create a new child frame in the frame 'frame' at
+ // iteration 'iter'.
+ void FindOrCreateChildFrame(FrameState* frame, int64 iter,
+ const NodeItem& node_item, FrameState** child);
+
+ // Delete a frame. Called when the frame is done.
+ void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
+
+ // Cleanup frames and iterations starting from frame/iter. Called when
+ // a child frame is done.
+ void CleanupFramesIterations(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready);
+
+ // Provide debugging output about an outstanding node in the executor.
+ void DumpPendingNodeState(const int node_id, const Entry* input_vector,
+ bool show_nodes_with_no_ready_inputs);
+ void DumpActiveNodeState(const int node_id, const Entry* input_vector);
+
+ // Provide debugging output about an outstanding iteration in the executor.
+ void DumpIterationState(const FrameState* frame, IterationState* iteration);
+
+ const Tensor* GetTensorValueForDump(const Entry& input);
+
+ const ImmutableExecutorState& immutable_state_;
+ const int64 step_id_;
+ const bool vlog_;
+
+ mutex mu_;
+
+ // A flag that is set on error after the frame state has been
+ // dumped for diagnostic purposes.
+ bool dumped_on_error_ TF_GUARDED_BY(mu_) = false;
+
+ // The root frame in which the execution of this step is started.
+ FrameState* root_frame_;
+
+ // Mapping from frame name to outstanding frames. A new frame is created
+ // at some iteration of an active frame. So the unique key for the new
+ // child frame is composed of the name of the parent frame, the iteration
+ // number at which the parent frame is creating the new frame, and the
+ // name of the new frame from nodedef.
+ gtl::FlatMap<string, FrameState*> outstanding_frames_ TF_GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_