Cancel in-flight closures when there is an error.

PiperOrigin-RevId: 324542620
Change-Id: I1d6cddf8130df74f00ce7b0a3b6b84f553990e78
diff --git a/tensorflow/python/distribute/client/BUILD b/tensorflow/python/distribute/client/BUILD
index 0f7b7df..35d8de9 100644
--- a/tensorflow/python/distribute/client/BUILD
+++ b/tensorflow/python/distribute/client/BUILD
@@ -32,6 +32,7 @@
         "//tensorflow/python/distribute:input_lib",
         "//tensorflow/python/distribute:parameter_server_strategy_v2",
         "//tensorflow/python/distribute:values",
+        "//tensorflow/python/eager:cancellation",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:executor",
diff --git a/tensorflow/python/distribute/client/client.py b/tensorflow/python/distribute/client/client.py
index 533d5f1..7bef5e2 100644
--- a/tensorflow/python/distribute/client/client.py
+++ b/tensorflow/python/distribute/client/client.py
@@ -31,15 +31,19 @@
 import weakref
 from absl import logging
 from six.moves import queue
+
 from tensorflow.python.distribute import distribute_lib
 from tensorflow.python.distribute import input_lib
 from tensorflow.python.distribute import parameter_server_strategy_v2
 from tensorflow.python.distribute.client import metric_utils
+from tensorflow.python.eager import cancellation
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import executor
 from tensorflow.python.eager import function as tf_function
 from tensorflow.python.eager import remote
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import func_graph
 from tensorflow.python.framework import ops
@@ -247,20 +251,28 @@
     self._values = tuple(values)
 
 
+def _select_worker_slice(worker_id, structured):
+  """Selects the worker slice of each of the items in `structured`."""
+
+  def _get(x):
+    return x._values[worker_id] if isinstance(x, PerWorkerValues) else x  # pylint: disable=protected-access
+
+  return nest.map_structure(_get, structured)
+
+
 class Closure(object):
   """Hold a function to be scheduled and its arguments."""
 
-  def __init__(self, function, args=None, kwargs=None):
+  def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
     if not callable(function):
       raise ValueError("Function passed to `Client.schedule` must be a "
                        "callable object.")
     self._args = args or ()
     self._kwargs = kwargs or {}
-    self._function = function
 
     if isinstance(function, def_function.Function):
-      replica_args = self._select_worker_slice(0, self._args)
-      replica_kwargs = self._select_worker_slice(0, self._kwargs)
+      replica_args = _select_worker_slice(0, self._args)
+      replica_kwargs = _select_worker_slice(0, self._kwargs)
 
       # Note: no need to handle function registration failure since this kind of
       # failure will not raise exceptions as designed in the runtime. The client
@@ -276,25 +288,22 @@
         concrete_function = function.get_concrete_function(
             *nest.map_structure(_maybe_as_type_spec, replica_args),
             **nest.map_structure(_maybe_as_type_spec, replica_kwargs))
+      self._function = cancellation_mgr.get_cancelable_function(
+          concrete_function)
       self._output_remote_values = nest.map_structure(
           lambda x: RemoteValue(self, x), concrete_function.structured_outputs)
     elif isinstance(function, tf_function.ConcreteFunction):
+      self._function = cancellation_mgr.get_cancelable_function(
+          concrete_function)
       self._output_remote_values = nest.map_structure(
           lambda x: RemoteValue(self, x), function.structured_outputs)
     else:
       # Regular python functions.
+      self._function = function
       # TODO(yuefengz): maybe we should trace python functions if their inputs
       # are Python primitives, tensors and composite tensors.
       self._output_remote_values = RemoteValue(self, None)
 
-  def _select_worker_slice(self, worker_id, structured):
-    """Selects the worker slice of each of the items in `structured`."""
-
-    def _get(x):
-      return x._values[worker_id] if isinstance(x, PerWorkerValues) else x  # pylint: disable=protected-access
-
-    return nest.map_structure(_get, structured)
-
   def _fetch_output_remote_values(self):
     """Temporary method used to sync the scheduler."""
     # It will do nothing if there is no return value.
@@ -319,9 +328,8 @@
     Args:
       worker: a `Worker` object.
     """
-    replica_args = self._select_worker_slice(worker.worker_index, self._args)
-    replica_kwargs = self._select_worker_slice(worker.worker_index,
-                                               self._kwargs)
+    replica_args = _select_worker_slice(worker.worker_index, self._args)
+    replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
 
     e = (
         _maybe_get_error_and_rebuild_remote_values(worker, replica_args) or
@@ -350,8 +358,7 @@
   This class is thread-safe.
   """
 
-  def __init__(self):
-
+  def __init__(self, cancellation_mgr):
     # `self._inflight_closure_count` only tracks the number of inflight closures
     # that are "in generation". Once an error occurs, error generation is
     # incremented and all subsequent arriving closures (from inflight) are
@@ -359,17 +366,26 @@
     self._inflight_closure_count = 0
 
     self._queue_lock = threading.Lock()
+
     # Condition indicating that all pending closures (either queued or inflight)
     # have been processed, failed, or cancelled.
     self._stop_waiting_condition = threading.Condition(self._queue_lock)
+
     # Condition indicating that an item becomes available in queue (not empty).
     self._closures_queued_condition = threading.Condition(self._queue_lock)
+
     # Condition indicating that a queue slot becomes available (not full).
     # Note that even with "infinite" queue size, there is still a "practical"
     # size limit for the queue depending on host memory capacity, and thus the
     # queue will eventually become full with a lot of enqueued closures.
     self._queue_free_slot_condition = threading.Condition(self._queue_lock)
 
+    # Condition indicating there is no inflight closures.
+    self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
+
+    # Use to cancel in-flight closures.
+    self._cancellation_mgr = cancellation_mgr
+
     if _CLOSURE_QUEUE_MAX_SIZE <= 0:
       logging.warning(
           "In ParameterServerClient, creating an infinite closure queue can "
@@ -377,31 +393,6 @@
     self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
     self._error = None
 
-    # Error generation is a counter that helps us track whether a closure
-    # should be cancelled when it is being put back to `self._queue`. It works
-    # in the following way:
-    # 1) Error generation starts off at 0.
-    # 2) When a worker thread calls `get()`, the closure's error generation
-    #    is copied from this queue's error generation.
-    # 3) If any worker thread experiences an error that's categorized as a
-    #    non-retryable error, the queue's error will be set, error generation
-    #    increments by 1, and the queue is cleared (with the closures marked
-    #    with cancelled error), so other worker threads stop getting closures
-    #    from the queue. Worker preemption is categorized as a retryable error.
-    # 4) At this point, if `put()` or `wait()` is called (usually by the main
-    #    thread via `schedule` and `join`), the error is raised through that
-    #    call.
-    # 5) The closures that are inflight, i.e. that are being executed remotely,
-    #    will not be aware of such error event. If the worker that's executing
-    #    the closure happens to be interrupted, the closure should not be put
-    #    back to the queue, and be cancelled with error instead. Checking the
-    #    generation id of the closure and queue is how the worker thread tells
-    #    whether the closure should be put back. Likewise for `mark_finished`
-    #    and `mark_failed`: if the arriving closure is considered out of
-    #    generation in those two methods, it is simply discarded (the inflight
-    #    closure count still decrements).
-    self._error_generation = 0
-
     # The following is a lock to make sure when `wait` is called and before it
     # returns no `put` can be executed during this period. It is because `wait`
     # won't know what to do with newly put closures. This lock adds an cutoff
@@ -415,11 +406,14 @@
     # of the code.
     self._put_wait_lock = threading.Lock()
 
-  def _cancel_closures_in_queue(self):
+  def _cancel_all_closures(self):
     """Clears the queue and sets remaining closures cancelled error.
 
     This method expects self._queue_lock to be held prior to entry.
     """
+    self._cancellation_mgr.start_cancel()
+    while self._inflight_closure_count > 0:
+      self._no_inflight_closure_condition.wait()
     while True:
       try:
         closure = self._queue.get(block=False)
@@ -437,8 +431,8 @@
     This method expects self._queue_lock to be held prior to entry.
     """
     if self._error:
+      self._cancel_all_closures()
       try:
-        self._cancel_closures_in_queue()
         raise self._error  # pylint: disable=raising-bad-type
       finally:
         self._error = None
@@ -466,16 +460,17 @@
           return None
       closure = self._queue.get(block=False)
       self._queue_free_slot_condition.notify()
-      closure._error_generation = self._error_generation  # pylint: disable=protected-access
       self._inflight_closure_count += 1
       return closure
 
-  def mark_finished(self, closure):
+  def mark_finished(self):
     """Let the queue know that a closure has been successfully executed."""
     with self._queue_lock:
       if self._inflight_closure_count < 1:
         raise AssertionError("There is no inflight closures to mark_finished.")
       self._inflight_closure_count -= 1
+      if self._inflight_closure_count == 0:
+        self._no_inflight_closure_condition.notifyAll()
       if self._queue.empty() and self._inflight_closure_count == 0:
         self._stop_waiting_condition.notifyAll()
 
@@ -484,17 +479,15 @@
     with self._queue_lock:
       if self._inflight_closure_count < 1:
         raise AssertionError("There is no inflight closures to put_back.")
-      self._inflight_closure_count -= 1
-      if closure._error_generation < self._error_generation:  # pylint: disable=protected-access
-        # If the closure to put back is out of generation, cancel the closure
-        # and ignore it.
-        logging.info("Function %r should no longer be dispatched; marking "
-                     "as cancelled.")
+      if self._error:
         closure._set_output_remote_values_cancelled()  # pylint: disable=protected-access
-        return
-      self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
-      self._queue.put(closure, block=False)
-      self._closures_queued_condition.notify()
+      else:
+        self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
+        self._queue.put(closure, block=False)
+        self._closures_queued_condition.notify()
+      self._inflight_closure_count -= 1
+      if self._inflight_closure_count == 0:
+        self._no_inflight_closure_condition.notifyAll()
 
   def wait(self, timeout=None):
     """Wait for all closures to be finished before returning.
@@ -516,22 +509,18 @@
       self._raise_if_error()
       return True
 
-  def mark_failed(self, e, closure):
+  def mark_failed(self, e):
     """Sets error and unblocks any wait() call."""
     with self._queue_lock:
       # TODO(yuefengz): maybe record all failure and give users more
       # information?
       if self._inflight_closure_count < 1:
         raise AssertionError("There is no inflight closures to mark_failed.")
+      if self._error is None:
+        self._error = e
       self._inflight_closure_count -= 1
-      if closure._error_generation < self._error_generation:  # pylint: disable=protected-access
-        # If the closure to mark fail is out of generation, simply ignore it
-        # (with the actual error associated with the closure preserved).
-        return
-      assert self._error is None
-      self._error = e
-      self._error_generation += 1
-      self._cancel_closures_in_queue()
+      if self._inflight_closure_count == 0:
+        self._no_inflight_closure_condition.notifyAll()
       self._stop_waiting_condition.notifyAll()
 
   def done(self):
@@ -678,7 +667,7 @@
         # TODO(yuefengz): we don't have to materialize results every step.
         with metric_utils.monitored_timer("remote_value_fetch"):
           closure._fetch_output_remote_values()  # pylint: disable=protected-access
-        self._cluster._closure_queue.mark_finished(closure)  # pylint: disable=protected-access
+        self._cluster._closure_queue.mark_finished()  # pylint: disable=protected-access
     except Exception as e:  # pylint: disable=broad-except
       logging.error(
           "/job:worker/task:%d encountered the following error when processing "
@@ -686,7 +675,7 @@
       nest.map_structure(
           lambda x: x._set_error(e),  # pylint: disable=protected-access
           closure._output_remote_values)  # pylint: disable=protected-access
-      self._cluster._closure_queue.mark_failed(e, closure)  # pylint: disable=protected-access
+      self._cluster._closure_queue.mark_failed(e)  # pylint: disable=protected-access
 
   def _process_queue(self):
     while True:
@@ -710,7 +699,8 @@
     # the same worker such as creating resources, setting resources' aborted
     # status, and executing closures happen on the same thread. This allows us
     # to have simpler logic of concurrency.
-    closure = Closure(function=function, args=args, kwargs=kwargs)
+    closure = Closure(
+        function, self._cluster._cancellation_mgr, args=args, kwargs=kwargs)  # pylint: disable=protected-access
     resource_remote_value = closure._output_remote_values  # pylint: disable=protected-access
     self._register_resource(resource_remote_value)
 
@@ -775,7 +765,8 @@
                               protocol=cluster_resolver.rpc_layer,
                               cluster_device_filters=device_filters)
 
-    self._closure_queue = _CoordinatedClosureQueue()
+    self._cancellation_mgr = cancellation.CancellationManager()
+    self._closure_queue = _CoordinatedClosureQueue(self._cancellation_mgr)
     self.failure_handler = WorkerPreemptionHandler(context.get_server_def())
     worker_device_strings = [
         "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
@@ -796,7 +787,8 @@
     Returns:
       A structure of `RemoteValue` object.
     """
-    closure = Closure(function=function, args=args, kwargs=kwargs)
+    closure = Closure(
+        function, self._cancellation_mgr, args=args, kwargs=kwargs)
     self._closure_queue.put(closure)
     return closure._output_remote_values  # pylint: disable=protected-access
 
@@ -893,8 +885,8 @@
     function execution to finish and retrieve its output from the remote worker.
 
     `schedule` guarantees that `fn` will be executed on a worker at least once;
-    it could be more than once if a worker fails and restarts in the middle of
-    function scheduling. Note that since worker can fail at any point when
+    it could be more than once if its corresponding worker fails in the middle
+    of its execution. Note that since worker can fail at any point when
     executing the function, it is possible that the function is partially
     executed, but `Client` guarantees that in those events, the function will
     eventually be fully executed, possibly on a different worker that is
@@ -904,14 +896,12 @@
     by raising any one of those errors, and clear the errors collected so far.
     There are two implications when this happens: 1) user should call `schedule`
     with `fn` again to re-schedule, and 2) some of the previously scheduled
-    functions may no longer execute. User can call `fetch` on the returned
+    functions may have not been executed. User can call `fetch` on the returned
     `RemoteValue` to inspect if they have executed, failed, or cancelled, and
     reschedule the corresponding function if needed.
 
-    When `schedule` raises, it is possible that there are still functions being
-    executed on workers, at the time `schedule` raises. When this happens, users
-    can call `join` again to wait for all pending async function execution to
-    finish, and bring the cluster into a consistent state.
+    When `schedule` raises, it guarantees that there is no function that is
+    still being executed.
 
     At this time, there is no support of worker assignment for function
     execution, or priority of the workers.
@@ -940,7 +930,8 @@
     # TODO(b/160702436): Invoke `strategy.run` for user's function so it enters
     # a `ReplicaContext` in a logically correct way.
     with distribute_lib.ReplicaContext(
-        self._strategy, replica_id_in_sync_group=0):
+        self._strategy,
+        replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
       with self._translate_parameter_server_failure():
         return self.cluster.schedule(fn, args=args, kwargs=kwargs)
 
@@ -949,17 +940,14 @@
 
     If any previously scheduled function raises an error, `join` will fail by
     raising any one of those errors, and clear the errors collected so far. If
-    this happens, some of the previously scheduled functions may no longer
-    execute. Users can call `fetch` on the returned `RemoteValue` to inspect if
+    this happens, some of the previously scheduled functions may have not been
+    executed. Users can call `fetch` on the returned `RemoteValue` to inspect if
     they have executed, failed, or cancelled. If some that have been cancelled
     need to be rescheduled, users should call `schedule` with the function
     again.
 
-    Note: `join` raises an exception as soon as the client detects one, and this
-    means it is possible that there are still functions being executed on
-    workers, at the time `join` raises. When this happens, users can call `join`
-    again to wait for all pending async function execution to finish, and bring
-    the cluster into a consistent state.
+    When `join` returns or raises, it guarantees that there is no function that
+    is still being executed.
 
     Raises:
       Exception: one of the exceptions caught by the client by any previously
@@ -976,6 +964,9 @@
 
     If any previously scheduled function raises an error, `done` will fail by
     raising any one of those errors.
+
+    When `done` returns True or raises, it guarantees that there is no function
+    that is still being executed.
     """
     return self.cluster.done()
 
@@ -1091,7 +1082,7 @@
         raise
 
 
-class _PerWorkerDistributedDataset(object):  # pylint: disable=protected-access
+class _PerWorkerDistributedDataset(object):
   """Represents worker-distributed datasets created from dataset function."""
 
   def __init__(self, dataset_fn, input_workers, client):
@@ -1107,13 +1098,13 @@
 
     if isinstance(dataset_fn, def_function.Function):
       with variable_scope.variable_creator_scope(disallow_variable_creation):
-        self._dataset_fn = dataset_fn.get_concrete_function()
-    elif isinstance(dataset_fn, tf_function.ConcreteFunction):
-      self._dataset_fn = dataset_fn
-    else:
+        dataset_fn = dataset_fn.get_concrete_function()
+    elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
       with variable_scope.variable_creator_scope(disallow_variable_creation):
-        self._dataset_fn = def_function.function(
-            dataset_fn).get_concrete_function()
+        dataset_fn = def_function.function(dataset_fn).get_concrete_function()
+    self._dataset_fn = (
+        client.cluster._cancellation_mgr.get_cancelable_function(  # pylint: disable=protected-access
+            dataset_fn))
     self._input_workers = input_workers
     self._client = client
     self._element_spec = None
diff --git a/tensorflow/python/distribute/client/client_test.py b/tensorflow/python/distribute/client/client_test.py
index 1215240..459633a 100644
--- a/tensorflow/python/distribute/client/client_test.py
+++ b/tensorflow/python/distribute/client/client_test.py
@@ -30,22 +30,34 @@
 from tensorflow.python.util import nest
 
 
+class MockCancellationManager(object):
+
+  def __init__(self):
+    self.cancelled = False
+
+  def start_cancel(self):
+    self.cancelled = True
+
+  def get_cancelable_function(self, func):
+    return func
+
+
 class CoordinatedClosureQueueTest(test.TestCase):
 
   def testBasic(self):
-    queue = client._CoordinatedClosureQueue()
+    queue = client._CoordinatedClosureQueue(MockCancellationManager())
     closure1 = self._create_closure()
     queue.put(closure1)
     self.assertIs(closure1, queue.get())
     self.assertFalse(queue.done())
     queue.put_back(closure1)
     self.assertEqual(closure1, queue.get())
-    queue.mark_finished(closure1)
+    queue.mark_finished()
     self.assertTrue(queue.done())
     queue.wait()
 
   def testProcessAtLeaseOnce(self):
-    closure_queue = client._CoordinatedClosureQueue()
+    closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
     labels = ['A', 'B', 'C', 'D', 'E']
     processed_count = collections.defaultdict(int)
 
@@ -63,7 +75,7 @@
             closure_queue.put_back(closure)
             continue
           closure._function()
-          closure_queue.mark_finished(closure)
+          closure_queue.mark_finished()
 
     def get_func(label):
 
@@ -76,7 +88,8 @@
       return func
 
     for label in labels:
-      closure_queue.put(client.Closure(get_func(label)))
+      closure_queue.put(
+          client.Closure(get_func(label), MockCancellationManager()))
     t1 = threading.Thread(target=process_queue, daemon=True)
     t1.start()
     t2 = threading.Thread(target=process_queue, daemon=True)
@@ -93,7 +106,7 @@
     coord.join([t1, t2])
 
   def testNotifyBeforeWait(self):
-    closure_queue = client._CoordinatedClosureQueue()
+    closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
 
     def func():
       logging.info('func running')
@@ -102,10 +115,10 @@
 
     def process_queue():
       with coord.stop_on_exception():
-        closure = closure_queue.get()
-        closure_queue.mark_finished(closure)
+        closure_queue.get()
+        closure_queue.mark_finished()
 
-    closure_queue.put(client.Closure(func))
+    closure_queue.put(client.Closure(func, MockCancellationManager()))
     t = threading.Thread(target=process_queue)
     t.start()
     coord.join([t])
@@ -114,8 +127,30 @@
     # doesn't time out.
     closure_queue.wait()
 
+  def _assert_one_unblock_the_other(self, first_fn, second_fn):
+    """Asserts `second_fn` wouldn't return before `first_fn` is finished."""
+    first_fn_done = threading.Event()
+    second_fn_done = threading.Event()
+    coord = coordinator.Coordinator(clean_stop_exception_types=[])
+
+    def wrapped_first_fn():
+      with coord.stop_on_exception():
+        self.assertFalse(second_fn_done.is_set())
+        first_fn()
+        first_fn_done.set()
+
+    self.assertFalse(first_fn_done.is_set())
+    t = threading.Thread(target=wrapped_first_fn)
+    t.start()
+
+    second_fn()
+    self.assertTrue(first_fn_done.is_set())
+    second_fn_done.set()
+
+    coord.join([t])
+
   def testWaitRaiseErrorAfterMarkFailure(self):
-    closure_queue = client._CoordinatedClosureQueue()
+    closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
     closure_queue.put(self._create_closure())
     closure = closure_queue.get()
 
@@ -126,22 +161,17 @@
     # all inflight closures are finished.
 
     def mark_finished_fn():
-      with coord.stop_on_exception():
-        self.assertFalse(wait_finish_event.is_set())
-        try:
-          raise ValueError('Some error.')
-        except ValueError as e:
-          closure_queue.mark_failed(e, closure)
-        wait_finish_event.wait()
+      try:
+        raise ValueError('Some error.')
+      except ValueError as e:
+        closure_queue.mark_failed(e)
 
-    t = threading.Thread(target=mark_finished_fn)
-    t.start()
+    def wait_fn():
+      with self.assertRaises(ValueError):
+        closure_queue.wait()
 
-    with self.assertRaises(ValueError):
-      closure_queue.wait()
-    wait_finish_event.set()
+    self._assert_one_unblock_the_other(mark_finished_fn, wait_fn)
 
-    coord.join([t])
     self.assertTrue(closure_queue.done())
 
   def _create_closure(self):
@@ -150,10 +180,10 @@
     def some_function():
       return 1.0
 
-    return client.Closure(some_function)
+    return client.Closure(some_function, MockCancellationManager())
 
   def _put_two_closures_and_get_one(self):
-    closure_queue = client._CoordinatedClosureQueue()
+    closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
     closure1 = self._create_closure()
     closure_queue.put(closure1)
 
@@ -166,9 +196,9 @@
     return closure_queue, closure1, closure2
 
   def testPutRaiseError(self):
-    closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
+    closure_queue, _, closure2 = self._put_two_closures_and_get_one()
 
-    closure_queue.mark_failed(ValueError(), closure1)
+    closure_queue.mark_failed(ValueError())
 
     with self.assertRaises(ValueError):
       closure_queue.put(self._create_closure())
@@ -185,9 +215,9 @@
     closure_queue.put(self._create_closure())
 
   def testWaitRaiseError(self):
-    closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
+    closure_queue, _, closure2 = self._put_two_closures_and_get_one()
 
-    closure_queue.mark_failed(ValueError(), closure1)
+    closure_queue.mark_failed(ValueError())
 
     with self.assertRaises(ValueError):
       closure_queue.wait()
@@ -203,15 +233,22 @@
     closure_queue.wait()
 
   def testDoneRaiseError(self):
-    closure_queue, closure1, _ = self._put_two_closures_and_get_one()
-    closure_queue.get()
+    closure_queue, _, _ = self._put_two_closures_and_get_one()
 
     self.assertFalse(closure_queue.done())
-    closure_queue.mark_failed(ValueError(), closure1)
+    closure_queue.mark_failed(ValueError())
     with self.assertRaises(ValueError):
       closure_queue.done()
 
-  def _test_error_reporting_and_cancel_flow(self, call_wait):
+  def _set_error(self, closure_queue, closure, error):
+    try:
+      raise error
+    except Exception as e:  # pylint: disable=broad-except
+      nest.map_structure(lambda x: x._set_error(e),
+                         closure._output_remote_values)
+      closure_queue.mark_failed(e)
+
+  def _test_cancel_closure_when_error(self, call_wait):
     closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
     closure_queue.put(self._create_closure())
     closure_queue.get()
@@ -219,34 +256,37 @@
     self.assertEqual(closure_queue._inflight_closure_count, 2)
 
     # Simulating closure1 fails.
-    try:
-      raise ValueError('Some error.')
-    except ValueError as e:
-      nest.map_structure(lambda x: x._set_error(e),
-                         closure1._output_remote_values)
-      self.assertEqual(closure_queue._error_generation, 0)  # pylint: disable=g-assert-in-except
-      closure_queue.mark_failed(e, closure1)
-      self.assertEqual(closure_queue._error_generation, 1)
-    # At this moment, there are one inflight, nothing
-    # in queue (because the ones in queue should have been removed and
-    # cancelled).
-    self.assertTrue(closure_queue._queue.empty())
-    # Doesn't include out of generation closures.
+    self._set_error(closure_queue, closure1, ValueError('Some error.'))
+
+    # At this moment, there are one inflight, one in queue.
+    self.assertEqual(closure_queue._queue.qsize(), 1)
     self.assertEqual(closure_queue._inflight_closure_count, 1)
 
-    coord = coordinator.Coordinator(clean_stop_exception_types=[])
     closure3 = self._create_closure()
 
-    with self.assertRaises(ValueError):
-      # Verifying `wait()` or `put()` raises even if one closure is in
-      # flight.
-      if call_wait:
-        closure_queue.wait()
-      else:
-        closure_queue.put(closure3)
-    # At this moment, there is one inflight, nothing in queue.
+    def fake_cancellation():
+      self._set_error(closure_queue, closure2,
+                      ValueError('Fake cancellation error.'))
+
+    def report_error():
+      # It should not report the fake cancellation error.
+      with self.assertRaisesRegex(ValueError, 'Some error.'):
+        # Verifying `wait()` or `put()` raises even if one closure is in
+        # flight.
+        if call_wait:
+          closure_queue.wait()
+        else:
+          closure_queue.put(closure3)
+
+    self._assert_one_unblock_the_other(fake_cancellation, report_error)
+
+    # Cancellation manager has been called.
+    self.assertTrue(closure_queue._cancellation_mgr.cancelled)
+
+    # At this moment, there is zero inflight, nothing in queue.
     self.assertTrue(closure_queue._queue.empty())
-    self.assertEqual(closure_queue._inflight_closure_count, 1)
+    self.assertEqual(closure_queue._inflight_closure_count, 0)
+    self.assertIsNone(closure_queue._error)
 
     # This asserts that closure1 has errored.
     with self.assertRaisesRegex(ValueError, 'Some error.'):
@@ -260,107 +300,36 @@
           'function.'):
         closure3._fetch_output_remote_values()
 
-    # Closure2 is inflight, so it shouldn't be ready.
+    # Closure2 was an inflight closure when it got cancelled.
     self.assertEqual(closure2._output_remote_values._status,
-                     client._RemoteValueStatus.NOT_READY)
-
-    # And `wait` should block because closure2 is not back yet.
-    self.assertFalse(closure_queue.wait(timeout=20))
-
-    # Now let's assume that closure2 isn't successful due to worker preemption,
-    # and now it's attempted to be put back, but ends up getting cancelled.
-    self.assertEqual(closure2._error_generation, 0)
-    self.assertEqual(closure_queue._error_generation, 1)
-    closure_queue.put_back(closure2)
-
-    with self.assertRaisesRegex(
-        client.FunctionRetryableError,
-        'The corresponding function is cancelled. Please reschedule the '
-        'function.'):
+                     client._RemoteValueStatus.READY)
+    with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
       closure2._fetch_output_remote_values()
 
