Parallel device: avoid deadlocks when the EagerContext's default executor is async
Creates one sync executor per thread.
Requires fixing a tangential use-after-free where the context assumed all of the thread-local executors were still allocated at shutdown.
PiperOrigin-RevId: 316783819
Change-Id: I62e7a91dcccb847d4e1c2a5f08e30c2877556618
diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc
index 0c05839..a4d3141 100644
--- a/tensorflow/c/eager/c_api_experimental_test.cc
+++ b/tensorflow/c/eager/c_api_experimental_test.cc
@@ -212,6 +212,35 @@
TFE_DeleteCancellationManager(c_mgr);
}
+TEST(CAPI, ExecutorContextDestructionOrder) {
+ TF_Status* status = TF_NewStatus();
+
+ {
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+ TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
+ TFE_ContextSetExecutorForThread(ctx, executor);
+
+ TFE_DeleteContext(ctx);
+ TFE_DeleteExecutor(executor);
+ }
+
+ {
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+ TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
+ TFE_ContextSetExecutorForThread(ctx, executor);
+
+ TFE_DeleteExecutor(executor);
+ TFE_DeleteContext(ctx);
+ }
+ TF_DeleteStatus(status);
+}
+
TEST(CAPI, Function_ident_CPU) {
// First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph();
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
index 98cd481..d0149b2 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
@@ -37,6 +37,15 @@
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
+class ExecutorDeleter {
+ public:
+ void operator()(TFE_Executor* to_delete) const {
+ TFE_DeleteExecutor(to_delete);
+ }
+};
+
+using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
+
} // namespace
// Allows a single op at a time to be launched without blocking.
@@ -51,6 +60,13 @@
explicit DeviceThread(const std::string& device)
: status_(TF_NewStatus()),
device_(device),
+ // If the context's default exector is set to async, re-using that in
+ // each thread would cause collectives to deadlock. For consistency we
+ // create a new sync executor for every thread.
+ //
+ // TODO(allenl): We should have an async API that works with the
+ // parallel device.
+ executor_(TFE_NewExecutor(/*is_async=*/false)),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
@@ -105,6 +121,7 @@
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
+ ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
std::unique_ptr<Thread> thread_;
};
@@ -186,6 +203,7 @@
std::vector<TensorHandlePtr>* outputs,
TF_Status* status) const {
if (op_ == nullptr) {
+ TFE_ContextSetExecutorForThread(context, executor_.get());
op_.reset(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
index e5412db..2fa183d 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
@@ -412,6 +412,7 @@
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
+ TFE_ContextOptionsSetAsync(opts.get(), async);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
@@ -423,9 +424,6 @@
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
- std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
- TFE_NewExecutor(async), TFE_DeleteExecutor);
- TFE_ContextSetExecutorForThread(context.get(), executor.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{
@@ -455,8 +453,6 @@
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.);
- // Destroying the context's default executor first isn't safe.
- context.reset();
}
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 5d8cb3d..970c2bc 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -341,7 +341,28 @@
if (executor == &default_executor_) {
thread_local_executor_.erase(std::this_thread::get_id());
} else {
- thread_local_executor_[std::this_thread::get_id()] = executor;
+ auto thread_id = std::this_thread::get_id();
+ thread_local_executor_[thread_id] = executor;
+ auto& executors_with_cleanups = has_cleanup_[thread_id];
+ if (executors_with_cleanups.find(executor) ==
+ executors_with_cleanups.end()) {
+ executors_with_cleanups.insert(executor);
+ // If the executor is deleted before this context, we need to remove it
+ // from the map to avoid attempting to sync it in our destructor.
+ std::function<void()> cleanup([this, thread_id, executor]() {
+ {
+ tensorflow::mutex_lock l(executor_map_mu_);
+ auto existing = thread_local_executor_.find(thread_id);
+ if (existing != thread_local_executor_.end() &&
+ existing->second == executor) {
+ thread_local_executor_.erase(thread_id);
+ }
+ has_cleanup_[thread_id].erase(executor);
+ }
+ });
+ executor->AddCleanup(reinterpret_cast<intptr_t>(this),
+ std::move(cleanup));
+ }
}
}
@@ -525,6 +546,15 @@
custom_devices_.clear();
ClearCachesAndThreadExecutors();
+ std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
+ {
+ mutex_lock l(executor_map_mu_);
+ executors_copy = thread_local_executor_;
+ }
+ for (const auto& entry : executors_copy) {
+ // Let the executor know that its cleanup closure is no longer valid.
+ entry.second->RemoveCleanups(reinterpret_cast<intptr_t>(this));
+ }
for (auto& entry : registered_functions_) {
while (!entry.second->Unref()) {
// remove all references.
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index fa57afe..cb6d09f 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -639,6 +639,8 @@
// Not owned.
std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
TF_GUARDED_BY(executor_map_mu_);
+ std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
+ has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
const bool log_memory_;
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc
index ddfdabf..7fe321e 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.cc
+++ b/tensorflow/core/common_runtime/eager/eager_executor.cc
@@ -46,6 +46,11 @@
tensorflow::mutex_lock l(node_queue_mutex_);
state_ = ExecutorState::kShutDown;
nodes_pending_.notify_all();
+ for (const auto& cleanups_for_key : cleanups_) {
+ for (const std::function<void()>& cleanup : cleanups_for_key.second) {
+ cleanup();
+ }
+ }
}
Status EagerExecutor::ShutDown() {
@@ -413,4 +418,10 @@
return Status::OK();
}
+void EagerExecutor::AddCleanup(intptr_t key, std::function<void()> callback) {
+ cleanups_[key].push_back(callback);
+}
+
+void EagerExecutor::RemoveCleanups(intptr_t key) { cleanups_.erase(key); }
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.h b/tensorflow/core/common_runtime/eager/eager_executor.h
index aa8864c..34847ab 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.h
+++ b/tensorflow/core/common_runtime/eager/eager_executor.h
@@ -153,6 +153,13 @@
bool ok() const TF_NO_THREAD_SAFETY_ANALYSIS { return ok_; }
+ // On destruction, runs `callback`. Used by the EagerContext for clearing
+ // thread-local executors.
+ void AddCleanup(intptr_t key, std::function<void()> callback);
+ // If `key` (e.g. a context) is destroyed before the executor, the associated
+ // callbacks are no longer safe to run.
+ void RemoveCleanups(intptr_t key);
+
private:
// Possible states for this executor.
// Executor starts in kActive state. When Shutdown() is called, Executor
@@ -250,6 +257,9 @@
const eager::EagerClient* last_eager_client_;
const bool enable_async_wait_for_remote_function_;
+
+ // Callbacks to run on destruction.
+ std::unordered_map<intptr_t, std::vector<std::function<void()>>> cleanups_;
};
inline bool EagerExecutor::Async() const { return thread_ != nullptr; }
diff --git a/tensorflow/python/distribute/parallel_device/parallel_device_test.py b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
index 9dbf258..8fc3dcb 100644
--- a/tensorflow/python/distribute/parallel_device/parallel_device_test.py
+++ b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
@@ -23,6 +23,7 @@
from tensorflow.python.distribute.parallel_device import parallel_device
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.module import module
@@ -136,7 +137,7 @@
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
- def test_collective_reduce_async(self):
+ def test_collective_reduce_async_scope(self):
# Note that ops on the parallel device currently don't execute
# asynchronously. The test is just that we don't get deadlocks.
with context.async_scope(), ops.device(self.device.name):
@@ -149,6 +150,27 @@
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
+ def test_collective_reduce_async_context(self):
+ previous = config.get_synchronous_execution()
+ try:
+ context._reset_context()
+ config.set_synchronous_execution(False)
+ self.setUp()
+ # Note that ops on the parallel device currently don't execute
+ # asynchronously. The test is just that we don't get deadlocks.
+ with ops.device(self.device.name):
+ x = self.device.pack(
+ [constant_op.constant(-1.5),
+ constant_op.constant(3.5)])
+ reduced = _collective_sum(x, num_replicas=2)
+ outputs = self.device.unpack(reduced)
+ self.assertAllClose([2., 2.], outputs)
+ self.assertIn(self.device.components[0], outputs[0].backing_device)
+ self.assertIn(self.device.components[1], outputs[1].backing_device)
+ finally:
+ context._reset_context()
+ config.set_synchronous_execution(previous)
+
def test_checkpointing(self):
prefix = os.path.join(self.get_temp_dir(), "ckpt")
with self.device.scope():