Ignore a limited number of transient timeouts in PS training.
PiperOrigin-RevId: 409426418
Change-Id: I0c04c47cb365fbbb2bf40a452657d22844504e88
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
index 87f165c..eb35d94 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
@@ -544,6 +544,19 @@
on_transient_failure_fn()
return
+ # If the error is due to temporary connectivity issues that cause the
+ # server-side RPCs to be cancelled, TF might not abort the step and the
+ # closure might timeout. The coordinator ignores certain amount of such
+ # failures without marking worker as failure.
+ if self._cluster._record_and_ignore_transient_timeouts(e): # pylint: disable=protected-access
+ logging.error(
+ "Remote function on worker %s failed with %r:%s\n"
+ "This derived error is ignored and not reported to users.",
+ worker_device_name, e, e)
+ if on_transient_failure_fn:
+ on_transient_failure_fn()
+ return
+
# Ignoring derived CancelledErrors to tolerate transient failures in
# PS-worker communication, which initially exposed as an UnavailableError
# and then lead to sub-function cancellation, subsequently getting
@@ -818,6 +831,18 @@
self._potential_ps_failures_lock = threading.Lock()
self._potential_ps_failures_count = [0] * self._num_ps
+ # Ignore worker timeouts due to transient connection errors.
+ # Transient connectivity issues might cause the server side to unexpectedly
+ # cancel RPC handling logic, leading to closure execution timeouts. When
+ # the _transient_timeout_threshold is set to a positive number, the cluster
+ # coordinator ignores DeadlineExceeded errors from workers for the specified
+ # times before raising the error to users.
+ self._transient_timeouts_threshold = int(
+ os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_TIMEOUTS",
+ self._num_workers // 10))
+ self._transient_timeouts_lock = threading.Lock()
+ self._transient_timeouts_count = 0
+
self.closure_queue = _CoordinatedClosureQueue()
self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
self)
@@ -852,6 +877,18 @@
return False
return True
+ def _record_and_ignore_transient_timeouts(self, e):
+ """Records observed timeout error and return if it should be ignored."""
+ if self._transient_timeouts_threshold <= 0:
+ return False
+ if not isinstance(e, errors.DeadlineExceededError):
+ return False
+ with self._transient_timeouts_lock:
+ self._transient_timeouts_count += 1
+ if self._transient_timeouts_count >= self._transient_timeouts_threshold:
+ return False
+ return True
+
def schedule(self, function, args, kwargs):
"""Schedules `function` to be dispatched to a worker for execution.