PSv2: Shutdown the worker preemption thread in ClusterCoordinator's destructor.
When ClusterCoordinator's destructor is called, mark the worker preemption thread as finished to stop the thread.
Notices the Cluster and Worker objects will not be destroyed due to the threading.Thread object holds a reference to the cluster object.
There will be 2 following CLs to address the Cluster/Worker deletion:
1. Close the worker threads in ClusterCoordinator's destructor.
2. Use weakref as target function for threading.Thread objects, to ensure Cluster/Worker objects can be garbage collected in memory.
PiperOrigin-RevId: 350639066
Change-Id: I23b54912f9b450fb2ccbe6b1da69d2bf5bc55aa5
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
index 159fc33..3829a43 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
@@ -613,12 +613,19 @@
self._server_def = server_def
self._cluster = cluster
self._cluster_update_lock = threading.Lock()
- self._cluster_due_for_update = threading.Event()
+ self._cluster_due_for_update_or_finish = threading.Event()
self._worker_up_cond = threading.Condition(self._cluster_update_lock)
+ self._should_preemption_thread_run = True
threading.Thread(target=self._preemption_handler,
name="WorkerPreemptionHandler",
daemon=True).start()
+ def _mark_finished(self):
+ """Ensure the worker preemption thread is closed."""
+ self._should_preemption_thread_run = False
+ with self._cluster_update_lock:
+ self._cluster_due_for_update_or_finish.set()
+
def _validate_preemption_failure(self, e):
"""Validates that the given exception represents worker preemption."""
if _is_worker_failure(e):
@@ -658,7 +665,7 @@
on_failure_fn()
with self._cluster_update_lock:
- self._cluster_due_for_update.set()
+ self._cluster_due_for_update_or_finish.set()
self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
logging.info("Worker %s has been recovered.", worker_device_name)
@@ -676,7 +683,10 @@
restarted workers.
"""
while True:
- self._cluster_due_for_update.wait()
+ self._cluster_due_for_update_or_finish.wait()
+ if not self._should_preemption_thread_run:
+ break
+
with self._cluster_update_lock:
try:
# TODO(haoyuzhang): support partial cluster recovery
@@ -687,7 +697,7 @@
# all workers that they are recovered from failure.
logging.info("Cluster successfully recovered.")
self._worker_up_cond.notify_all()
- self._cluster_due_for_update.clear()
+ self._cluster_due_for_update_or_finish.clear()
except Exception as e: # pylint: disable=broad-except
self._validate_preemption_failure(e)
# NOTE: Since the first RPC (GetStatus) of update_server_def is
@@ -996,6 +1006,10 @@
self._strategy.extended._used_with_coordinator = True
self._cluster = Cluster(strategy)
+ def __del__(self):
+ # TODO(xingliu): Stop the worker threads.
+ self._cluster.failure_handler._mark_finished()
+
@property
def strategy(self):
"""Returns the `Strategy` associated with the `ClusterCoordinator`."""
diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
index 6f5fc05..48bca52 100644
--- a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
+++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py
@@ -19,6 +19,7 @@
from __future__ import division
from __future__ import print_function
+import gc
import os
import threading
import time
@@ -47,6 +48,7 @@
_RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker"
_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
+_WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler"
class Model(object):
@@ -152,6 +154,25 @@
restart_thread.start()
return restart_thread
+ def _ensure_threads_closed(self):
+ """Ensure worker and preemption threads are closed."""
+ # Wait for threads to close.
+ self.cluster_coord = None
+ gc.collect()
+ time.sleep(1)
+
+ # Verify thread names.
+ running_threads = set()
+ for thread in threading.enumerate():
+ logging.info("Running thread name:%s", thread.name)
+ if thread.name is not None:
+ running_threads.add(thread.name)
+ # TODO(xingliu): Verify worker threads are closed.
+ self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
+
+ def testClusterCoordinatorDestroyed(self):
+ self._ensure_threads_closed()
+
def testWorkerPreemptionBetweenFunctions(self):
model = Model(self.cluster_coord)
model.schedule_training_functions(2)
@@ -346,6 +367,7 @@
if isinstance(e, errors.AbortedError):
self.assertIn("RecvTensor expects a different device incarnation",
str(e))
+ self._ensure_threads_closed()
def testTwoWorkersPreempted(self):
if self.num_workers < 2: