Fix access to dangling pointer, which can happen when GetTapeSet() is called
while the current thread is exiting and its thread_local variables are being
deleted.
PiperOrigin-RevId: 428064347
Change-Id: I48ab57603ac4fccf6a7f0b214b9df05e18ea46b7
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 6df6a20..c4fe0e4 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -826,6 +826,16 @@
static std::atomic<int64_t> _uid;
+// This struct is responsible for marking thread_local storage as destroyed.
+// Access to the `alive` field in already-destroyed ThreadLocalDestructionMarker
+// is safe because it's a trivial type, so long as nobody creates a new
+// thread_local in the space where now-destroyed marker used to be.
+// Hopefully creating new thread_locals while destructing a thread is rare.
+struct ThreadLocalDestructionMarker {
+ ~ThreadLocalDestructionMarker() { alive = false; }
+ bool alive = true;
+};
+
} // namespace
TF_Status* GetStatus() {
@@ -1714,7 +1724,12 @@
// stack.
tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
- tape_set = nullptr;
+ tape_set;
+ thread_local ThreadLocalDestructionMarker marker;
+ if (!marker.alive) {
+ // This thread is being destroyed. It is unsafe to access tape_set.
+ return nullptr;
+ }
if (tape_set == nullptr) {
tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
}
@@ -1725,7 +1740,13 @@
GetVariableWatcherSet() {
thread_local std::unique_ptr<
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
- variable_watcher_set = nullptr;
+ variable_watcher_set;
+ thread_local ThreadLocalDestructionMarker marker;
+ if (!marker.alive) {
+ // This thread is being destroyed. It is unsafe to access
+ // variable_watcher_set.
+ return nullptr;
+ }
if (variable_watcher_set == nullptr) {
variable_watcher_set.reset(
new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
@@ -1789,7 +1810,12 @@
};
AccumulatorSet* GetAccumulatorSet() {
- thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr};
+ thread_local std::unique_ptr<AccumulatorSet> accumulator_set;
+ thread_local ThreadLocalDestructionMarker marker;
+ if (!marker.alive) {
+ // This thread is being destroyed. It is unsafe to access accumulator_set.
+ return nullptr;
+ }
if (accumulator_set == nullptr) {
accumulator_set.reset(new AccumulatorSet);
}
@@ -1914,7 +1940,9 @@
void TFE_Py_TapeSetRemove(PyObject* tape) {
auto* stack = GetTapeSet();
- stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
+ if (stack != nullptr) {
+ stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
+ }
// We kept a reference to the tape in the set to ensure it wouldn't get
// deleted under us; cleaning it up here.
Py_DECREF(tape);
@@ -1998,7 +2026,7 @@
if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
return nullptr;
}
- auto tape_set = *GetTapeSet();
+ auto& tape_set = *GetTapeSet();
for (TFE_Py_Tape* tape : tape_set) {
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
Py_RETURN_TRUE;
@@ -2009,7 +2037,7 @@
}
PyObject* TFE_Py_ForwardAccumulatorPushState() {
- auto forward_accumulators = *GetAccumulatorSet();
+ auto& forward_accumulators = *GetAccumulatorSet();
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
accumulator->accumulator->PushState();
}
@@ -2017,7 +2045,7 @@
}
PyObject* TFE_Py_ForwardAccumulatorPopState() {
- auto forward_accumulators = *GetAccumulatorSet();
+ auto& forward_accumulators = *GetAccumulatorSet();
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
accumulator->accumulator->PopState();
}
@@ -2039,7 +2067,7 @@
// gradients are possible.
bool some_tape_watching = false;
if (CouldBackprop()) {
- auto tape_set = *GetTapeSet();
+ auto& tape_set = *GetTapeSet();
for (TFE_Py_Tape* tape : tape_set) {
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
if (tape->tape->IsPersistent() || some_tape_watching) {
@@ -2052,7 +2080,7 @@
}
}
if (CouldForwardprop()) {
- auto forward_accumulators = *GetAccumulatorSet();
+ auto& forward_accumulators = *GetAccumulatorSet();
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
if (some_tape_watching) {
@@ -2290,9 +2318,12 @@
PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
PyObject* weak_tensor_ref) {
- int64_t parsed_tensor_id = MakeInt(tensor_id);
- for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) {
- accumulator->accumulator->DeleteGradient(parsed_tensor_id);
+ auto* accumulator_set = GetAccumulatorSet();
+ if (accumulator_set != nullptr) {
+ int64_t parsed_tensor_id = MakeInt(tensor_id);
+ for (TFE_Py_ForwardAccumulator* accumulator : *accumulator_set) {
+ accumulator->accumulator->DeleteGradient(parsed_tensor_id);
+ }
}
Py_DECREF(weak_tensor_ref);
Py_DECREF(tensor_id);
@@ -2723,7 +2754,13 @@
}
void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) {
- for (TFE_Py_Tape* tape : *GetTapeSet()) {
+ auto* tape_set = GetTapeSet();
+ if (tape_set == nullptr) {
+ // Current thread is being destructed, and the tape set has already
+ // been cleared.
+ return;
+ }
+ for (TFE_Py_Tape* tape : *tape_set) {
tape->tape->DeleteTrace(tensor_id);
}
}
@@ -2884,8 +2921,11 @@
}
void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
- GetAccumulatorSet()->erase(
- reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
+ auto* accumulator_set = GetAccumulatorSet();
+ if (accumulator_set != nullptr) {
+ accumulator_set->erase(
+ reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
+ }
Py_DECREF(accumulator);
}
@@ -2915,7 +2955,7 @@
tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
}
- auto accumulators = *GetAccumulatorSet();
+ auto& accumulators = *GetAccumulatorSet();
tensorflow::Safe_PyObjectPtr tensors_fast(
PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
if (tensors_fast == nullptr || PyErr_Occurred()) {