-    # At this moment, there is nothing inflight, and the queue is also empty
-    # (because closure2 should not be added back to the queue).
-    self.assertTrue(closure_queue._queue.empty())
-    self.assertEqual(closure_queue._inflight_closure_count, 0)
+    # This asserts that the queue has a clear state.
+    self.testBasic()
 
-    closure4 = self._create_closure()
+  def testWaitRaiseErrorAfterCancelClosure(self):
+    self._test_cancel_closure_when_error(call_wait=True)
 
-    e = threading.Event()
-
-    def get_fn():
-      with coord.stop_on_exception():
-        # This should end up getting closure4, not closure2, because closure2
-        # has been cancelled and should not be got.
-        closure_got = closure_queue.get()
-        e.set()
-        self.assertEqual(closure_got._error_generation, 1)
-        self.assertEqual(closure_queue._error_generation, 1)
-        self.assertIs(closure4, closure_got)
-        self.assertIsNot(closure2, closure_got)
-
-    t = threading.Thread(target=get_fn)
-    t.start()
-
-    time.sleep(10)
-
-    # Make sure `closure_got = closure_queue.get()` is unblocked as a result of
-    # `closure_queue.put(closure4)`.
-    self.assertFalse(e.is_set())
-    closure_queue.put(closure4)
-    self.assertTrue(e.wait())
-    coord.join([t])
-
-    self.assertEqual(closure_queue._inflight_closure_count, 1)
-    closure_queue.mark_finished(closure4)
-    # The queue is now cleared and nothing inflight.
-    self.assertEqual(closure_queue._inflight_closure_count, 0)
-    closure_queue.wait()
-
-  def testWaitRaiseErrorAfterAnErrorIsReported(self):
-    self._test_error_reporting_and_cancel_flow(call_wait=True)
-
-  def testPutRaiseErrorAfterAnErrorIsReported(self):
-    self._test_error_reporting_and_cancel_flow(call_wait=False)
+  def testPutRaiseErrorAfterCancelClosure(self):
+    self._test_cancel_closure_when_error(call_wait=False)
 
   def testStateIsRestoredAfterJoinIsCalled(self):
