PSv2: Shutdown the worker preemption thread in ClusterCoordinator's destructor.

When ClusterCoordinator's destructor is called, mark the worker preemption thread as finished to stop the thread.

Notices the Cluster and Worker objects will not be destroyed due to the threading.Thread object holds a reference to the cluster object.

There will be 2 following CLs to address the Cluster/Worker deletion:
1. Close the worker threads in ClusterCoordinator's destructor.
2. Use weakref as target function for threading.Thread objects, to ensure Cluster/Worker objects can be garbage collected in memory.

PiperOrigin-RevId: 350639066
Change-Id: I23b54912f9b450fb2ccbe6b1da69d2bf5bc55aa5
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
index 159fc33..3829a43 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
@@ -613,12 +613,19 @@
     self._server_def = server_def
     self._cluster = cluster
     self._cluster_update_lock = threading.Lock()
-    self._cluster_due_for_update = threading.Event()
+    self._cluster_due_for_update_or_finish = threading.Event()
     self._worker_up_cond = threading.Condition(self._cluster_update_lock)
+    self._should_preemption_thread_run = True
     threading.Thread(target=self._preemption_handler,
                      name="WorkerPreemptionHandler",
                      daemon=True).start()
 
+  def _mark_finished(self):
+    """Ensure the worker preemption thread is closed."""
+    self._should_preemption_thread_run = False
+    with self._cluster_update_lock:
+      self._cluster_due_for_update_or_finish.set()
+
   def _validate_preemption_failure(self, e):
     """Validates that the given exception represents worker preemption."""
     if _is_worker_failure(e):
@@ -658,7 +665,7 @@
         on_failure_fn()
 
       with self._cluster_update_lock:
-        self._cluster_due_for_update.set()
+        self._cluster_due_for_update_or_finish.set()
         self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
         logging.info("Worker %s has been recovered.", worker_device_name)
 
@@ -676,7 +683,10 @@
     restarted workers.
     """
     while True:
-      self._cluster_due_for_update.wait()
+      self._cluster_due_for_update_or_finish.wait()
+      if not self._should_preemption_thread_run:
+        break
+
       with self._cluster_update_lock:
         try:
           # TODO(haoyuzhang): support partial cluster recovery
@@ -687,7 +697,7 @@
           # all workers that they are recovered from failure.
           logging.info("Cluster successfully recovered.")
           self._worker_up_cond.notify_all()
-          self._cluster_due_for_update.clear()
+          self._cluster_due_for_update_or_finish.clear()
         except Exception as e:  # pylint: disable=broad-except
           self._validate_preemption_failure(e)
           # NOTE: Since the first RPC (GetStatus) of update_server_def is
@@ -996,6 +1006,10 @@
     self._strategy.extended._used_with_coordinator = True
     self._cluster = Cluster(strategy)
 
+  def __del__(self):
+    # TODO(xingliu): Stop the worker threads.
+    self._cluster.failure_handler._mark_finished()
+
   @property
   def strategy(self):
     """Returns the `Strategy` associated with the `ClusterCoordinator`."""
diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
index 6f5fc05..48bca52 100644
--- a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
+++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
@@ -19,6 +19,7 @@
 from __future__ import division
 from __future__ import print_function
 
+import gc
 import os
 import threading
 import time
@@ -47,6 +48,7 @@
 
 _RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker"
 _RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
+_WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler"
 
 
 class Model(object):
@@ -152,6 +154,25 @@
     restart_thread.start()
     return restart_thread
 
+  def _ensure_threads_closed(self):
+    """Ensure worker and preemption threads are closed."""
+    # Wait for threads to close.
+    self.cluster_coord = None
+    gc.collect()
+    time.sleep(1)
+
+    # Verify thread names.
+    running_threads = set()
+    for thread in threading.enumerate():
+      logging.info("Running thread name:%s", thread.name)
+      if thread.name is not None:
+        running_threads.add(thread.name)
+    # TODO(xingliu): Verify worker threads are closed.
+    self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
+
+  def testClusterCoordinatorDestroyed(self):
+    self._ensure_threads_closed()
+
   def testWorkerPreemptionBetweenFunctions(self):
     model = Model(self.cluster_coord)
     model.schedule_training_functions(2)
@@ -346,6 +367,7 @@
       if isinstance(e, errors.AbortedError):
         self.assertIn("RecvTensor expects a different device incarnation",
                       str(e))
+      self._ensure_threads_closed()
 
   def testTwoWorkersPreempted(self):
     if self.num_workers < 2: