Ignore a limited number of transient timeouts in PS training.

PiperOrigin-RevId: 409426418
Change-Id: I0c04c47cb365fbbb2bf40a452657d22844504e88
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
index 87f165c..eb35d94 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py
@@ -544,6 +544,19 @@
           on_transient_failure_fn()
         return
 
+      # If the error is due to temporary connectivity issues that cause the
+      # server-side RPCs to be cancelled, TF might not abort the step and the
+      # closure might timeout. The coordinator ignores certain amount of such
+      # failures without marking worker as failure.
+      if self._cluster._record_and_ignore_transient_timeouts(e):  # pylint: disable=protected-access
+        logging.error(
+            "Remote function on worker %s failed with %r:%s\n"
+            "This derived error is ignored and not reported to users.",
+            worker_device_name, e, e)
+        if on_transient_failure_fn:
+          on_transient_failure_fn()
+        return
+
       # Ignoring derived CancelledErrors to tolerate transient failures in
       # PS-worker communication, which initially exposed as an UnavailableError
       # and then lead to sub-function cancellation, subsequently getting
@@ -818,6 +831,18 @@
     self._potential_ps_failures_lock = threading.Lock()
     self._potential_ps_failures_count = [0] * self._num_ps
 
+    # Ignore worker timeouts due to transient connection errors.
+    # Transient connectivity issues might cause the server side to unexpectedly
+    # cancel RPC handling logic, leading to closure execution timeouts. When
+    # the _transient_timeout_threshold is set to a positive number, the cluster
+    # coordinator ignores DeadlineExceeded errors from workers for the specified
+    # times before raising the error to users.
+    self._transient_timeouts_threshold = int(
+        os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_TIMEOUTS",
+                       self._num_workers // 10))
+    self._transient_timeouts_lock = threading.Lock()
+    self._transient_timeouts_count = 0
+
     self.closure_queue = _CoordinatedClosureQueue()
     self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
                                                    self)
@@ -852,6 +877,18 @@
           return False
     return True
 
+  def _record_and_ignore_transient_timeouts(self, e):
+    """Records observed timeout error and return if it should be ignored."""
+    if self._transient_timeouts_threshold <= 0:
+      return False
+    if not isinstance(e, errors.DeadlineExceededError):
+      return False
+    with self._transient_timeouts_lock:
+      self._transient_timeouts_count += 1
+      if self._transient_timeouts_count >= self._transient_timeouts_threshold:
+        return False
+    return True
+
   def schedule(self, function, args, kwargs):
     """Schedules `function` to be dispatched to a worker for execution.