-    closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
-    closure_queue.get()
-    self.assertEqual(closure_queue._inflight_closure_count, 2)
-    closure_queue.mark_failed(ValueError('test error'), closure1)
+    closure_queue, _, _ = self._put_two_closures_and_get_one()
+    self.assertEqual(closure_queue._inflight_closure_count, 1)
+    closure_queue.mark_failed(ValueError('test error'))
     with self.assertRaises(ValueError):
       closure_queue.put(self._create_closure())
-    closure_queue.mark_failed(ValueError('test error'), closure2)
 
-    # closure2's error is previous generation so should not raise at this
-    # following put, and _error should have been cleared.
+    # Its error should have been cleared.
     self.assertIsNone(closure_queue._error)
     closure_queue.put(self._create_closure())
     self.assertIsNone(closure_queue._error)
 
-  def testStateIsRestoredAfterJoinIsCalled_WaitShouldReturn(self):
-    closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
-    closure_queue.put(self._create_closure())
-    closure_queue.get()  # got closure2
-    self.assertFalse(closure_queue._queue.empty())  # still has closure3
-    self.assertEqual(closure_queue._inflight_closure_count, 2)  # closure1,2
-    closure_queue.mark_failed(ValueError('test error'), closure1)
-    self.assertTrue(closure_queue._queue.empty())  # closure3 cancelled
-    self.assertEqual(closure_queue._inflight_closure_count, 1)
-    with self.assertRaises(ValueError):
-      closure_queue.wait()  # reports error from closure1
-
-    # `wait` should block because closure2 is not back yet, even if closure2
-    # was sent inflight before the error.
-    self.assertFalse(closure_queue.wait(timeout=20))
-    self.assertEqual(closure_queue._inflight_closure_count, 1)
-    closure_queue.mark_finished(closure2)
-    closure_queue.wait()  # wait should pass immediately
-    self.assertEqual(closure_queue._inflight_closure_count, 0)
-
   def testThreadSafey(self):
     thread_count = 10
