PSV2: Close worker thread when ClusterCoordinator is destroyed.

Close worker thread when __del__() of ClusterCoordinator object is destoryed.

Also rename WorkerPreemptionHanlder._mark_finished() to destory(), to avoid naming confusion to _CoordinatedClosureQueue.mark_finished().

PiperOrigin-RevId: 351884974
Change-Id: I18a81eb1d3562c874cbc37ea40bd4e9f38d34e16
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
index 3829a43..0579e26 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
@@ -435,6 +435,7 @@
 
     # Condition indicating that an item becomes available in queue (not empty).
     self._closures_queued_condition = threading.Condition(self._queue_lock)
+    self._should_process_closures = True
 
     # Condition indicating that a queue slot becomes available (not full).
     # Note that even with "infinite" queue size, there is still a "practical"
@@ -468,6 +469,11 @@
     # of the code.
     self._put_wait_lock = threading.Lock()
 
+  def stop(self):
+    with self._queue_lock:
+      self._should_process_closures = False
+      self._closures_queued_condition.notifyAll()
+
   def _cancel_all_closures(self):
     """Clears the queue and sets remaining closures cancelled error.
 
@@ -527,9 +533,11 @@
   def get(self, timeout=None):
     """Return a closure from the queue to be executed."""
     with self._queue_lock:
-      while self._queue.empty():
+      while self._queue.empty() and self._should_process_closures:
         if not self._closures_queued_condition.wait(timeout=timeout):
           return None
+      if not self._should_process_closures:
+        return None
       closure = self._queue.get(block=False)
       self._queue_free_slot_condition.notify()
       self._inflight_closure_count += 1
@@ -620,7 +628,7 @@
                      name="WorkerPreemptionHandler",
                      daemon=True).start()
 
-  def _mark_finished(self):
+  def stop(self):
     """Ensure the worker preemption thread is closed."""
     self._should_preemption_thread_run = False
     with self._cluster_update_lock:
@@ -729,12 +737,17 @@
     self.failure_handler = cluster.failure_handler
     self._cluster = cluster
     self._resource_remote_value_refs = []
+    self._should_worker_thread_run = True
 
     # Worker threads need to start after `Worker`'s initialization.
     threading.Thread(target=self._process_queue,
                      name="WorkerClosureProcessingLoop-%d" % self.worker_index,
                      daemon=True).start()
 
+  def stop(self):
+    """Ensure the worker thread is closed."""
+    self._should_worker_thread_run = False
+
   def _set_resources_aborted(self):
     # TODO(yuefengz): maybe we can query whether a tensor is valid or not
     # instead of marking a tensor aborted?
@@ -748,6 +761,7 @@
 
   def _process_closure(self, closure):
     """Runs a closure with preemption handling."""
+    assert closure is not None
     try:
       with self._cluster.failure_handler.wait_on_failure(
           on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),  # pylint: disable=protected-access
@@ -786,8 +800,10 @@
   def _process_queue(self):
     """Function running in a thread to process closure queues."""
     self._maybe_delay()
-    while True:
+    while self._should_worker_thread_run:
       closure = self._cluster._closure_queue.get()  # pylint: disable=protected-access
+      if not self._should_worker_thread_run or closure is None:
+        return
       self._process_closure(closure)
 
   def _create_resource(self, function, args=None, kwargs=None):
@@ -882,6 +898,14 @@
     ]
     self._strategy = strategy
 
+  def stop(self):
+    """Stop worker, worker preemption threads, and the closure queue."""
+    self.failure_handler.stop()
+
+    for worker in self.workers:
+      worker.stop()
+    self._closure_queue.stop()
+
   def _record_and_ignore_transient_ps_failure(self, e):
     """Records potential PS failures and return if failure should be ignored."""
     if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e):
@@ -1007,8 +1031,7 @@
     self._cluster = Cluster(strategy)
 
   def __del__(self):
-    # TODO(xingliu): Stop the worker threads.
-    self._cluster.failure_handler._mark_finished()
+    self._cluster.stop()
 
   @property
   def strategy(self):
diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
index 9115abf..af06be8 100644
--- a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
+++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
@@ -49,6 +49,7 @@
 _RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker"
 _RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
 _WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler"
+_WORKER_THREAD_PREFIX = "WorkerClosureProcessingLoop"
 
 
 class Model(object):
@@ -129,6 +130,7 @@
   def tearDown(self):
     super(BaseFaultToleranceTest, self).tearDown()
     self._cluster.stop()
+    self._cluster = None
 
   def _restart(self, downtime_secs, job):
     """Kills `job` (index: 0) and restarts it after `downtime_secs`.
@@ -155,20 +157,38 @@
     return restart_thread
 
   def _ensure_threads_closed(self):
-    """Ensure worker and preemption threads are closed."""
+    """Ensures worker and preemption threads are closed."""
+
+    def _get_running_threads():
+      """Returns a set of all running thread names."""
+      running_threads = set()
+      for thread in threading.enumerate():
+        if thread.name is not None:
+          running_threads.add(thread.name)
+      return running_threads
+
+    def _has_thread(prefix, running_threads):
+      """Returns whether any 'running_threads' is prefixed with 'prefix'."""
+      for thread in running_threads:
+        if thread.startswith(prefix):
+          return True
+      return False
+
+    # Worker and preemption threads should exist before releasing
+    # ClusterCoordinator.
+    running_threads = _get_running_threads()
+    self.assertTrue(_has_thread(_WORKER_THREAD_PREFIX, running_threads))
+    self.assertIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
+
     # Wait for threads to close.
     self.cluster_coord = None
     gc.collect()
     time.sleep(1)
 
     # Verify thread names.
-    running_threads = set()
-    for thread in threading.enumerate():
-      logging.info("Running thread name:%s", thread.name)
-      if thread.name is not None:
-        running_threads.add(thread.name)
-    # TODO(xingliu): Verify worker threads are closed.
+    running_threads = _get_running_threads()
     self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
+    self.assertFalse(_has_thread(_WORKER_THREAD_PREFIX, running_threads))
 
   def _create_model_and_run_indefinitely(self):
     model = Model(self.cluster_coord)