Fix access to dangling pointer, which can happen when GetTapeSet() is called
while the current thread is exiting and its thread_local variables are being
deleted.

PiperOrigin-RevId: 428064347
Change-Id: I48ab57603ac4fccf6a7f0b214b9df05e18ea46b7
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 6df6a20..c4fe0e4 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -826,6 +826,16 @@
 
 static std::atomic<int64_t> _uid;
 
+// This struct is responsible for marking thread_local storage as destroyed.
+// Access to the `alive` field in already-destroyed ThreadLocalDestructionMarker
+// is safe because it's a trivial type, so long as nobody creates a new
+// thread_local in the space where now-destroyed marker used to be.
+// Hopefully creating new thread_locals while destructing a thread is rare.
+struct ThreadLocalDestructionMarker {
+  ~ThreadLocalDestructionMarker() { alive = false; }
+  bool alive = true;
+};
+
 }  // namespace
 
 TF_Status* GetStatus() {
@@ -1714,7 +1724,12 @@
 // stack.
 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
   thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
-      tape_set = nullptr;
+      tape_set;
+  thread_local ThreadLocalDestructionMarker marker;
+  if (!marker.alive) {
+    // This thread is being destroyed. It is unsafe to access tape_set.
+    return nullptr;
+  }
   if (tape_set == nullptr) {
     tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
   }
@@ -1725,7 +1740,13 @@
 GetVariableWatcherSet() {
   thread_local std::unique_ptr<
       tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
-      variable_watcher_set = nullptr;
+      variable_watcher_set;
+  thread_local ThreadLocalDestructionMarker marker;
+  if (!marker.alive) {
+    // This thread is being destroyed. It is unsafe to access
+    // variable_watcher_set.
+    return nullptr;
+  }
   if (variable_watcher_set == nullptr) {
     variable_watcher_set.reset(
         new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
@@ -1789,7 +1810,12 @@
 };
 
 AccumulatorSet* GetAccumulatorSet() {
-  thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr};
+  thread_local std::unique_ptr<AccumulatorSet> accumulator_set;
+  thread_local ThreadLocalDestructionMarker marker;
+  if (!marker.alive) {
+    // This thread is being destroyed. It is unsafe to access accumulator_set.
+    return nullptr;
+  }
   if (accumulator_set == nullptr) {
     accumulator_set.reset(new AccumulatorSet);
   }
@@ -1914,7 +1940,9 @@
 
 void TFE_Py_TapeSetRemove(PyObject* tape) {
   auto* stack = GetTapeSet();
-  stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
+  if (stack != nullptr) {
+    stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
+  }
   // We kept a reference to the tape in the set to ensure it wouldn't get
   // deleted under us; cleaning it up here.
   Py_DECREF(tape);
@@ -1998,7 +2026,7 @@
   if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
     return nullptr;
   }
-  auto tape_set = *GetTapeSet();
+  auto& tape_set = *GetTapeSet();
   for (TFE_Py_Tape* tape : tape_set) {
     if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
       Py_RETURN_TRUE;
@@ -2009,7 +2037,7 @@
 }
 
 PyObject* TFE_Py_ForwardAccumulatorPushState() {
-  auto forward_accumulators = *GetAccumulatorSet();
+  auto& forward_accumulators = *GetAccumulatorSet();
   for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
     accumulator->accumulator->PushState();
   }
@@ -2017,7 +2045,7 @@
 }
 
 PyObject* TFE_Py_ForwardAccumulatorPopState() {
-  auto forward_accumulators = *GetAccumulatorSet();
+  auto& forward_accumulators = *GetAccumulatorSet();
   for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
     accumulator->accumulator->PopState();
   }
@@ -2039,7 +2067,7 @@
   // gradients are possible.
   bool some_tape_watching = false;
   if (CouldBackprop()) {
-    auto tape_set = *GetTapeSet();
+    auto& tape_set = *GetTapeSet();
     for (TFE_Py_Tape* tape : tape_set) {
       if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
         if (tape->tape->IsPersistent() || some_tape_watching) {
@@ -2052,7 +2080,7 @@
     }
   }
   if (CouldForwardprop()) {
-    auto forward_accumulators = *GetAccumulatorSet();
+    auto& forward_accumulators = *GetAccumulatorSet();
     for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
       if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
         if (some_tape_watching) {
@@ -2290,9 +2318,12 @@
 
 PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
                                            PyObject* weak_tensor_ref) {
-  int64_t parsed_tensor_id = MakeInt(tensor_id);
-  for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) {
-    accumulator->accumulator->DeleteGradient(parsed_tensor_id);
+  auto* accumulator_set = GetAccumulatorSet();
+  if (accumulator_set != nullptr) {
+    int64_t parsed_tensor_id = MakeInt(tensor_id);
+    for (TFE_Py_ForwardAccumulator* accumulator : *accumulator_set) {
+      accumulator->accumulator->DeleteGradient(parsed_tensor_id);
+    }
   }
   Py_DECREF(weak_tensor_ref);
   Py_DECREF(tensor_id);
@@ -2723,7 +2754,13 @@
 }
 
 void TFE_Py_TapeSetDeleteTrace(int64_t tensor_id) {
-  for (TFE_Py_Tape* tape : *GetTapeSet()) {
+  auto* tape_set = GetTapeSet();
+  if (tape_set == nullptr) {
+    // Current thread is being destructed, and the tape set has already
+    // been cleared.
+    return;
+  }
+  for (TFE_Py_Tape* tape : *tape_set) {
     tape->tape->DeleteTrace(tensor_id);
   }
 }
@@ -2884,8 +2921,11 @@
 }
 
 void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
-  GetAccumulatorSet()->erase(
-      reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
+  auto* accumulator_set = GetAccumulatorSet();
+  if (accumulator_set != nullptr) {
+    accumulator_set->erase(
+        reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
+  }
   Py_DECREF(accumulator);
 }
 
@@ -2915,7 +2955,7 @@
     tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
     return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
   }
-  auto accumulators = *GetAccumulatorSet();
+  auto& accumulators = *GetAccumulatorSet();
   tensorflow::Safe_PyObjectPtr tensors_fast(
       PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
   if (tensors_fast == nullptr || PyErr_Occurred()) {