Fix flaky test: third_party/tensorflow/c/eager:c_api_test
1. Race condition in Execute_MatMul_CPU_Runtime_Error. The executor probably hasn't been poisoned when scheduling the next matmul2.
2. Deadlock: avoid holding remote_state_mu_ when cleaning up EagerClientCache.
3. Possible use after delete in streaming enqueue handler and remote_copy_node.
PiperOrigin-RevId: 267656988
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index d3b755f..6702e26 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1069,10 +1069,13 @@
// still fail.
TF_SetStatus(status, TF_OK, "");
TFE_DeleteTensorHandle(retvals[0]);
+ TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
+ TFE_ExecutorWaitForAllPendingNodes(executor, status);
+ EXPECT_NE(TF_OK, TF_GetCode(status));
+ TF_SetStatus(status, TF_OK, "");
retvals[0] = nullptr;
TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
EXPECT_NE(TF_OK, TF_GetCode(status));
- TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorClearError(executor);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index a989de3..bef4f8f 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -263,16 +263,18 @@
CloseRemoteContexts();
}
- mutex_lock l(remote_state_mu_);
-
- default_executor_.ShutDown().IgnoreError();
- std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
{
- mutex_lock l(executor_map_mu_);
- executors_copy = thread_local_executor_;
- }
- for (const auto& it : executors_copy) {
- it.second->ShutDown().IgnoreError();
+ mutex_lock l(remote_state_mu_);
+
+ default_executor_.ShutDown().IgnoreError();
+ std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
+ {
+ mutex_lock l(executor_map_mu_);
+ executors_copy = thread_local_executor_;
+ }
+ for (const auto& it : executors_copy) {
+ it.second->ShutDown().IgnoreError();
+ }
}
// This shuts down the completion queue and joins the thread polling it.
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc
index 72f148c..a009171 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.cc
+++ b/tensorflow/core/common_runtime/eager/eager_executor.cc
@@ -121,6 +121,8 @@
// try to call EagerExecutor::Add()
{
tensorflow::mutex_lock l(node_queue_mutex_);
+ VLOG(3) << "Add node [id " << next_node_id_ << "]" << node->DebugString()
+ << " with status: " << status_.ToString();
if (state_ != ExecutorState::kActive) {
status = errors::FailedPrecondition(
"EagerExecutor accepts new EagerNodes to run only in Active state. "
@@ -190,7 +192,8 @@
}
void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
- VLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString();
+ VLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
+ << " with status: " << status.ToString();
std::unique_ptr<NodeItem> current_item;
std::vector<std::unique_ptr<NodeItem>> items_to_destroy;
{
diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
index 8a5f5a3..b458caa 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
@@ -341,10 +341,11 @@
}
StartSend();
+ const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
auto done_wrapper = std::bind(
- [this](const StatusCallback& done, const Status& s) {
+ [captured_state](const StatusCallback& done, const Status& s) {
if (!s.ok() && errors::IsCancelled(s)) {
- Status send_status = captured_state_->GetSendStatus();
+ Status send_status = captured_state->GetSendStatus();
if (!send_status.ok()) {
// In this case, Recv is cancelled because the Send op failed.
// Return the status of the Send op instead.
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
index ae94770..29328df 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
@@ -84,7 +84,13 @@
// executor queue.
void StreamingEnqueueHandler(
StreamingCall<EnqueueRequest, EnqueueResponse>* call) {
+ call->Ref();
enqueue_streaming_thread_.Schedule([this, call]() {
+ if (call->RefCountIsOne()) {
+ // This StreamingCall has already been shutdown. Don't need to anything.
+ call->Unref();
+ return;
+ }
// NOTE(fishx): Use the address of StreamingCall as the stream_id since we
// reuse the same StreamingCall for multiple requests in the same
// streaming connection.
@@ -100,6 +106,7 @@
<< " on request " << call->request().DebugString();
call->Finish(ToGrpcStatus(status));
}
+ call->Unref();
// We do not tell gRPC to accept a new StreamingEnqueue request because
// this method can be called multiple times for a given streaming call.