Avoid heap allocations for sync execution
The heap allocation of ExecuteNode & NodeItem was a clear drag on
profiles.
PiperOrigin-RevId: 282070222
Change-Id: Ifbd02722735574e1de2ad4a946099f50b9f27342
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc
index 3b4c791..8cec323 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.cc
+++ b/tensorflow/core/common_runtime/eager/eager_executor.cc
@@ -84,6 +84,36 @@
}
}
+Status EagerExecutor::SyncExecute(EagerNode* node) {
+ if (Async()) {
+ return errors::Internal("Executor does not support sync execution");
+ }
+ if (node->AsAsync() != nullptr) {
+ return errors::Internal("Executor does not support executing async nodes");
+ }
+ Status s = status();
+ if (!s.ok()) {
+ return s;
+ }
+
+ uint64 id = next_node_id_++;
+
+ s = node->Prepare();
+ if (!s.ok()) {
+ return s;
+ }
+
+ // Inline execution in sync mode.
+ s = node->Run();
+ tensorflow::mutex_lock l(node_queue_mutex_);
+ if (!s.ok()) {
+ status_ = s;
+ ok_ = false;
+ }
+ NotifyWaiters(id);
+ return s;
+}
+
Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
Status status;
core::RefCountPtr<NodeItem> item(new NodeItem);
@@ -220,30 +250,8 @@
}
unfinished_nodes_.clear();
}
- if (!node_done_notifications_.empty() && need_notification) {
- uint64 upperbound_id = 0;
- if (!unfinished_nodes_.empty()) {
- upperbound_id = unfinished_nodes_.begin()->first - 1;
- } else if (!node_queue_.empty()) {
- upperbound_id = node_queue_.front()->id - 1;
- } else {
- upperbound_id = next_node_id_ - 1;
- }
- DVLOG(3) << "Notify node done: [id " << item->id << " to "
- << upperbound_id << "] ";
- // Note that we notify all waiting threads in case an error has
- // occurred. These calling threads are responsible for checking status_
- // before proceeding.
- const auto range =
- status_.ok()
- ? make_pair(node_done_notifications_.lower_bound(item->id),
- node_done_notifications_.upper_bound(upperbound_id))
- : make_pair(node_done_notifications_.begin(),
- node_done_notifications_.end());
- for (auto it = range.first; it != range.second; ++it) {
- it->second->notify_all();
- }
- node_done_notifications_.erase(range.first, range.second);
+ if (need_notification) {
+ NotifyWaiters(item->id);
}
}
for (auto& item : items_to_destroy) {
@@ -255,6 +263,34 @@
// a deadlock.
}
+void EagerExecutor::NotifyWaiters(uint64 id) {
+ if (!node_done_notifications_.empty()) {
+ uint64 upperbound_id = 0;
+ if (!unfinished_nodes_.empty()) {
+ upperbound_id = unfinished_nodes_.begin()->first - 1;
+ } else if (!node_queue_.empty()) {
+ upperbound_id = node_queue_.front()->id - 1;
+ } else {
+ upperbound_id = next_node_id_ - 1;
+ }
+ DVLOG(3) << "Notify node done: [id " << id << " to " << upperbound_id
+ << "] ";
+ // Note that we notify all waiting threads in case an error has
+ // occurred. These calling threads are responsible for checking status_
+ // before proceeding.
+ const auto range =
+ status_.ok()
+ ? make_pair(node_done_notifications_.lower_bound(id),
+ node_done_notifications_.upper_bound(upperbound_id))
+ : make_pair(node_done_notifications_.begin(),
+ node_done_notifications_.end());
+ for (auto it = range.first; it != range.second; ++it) {
+ it->second->notify_all();
+ }
+ node_done_notifications_.erase(range.first, range.second);
+ }
+}
+
void EagerExecutor::Run() {
auto thread_exited_notifier =
gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.h b/tensorflow/core/common_runtime/eager/eager_executor.h
index 97e9c44..fb65ec02 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.h
+++ b/tensorflow/core/common_runtime/eager/eager_executor.h
@@ -106,6 +106,9 @@
bool Async() const;
+ // Inline execute node if executor is in sync mode.
+ Status SyncExecute(EagerNode* node);
+
// - Async Mode: schedules `node` for execution.
// - Sync Mode: inline execute the 'node' directly.
// If an error occurs (e.g. EagerExecutor has already been shut down), the
@@ -165,6 +168,7 @@
const char* StateStringLocked() EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
void NodeDone(const core::RefCountPtr<NodeItem>& item, const Status& status);
+ void NotifyWaiters(uint64 id) EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
// Starts execution of pending EagerNodes. This function loops till
// thread_done_ is set to true. If any errors are encontered, these are set
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 0d5dd4e..923c7e1 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -659,15 +659,24 @@
output_dtypes[i], ctx, &retvals[i]));
}
- std::unique_ptr<EagerNode> node(new ExecuteNode(
- ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
- graph_collector, output_dtypes, op->GetCancellationManager(),
- executor.Async(), {retvals, num_outputs}));
- // Note that for async mode, execution order will make sure that all
- // input handles are ready before executing them.
- // TODO(b/137118203): Consider executing "cheap" kernels inline for
- // performance.
- Status s = executor.AddOrExecute(std::move(node));
+ Status s;
+ if (executor.Async()) {
+ auto node = absl::make_unique<ExecuteNode>(
+ ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
+ graph_collector, output_dtypes, op->GetCancellationManager(),
+ executor.Async(), absl::Span<TensorHandle*>(retvals, num_outputs));
+ // For async mode, execution order will make sure that all
+ // input handles are ready before executing them.
+ // TODO(b/137118203): Consider executing "cheap" kernels inline for
+ // performance.
+ s = executor.AddOrExecute(std::move(node));
+ } else {
+ ExecuteNode node(ctx, op->Inputs(), op->remote_func_params(),
+ std::move(kernel), graph_collector, output_dtypes,
+ op->GetCancellationManager(), executor.Async(),
+ {retvals, num_outputs});
+ s = executor.SyncExecute(&node);
+ }
// Since the operation failed, we need to Unref any outputs that were
// allocated.
if (!s.ok()) {