-    queue = client._CoordinatedClosureQueue()
+    queue = client._CoordinatedClosureQueue(MockCancellationManager())
 
     # Each thread performs 20 queue actions: 10 are `put_back` and 10 are
     # `mark_finished`.
@@ -372,7 +341,7 @@
         if i % 2 == 0:
           queue.put_back(closure)
         else:
-          queue.mark_finished(closure)
+          queue.mark_finished()
 
     threads = [threading.Thread(target=func) for i in range(thread_count)]
     for t in threads:
diff --git a/tensorflow/python/distribute/client/parameter_server_client_test.py b/tensorflow/python/distribute/client/parameter_server_client_test.py
index db22a47..32c7ff9 100644
--- a/tensorflow/python/distribute/client/parameter_server_client_test.py
+++ b/tensorflow/python/distribute/client/parameter_server_client_test.py
@@ -19,7 +19,10 @@
 from __future__ import division
 from __future__ import print_function
 
+import functools
+import threading
 from absl import logging
+
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import multi_worker_test_base
 from tensorflow.python.distribute import sharded_variable
@@ -40,6 +43,48 @@
 from tensorflow.python.training.server_lib import ClusterSpec
 
 
+class ErrorReportingThread(threading.Thread):
+
+  error = None
+
+  def __init__(self, *args, **kwargs):
+    assert "target" in kwargs
+    target = kwargs["target"]
+
+    @functools.wraps(target)
+    def wrapped_target(*args, **kwargs):
+      try:
+        return target(*args, **kwargs)
+      except Exception as e:  # pylint: disable=broad-except
+        ErrorReportingThread.error = e
+
+    kwargs["target"] = wrapped_target
+    super(ErrorReportingThread, self).__init__(*args, **kwargs)
+
+
+class TestCaseWithErrorReportingThread(test.TestCase):
+
+  @classmethod
+  def setUpClass(cls):
+    cls._threading_thread = threading.Thread
+    threading.Thread = ErrorReportingThread
+    super(TestCaseWithErrorReportingThread, cls).setUpClass()
+
+  @classmethod
+  def tearDownClass(cls):
+    super(TestCaseWithErrorReportingThread, cls).tearDownClass()
+    threading.Thread = cls._threading_thread
+
+  def setUp(self):
+    ErrorReportingThread.error = None
+    super(TestCaseWithErrorReportingThread, self).setUp()
+
+  def tearDown(self):
+    super(TestCaseWithErrorReportingThread, self).tearDown()
+    if ErrorReportingThread.error:
+      raise ErrorReportingThread.error  # pylint: disable=raising-bad-type
+
+
 def make_client(num_workers, num_ps):
   # TODO(rchao): Test the internal rpc_layer version.
   cluster_def = multi_worker_test_base.create_in_process_cluster(
@@ -52,7 +97,7 @@
   return parameter_server_client.ParameterServerClient(cluster_resolver)
 
 
-class ParameterServerClientTest(test.TestCase):
+class ParameterServerClientTest(TestCaseWithErrorReportingThread):
 
   @classmethod
   def setUpClass(cls):
@@ -304,7 +349,7 @@
     self.assertEqual(var_sum, 10.0)
 
 
-class ErrorReportingTest(test.TestCase):
+class ErrorReportingTest(TestCaseWithErrorReportingThread):
 
   @classmethod
   def setUpClass(cls):
@@ -344,8 +389,16 @@
       while True:
         self.client.schedule(self._normal_function)
 
+  def testScheduleRaiseErrorWithMultipleFailure(self):
+    for _ in range(3):
+      self.client.schedule(self._normal_function)
+    self.client.schedule(self._error_function)
+    with self.assertRaises(errors.InvalidArgumentError):
+      while True:
+        self.client.schedule(self._error_function)
+    self.client.join()
+
   def testErrorWillbeCleared(self):
-    self.skipTest("b/157597579")
     self.client.schedule(self._error_function)
     with self.assertRaises(errors.InvalidArgumentError):
       self.client.join()
@@ -356,7 +409,7 @@
     with self.assertRaises(errors.InvalidArgumentError):
       self.client.join()
 
-  def testFutureReturnError(self):
+  def testRemoteValueReturnError(self):
     result = self.client.schedule(self._error_function)
 
     with self.assertRaises(errors.InvalidArgumentError):