PSV2: Close worker thread when ClusterCoordinator is destroyed.
Close worker thread when __del__() of ClusterCoordinator object is destoryed.
Also rename WorkerPreemptionHanlder._mark_finished() to destory(), to avoid naming confusion to _CoordinatedClosureQueue.mark_finished().
PiperOrigin-RevId: 351884974
Change-Id: I18a81eb1d3562c874cbc37ea40bd4e9f38d34e16
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
index 3829a43..0579e26 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
@@ -435,6 +435,7 @@
# Condition indicating that an item becomes available in queue (not empty).
self._closures_queued_condition = threading.Condition(self._queue_lock)
+ self._should_process_closures = True
# Condition indicating that a queue slot becomes available (not full).
# Note that even with "infinite" queue size, there is still a "practical"
@@ -468,6 +469,11 @@
# of the code.
self._put_wait_lock = threading.Lock()
+ def stop(self):
+ with self._queue_lock:
+ self._should_process_closures = False
+ self._closures_queued_condition.notifyAll()
+
def _cancel_all_closures(self):
"""Clears the queue and sets remaining closures cancelled error.
@@ -527,9 +533,11 @@
def get(self, timeout=None):
"""Return a closure from the queue to be executed."""
with self._queue_lock:
- while self._queue.empty():
+ while self._queue.empty() and self._should_process_closures:
if not self._closures_queued_condition.wait(timeout=timeout):
return None
+ if not self._should_process_closures:
+ return None
closure = self._queue.get(block=False)
self._queue_free_slot_condition.notify()
self._inflight_closure_count += 1
@@ -620,7 +628,7 @@
name="WorkerPreemptionHandler",
daemon=True).start()
- def _mark_finished(self):
+ def stop(self):
"""Ensure the worker preemption thread is closed."""
self._should_preemption_thread_run = False
with self._cluster_update_lock:
@@ -729,12 +737,17 @@
self.failure_handler = cluster.failure_handler
self._cluster = cluster
self._resource_remote_value_refs = []
+ self._should_worker_thread_run = True
# Worker threads need to start after `Worker`'s initialization.
threading.Thread(target=self._process_queue,
name="WorkerClosureProcessingLoop-%d" % self.worker_index,
daemon=True).start()
+ def stop(self):
+ """Ensure the worker thread is closed."""
+ self._should_worker_thread_run = False
+
def _set_resources_aborted(self):
# TODO(yuefengz): maybe we can query whether a tensor is valid or not
# instead of marking a tensor aborted?
@@ -748,6 +761,7 @@
def _process_closure(self, closure):
"""Runs a closure with preemption handling."""
+ assert closure is not None
try:
with self._cluster.failure_handler.wait_on_failure(
on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure), # pylint: disable=protected-access
@@ -786,8 +800,10 @@
def _process_queue(self):
"""Function running in a thread to process closure queues."""
self._maybe_delay()
- while True:
+ while self._should_worker_thread_run:
closure = self._cluster._closure_queue.get() # pylint: disable=protected-access
+ if not self._should_worker_thread_run or closure is None:
+ return
self._process_closure(closure)
def _create_resource(self, function, args=None, kwargs=None):
@@ -882,6 +898,14 @@
]
self._strategy = strategy
+ def stop(self):
+ """Stop worker, worker preemption threads, and the closure queue."""
+ self.failure_handler.stop()
+
+ for worker in self.workers:
+ worker.stop()
+ self._closure_queue.stop()
+
def _record_and_ignore_transient_ps_failure(self, e):
"""Records potential PS failures and return if failure should be ignored."""
if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e):
@@ -1007,8 +1031,7 @@
self._cluster = Cluster(strategy)
def __del__(self):
- # TODO(xingliu): Stop the worker threads.
- self._cluster.failure_handler._mark_finished()
+ self._cluster.stop()
@property
def strategy(self):
diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
index 9115abf..af06be8 100644
--- a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
+++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
@@ -49,6 +49,7 @@
_RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker"
_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
_WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler"
+_WORKER_THREAD_PREFIX = "WorkerClosureProcessingLoop"
class Model(object):
@@ -129,6 +130,7 @@
def tearDown(self):
super(BaseFaultToleranceTest, self).tearDown()
self._cluster.stop()
+ self._cluster = None
def _restart(self, downtime_secs, job):
"""Kills `job` (index: 0) and restarts it after `downtime_secs`.
@@ -155,20 +157,38 @@
return restart_thread
def _ensure_threads_closed(self):
- """Ensure worker and preemption threads are closed."""
+ """Ensures worker and preemption threads are closed."""
+
+ def _get_running_threads():
+ """Returns a set of all running thread names."""
+ running_threads = set()
+ for thread in threading.enumerate():
+ if thread.name is not None:
+ running_threads.add(thread.name)
+ return running_threads
+
+ def _has_thread(prefix, running_threads):
+ """Returns whether any 'running_threads' is prefixed with 'prefix'."""
+ for thread in running_threads:
+ if thread.startswith(prefix):
+ return True
+ return False
+
+ # Worker and preemption threads should exist before releasing
+ # ClusterCoordinator.
+ running_threads = _get_running_threads()
+ self.assertTrue(_has_thread(_WORKER_THREAD_PREFIX, running_threads))
+ self.assertIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
+
# Wait for threads to close.
self.cluster_coord = None
gc.collect()
time.sleep(1)
# Verify thread names.
- running_threads = set()
- for thread in threading.enumerate():
- logging.info("Running thread name:%s", thread.name)
- if thread.name is not None:
- running_threads.add(thread.name)
- # TODO(xingliu): Verify worker threads are closed.
+ running_threads = _get_running_threads()
self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
+ self.assertFalse(_has_thread(_WORKER_THREAD_PREFIX, running_threads))
def _create_model_and_run_indefinitely(self):
model = Model(self.cluster_coord)