Cancel in-flight closures when there is an error.
PiperOrigin-RevId: 324542620
Change-Id: I1d6cddf8130df74f00ce7b0a3b6b84f553990e78
diff --git a/tensorflow/python/distribute/client/BUILD b/tensorflow/python/distribute/client/BUILD
index 0f7b7df..35d8de9 100644
--- a/tensorflow/python/distribute/client/BUILD
+++ b/tensorflow/python/distribute/client/BUILD
@@ -32,6 +32,7 @@
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:values",
+ "//tensorflow/python/eager:cancellation",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:executor",
diff --git a/tensorflow/python/distribute/client/client.py b/tensorflow/python/distribute/client/client.py
index 533d5f1..7bef5e2 100644
--- a/tensorflow/python/distribute/client/client.py
+++ b/tensorflow/python/distribute/client/client.py
@@ -31,15 +31,19 @@
import weakref
from absl import logging
from six.moves import queue
+
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.client import metric_utils
+from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import executor
from tensorflow.python.eager import function as tf_function
from tensorflow.python.eager import remote
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
@@ -247,20 +251,28 @@
self._values = tuple(values)
+def _select_worker_slice(worker_id, structured):
+ """Selects the worker slice of each of the items in `structured`."""
+
+ def _get(x):
+ return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access
+
+ return nest.map_structure(_get, structured)
+
+
class Closure(object):
"""Hold a function to be scheduled and its arguments."""
- def __init__(self, function, args=None, kwargs=None):
+ def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
if not callable(function):
raise ValueError("Function passed to `Client.schedule` must be a "
"callable object.")
self._args = args or ()
self._kwargs = kwargs or {}
- self._function = function
if isinstance(function, def_function.Function):
- replica_args = self._select_worker_slice(0, self._args)
- replica_kwargs = self._select_worker_slice(0, self._kwargs)
+ replica_args = _select_worker_slice(0, self._args)
+ replica_kwargs = _select_worker_slice(0, self._kwargs)
# Note: no need to handle function registration failure since this kind of
# failure will not raise exceptions as designed in the runtime. The client
@@ -276,25 +288,22 @@
concrete_function = function.get_concrete_function(
*nest.map_structure(_maybe_as_type_spec, replica_args),
**nest.map_structure(_maybe_as_type_spec, replica_kwargs))
+ self._function = cancellation_mgr.get_cancelable_function(
+ concrete_function)
self._output_remote_values = nest.map_structure(
lambda x: RemoteValue(self, x), concrete_function.structured_outputs)
elif isinstance(function, tf_function.ConcreteFunction):
+ self._function = cancellation_mgr.get_cancelable_function(
+ concrete_function)
self._output_remote_values = nest.map_structure(
lambda x: RemoteValue(self, x), function.structured_outputs)
else:
# Regular python functions.
+ self._function = function
# TODO(yuefengz): maybe we should trace python functions if their inputs
# are Python primitives, tensors and composite tensors.
self._output_remote_values = RemoteValue(self, None)
- def _select_worker_slice(self, worker_id, structured):
- """Selects the worker slice of each of the items in `structured`."""
-
- def _get(x):
- return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access
-
- return nest.map_structure(_get, structured)
-
def _fetch_output_remote_values(self):
"""Temporary method used to sync the scheduler."""
# It will do nothing if there is no return value.
@@ -319,9 +328,8 @@
Args:
worker: a `Worker` object.
"""
- replica_args = self._select_worker_slice(worker.worker_index, self._args)
- replica_kwargs = self._select_worker_slice(worker.worker_index,
- self._kwargs)
+ replica_args = _select_worker_slice(worker.worker_index, self._args)
+ replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
e = (
_maybe_get_error_and_rebuild_remote_values(worker, replica_args) or
@@ -350,8 +358,7 @@
This class is thread-safe.
"""
- def __init__(self):
-
+ def __init__(self, cancellation_mgr):
# `self._inflight_closure_count` only tracks the number of inflight closures
# that are "in generation". Once an error occurs, error generation is
# incremented and all subsequent arriving closures (from inflight) are
@@ -359,17 +366,26 @@
self._inflight_closure_count = 0
self._queue_lock = threading.Lock()
+
# Condition indicating that all pending closures (either queued or inflight)
# have been processed, failed, or cancelled.
self._stop_waiting_condition = threading.Condition(self._queue_lock)
+
# Condition indicating that an item becomes available in queue (not empty).
self._closures_queued_condition = threading.Condition(self._queue_lock)
+
# Condition indicating that a queue slot becomes available (not full).
# Note that even with "infinite" queue size, there is still a "practical"
# size limit for the queue depending on host memory capacity, and thus the
# queue will eventually become full with a lot of enqueued closures.
self._queue_free_slot_condition = threading.Condition(self._queue_lock)
+ # Condition indicating there is no inflight closures.
+ self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
+
+ # Use to cancel in-flight closures.
+ self._cancellation_mgr = cancellation_mgr
+
if _CLOSURE_QUEUE_MAX_SIZE <= 0:
logging.warning(
"In ParameterServerClient, creating an infinite closure queue can "
@@ -377,31 +393,6 @@
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
self._error = None
- # Error generation is a counter that helps us track whether a closure
- # should be cancelled when it is being put back to `self._queue`. It works
- # in the following way:
- # 1) Error generation starts off at 0.
- # 2) When a worker thread calls `get()`, the closure's error generation
- # is copied from this queue's error generation.
- # 3) If any worker thread experiences an error that's categorized as a
- # non-retryable error, the queue's error will be set, error generation
- # increments by 1, and the queue is cleared (with the closures marked
- # with cancelled error), so other worker threads stop getting closures
- # from the queue. Worker preemption is categorized as a retryable error.
- # 4) At this point, if `put()` or `wait()` is called (usually by the main
- # thread via `schedule` and `join`), the error is raised through that
- # call.
- # 5) The closures that are inflight, i.e. that are being executed remotely,
- # will not be aware of such error event. If the worker that's executing
- # the closure happens to be interrupted, the closure should not be put
- # back to the queue, and be cancelled with error instead. Checking the
- # generation id of the closure and queue is how the worker thread tells
- # whether the closure should be put back. Likewise for `mark_finished`
- # and `mark_failed`: if the arriving closure is considered out of
- # generation in those two methods, it is simply discarded (the inflight
- # closure count still decrements).
- self._error_generation = 0
-
# The following is a lock to make sure when `wait` is called and before it
# returns no `put` can be executed during this period. It is because `wait`
# won't know what to do with newly put closures. This lock adds an cutoff
@@ -415,11 +406,14 @@
# of the code.
self._put_wait_lock = threading.Lock()
- def _cancel_closures_in_queue(self):
+ def _cancel_all_closures(self):
"""Clears the queue and sets remaining closures cancelled error.
This method expects self._queue_lock to be held prior to entry.
"""
+ self._cancellation_mgr.start_cancel()
+ while self._inflight_closure_count > 0:
+ self._no_inflight_closure_condition.wait()
while True:
try:
closure = self._queue.get(block=False)
@@ -437,8 +431,8 @@
This method expects self._queue_lock to be held prior to entry.
"""
if self._error:
+ self._cancel_all_closures()
try:
- self._cancel_closures_in_queue()
raise self._error # pylint: disable=raising-bad-type
finally:
self._error = None
@@ -466,16 +460,17 @@
return None
closure = self._queue.get(block=False)
self._queue_free_slot_condition.notify()
- closure._error_generation = self._error_generation # pylint: disable=protected-access
self._inflight_closure_count += 1
return closure
- def mark_finished(self, closure):
+ def mark_finished(self):
"""Let the queue know that a closure has been successfully executed."""
with self._queue_lock:
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to mark_finished.")
self._inflight_closure_count -= 1
+ if self._inflight_closure_count == 0:
+ self._no_inflight_closure_condition.notifyAll()
if self._queue.empty() and self._inflight_closure_count == 0:
self._stop_waiting_condition.notifyAll()
@@ -484,17 +479,15 @@
with self._queue_lock:
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to put_back.")
- self._inflight_closure_count -= 1
- if closure._error_generation < self._error_generation: # pylint: disable=protected-access
- # If the closure to put back is out of generation, cancel the closure
- # and ignore it.
- logging.info("Function %r should no longer be dispatched; marking "
- "as cancelled.")
+ if self._error:
closure._set_output_remote_values_cancelled() # pylint: disable=protected-access
- return
- self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
- self._queue.put(closure, block=False)
- self._closures_queued_condition.notify()
+ else:
+ self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
+ self._queue.put(closure, block=False)
+ self._closures_queued_condition.notify()
+ self._inflight_closure_count -= 1
+ if self._inflight_closure_count == 0:
+ self._no_inflight_closure_condition.notifyAll()
def wait(self, timeout=None):
"""Wait for all closures to be finished before returning.
@@ -516,22 +509,18 @@
self._raise_if_error()
return True
- def mark_failed(self, e, closure):
+ def mark_failed(self, e):
"""Sets error and unblocks any wait() call."""
with self._queue_lock:
# TODO(yuefengz): maybe record all failure and give users more
# information?
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to mark_failed.")
+ if self._error is None:
+ self._error = e
self._inflight_closure_count -= 1
- if closure._error_generation < self._error_generation: # pylint: disable=protected-access
- # If the closure to mark fail is out of generation, simply ignore it
- # (with the actual error associated with the closure preserved).
- return
- assert self._error is None
- self._error = e
- self._error_generation += 1
- self._cancel_closures_in_queue()
+ if self._inflight_closure_count == 0:
+ self._no_inflight_closure_condition.notifyAll()
self._stop_waiting_condition.notifyAll()
def done(self):
@@ -678,7 +667,7 @@
# TODO(yuefengz): we don't have to materialize results every step.
with metric_utils.monitored_timer("remote_value_fetch"):
closure._fetch_output_remote_values() # pylint: disable=protected-access
- self._cluster._closure_queue.mark_finished(closure) # pylint: disable=protected-access
+ self._cluster._closure_queue.mark_finished() # pylint: disable=protected-access
except Exception as e: # pylint: disable=broad-except
logging.error(
"/job:worker/task:%d encountered the following error when processing "
@@ -686,7 +675,7 @@
nest.map_structure(
lambda x: x._set_error(e), # pylint: disable=protected-access
closure._output_remote_values) # pylint: disable=protected-access
- self._cluster._closure_queue.mark_failed(e, closure) # pylint: disable=protected-access
+ self._cluster._closure_queue.mark_failed(e) # pylint: disable=protected-access
def _process_queue(self):
while True:
@@ -710,7 +699,8 @@
# the same worker such as creating resources, setting resources' aborted
# status, and executing closures happen on the same thread. This allows us
# to have simpler logic of concurrency.
- closure = Closure(function=function, args=args, kwargs=kwargs)
+ closure = Closure(
+ function, self._cluster._cancellation_mgr, args=args, kwargs=kwargs) # pylint: disable=protected-access
resource_remote_value = closure._output_remote_values # pylint: disable=protected-access
self._register_resource(resource_remote_value)
@@ -775,7 +765,8 @@
protocol=cluster_resolver.rpc_layer,
cluster_device_filters=device_filters)
- self._closure_queue = _CoordinatedClosureQueue()
+ self._cancellation_mgr = cancellation.CancellationManager()
+ self._closure_queue = _CoordinatedClosureQueue(self._cancellation_mgr)
self.failure_handler = WorkerPreemptionHandler(context.get_server_def())
worker_device_strings = [
"/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
@@ -796,7 +787,8 @@
Returns:
A structure of `RemoteValue` object.
"""
- closure = Closure(function=function, args=args, kwargs=kwargs)
+ closure = Closure(
+ function, self._cancellation_mgr, args=args, kwargs=kwargs)
self._closure_queue.put(closure)
return closure._output_remote_values # pylint: disable=protected-access
@@ -893,8 +885,8 @@
function execution to finish and retrieve its output from the remote worker.
`schedule` guarantees that `fn` will be executed on a worker at least once;
- it could be more than once if a worker fails and restarts in the middle of
- function scheduling. Note that since worker can fail at any point when
+ it could be more than once if its corresponding worker fails in the middle
+ of its execution. Note that since worker can fail at any point when
executing the function, it is possible that the function is partially
executed, but `Client` guarantees that in those events, the function will
eventually be fully executed, possibly on a different worker that is
@@ -904,14 +896,12 @@
by raising any one of those errors, and clear the errors collected so far.
There are two implications when this happens: 1) user should call `schedule`
with `fn` again to re-schedule, and 2) some of the previously scheduled
- functions may no longer execute. User can call `fetch` on the returned
+ functions may have not been executed. User can call `fetch` on the returned
`RemoteValue` to inspect if they have executed, failed, or cancelled, and
reschedule the corresponding function if needed.
- When `schedule` raises, it is possible that there are still functions being
- executed on workers, at the time `schedule` raises. When this happens, users
- can call `join` again to wait for all pending async function execution to
- finish, and bring the cluster into a consistent state.
+ When `schedule` raises, it guarantees that there is no function that is
+ still being executed.
At this time, there is no support of worker assignment for function
execution, or priority of the workers.
@@ -940,7 +930,8 @@
# TODO(b/160702436): Invoke `strategy.run` for user's function so it enters
# a `ReplicaContext` in a logically correct way.
with distribute_lib.ReplicaContext(
- self._strategy, replica_id_in_sync_group=0):
+ self._strategy,
+ replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
with self._translate_parameter_server_failure():
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
@@ -949,17 +940,14 @@
If any previously scheduled function raises an error, `join` will fail by
raising any one of those errors, and clear the errors collected so far. If
- this happens, some of the previously scheduled functions may no longer
- execute. Users can call `fetch` on the returned `RemoteValue` to inspect if
+ this happens, some of the previously scheduled functions may have not been
+ executed. Users can call `fetch` on the returned `RemoteValue` to inspect if
they have executed, failed, or cancelled. If some that have been cancelled
need to be rescheduled, users should call `schedule` with the function
again.
- Note: `join` raises an exception as soon as the client detects one, and this
- means it is possible that there are still functions being executed on
- workers, at the time `join` raises. When this happens, users can call `join`
- again to wait for all pending async function execution to finish, and bring
- the cluster into a consistent state.
+ When `join` returns or raises, it guarantees that there is no function that
+ is still being executed.
Raises:
Exception: one of the exceptions caught by the client by any previously
@@ -976,6 +964,9 @@
If any previously scheduled function raises an error, `done` will fail by
raising any one of those errors.
+
+ When `done` returns True or raises, it guarantees that there is no function
+ that is still being executed.
"""
return self.cluster.done()
@@ -1091,7 +1082,7 @@
raise
-class _PerWorkerDistributedDataset(object): # pylint: disable=protected-access
+class _PerWorkerDistributedDataset(object):
"""Represents worker-distributed datasets created from dataset function."""
def __init__(self, dataset_fn, input_workers, client):
@@ -1107,13 +1098,13 @@
if isinstance(dataset_fn, def_function.Function):
with variable_scope.variable_creator_scope(disallow_variable_creation):
- self._dataset_fn = dataset_fn.get_concrete_function()
- elif isinstance(dataset_fn, tf_function.ConcreteFunction):
- self._dataset_fn = dataset_fn
- else:
+ dataset_fn = dataset_fn.get_concrete_function()
+ elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
with variable_scope.variable_creator_scope(disallow_variable_creation):
- self._dataset_fn = def_function.function(
- dataset_fn).get_concrete_function()
+ dataset_fn = def_function.function(dataset_fn).get_concrete_function()
+ self._dataset_fn = (
+ client.cluster._cancellation_mgr.get_cancelable_function( # pylint: disable=protected-access
+ dataset_fn))
self._input_workers = input_workers
self._client = client
self._element_spec = None
diff --git a/tensorflow/python/distribute/client/client_test.py b/tensorflow/python/distribute/client/client_test.py
index 1215240..459633a 100644
--- a/tensorflow/python/distribute/client/client_test.py
+++ b/tensorflow/python/distribute/client/client_test.py
@@ -30,22 +30,34 @@
from tensorflow.python.util import nest
+class MockCancellationManager(object):
+
+ def __init__(self):
+ self.cancelled = False
+
+ def start_cancel(self):
+ self.cancelled = True
+
+ def get_cancelable_function(self, func):
+ return func
+
+
class CoordinatedClosureQueueTest(test.TestCase):
def testBasic(self):
- queue = client._CoordinatedClosureQueue()
+ queue = client._CoordinatedClosureQueue(MockCancellationManager())
closure1 = self._create_closure()
queue.put(closure1)
self.assertIs(closure1, queue.get())
self.assertFalse(queue.done())
queue.put_back(closure1)
self.assertEqual(closure1, queue.get())
- queue.mark_finished(closure1)
+ queue.mark_finished()
self.assertTrue(queue.done())
queue.wait()
def testProcessAtLeaseOnce(self):
- closure_queue = client._CoordinatedClosureQueue()
+ closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
labels = ['A', 'B', 'C', 'D', 'E']
processed_count = collections.defaultdict(int)
@@ -63,7 +75,7 @@
closure_queue.put_back(closure)
continue
closure._function()
- closure_queue.mark_finished(closure)
+ closure_queue.mark_finished()
def get_func(label):
@@ -76,7 +88,8 @@
return func
for label in labels:
- closure_queue.put(client.Closure(get_func(label)))
+ closure_queue.put(
+ client.Closure(get_func(label), MockCancellationManager()))
t1 = threading.Thread(target=process_queue, daemon=True)
t1.start()
t2 = threading.Thread(target=process_queue, daemon=True)
@@ -93,7 +106,7 @@
coord.join([t1, t2])
def testNotifyBeforeWait(self):
- closure_queue = client._CoordinatedClosureQueue()
+ closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
def func():
logging.info('func running')
@@ -102,10 +115,10 @@
def process_queue():
with coord.stop_on_exception():
- closure = closure_queue.get()
- closure_queue.mark_finished(closure)
+ closure_queue.get()
+ closure_queue.mark_finished()
- closure_queue.put(client.Closure(func))
+ closure_queue.put(client.Closure(func, MockCancellationManager()))
t = threading.Thread(target=process_queue)
t.start()
coord.join([t])
@@ -114,8 +127,30 @@
# doesn't time out.
closure_queue.wait()
+ def _assert_one_unblock_the_other(self, first_fn, second_fn):
+ """Asserts `second_fn` wouldn't return before `first_fn` is finished."""
+ first_fn_done = threading.Event()
+ second_fn_done = threading.Event()
+ coord = coordinator.Coordinator(clean_stop_exception_types=[])
+
+ def wrapped_first_fn():
+ with coord.stop_on_exception():
+ self.assertFalse(second_fn_done.is_set())
+ first_fn()
+ first_fn_done.set()
+
+ self.assertFalse(first_fn_done.is_set())
+ t = threading.Thread(target=wrapped_first_fn)
+ t.start()
+
+ second_fn()
+ self.assertTrue(first_fn_done.is_set())
+ second_fn_done.set()
+
+ coord.join([t])
+
def testWaitRaiseErrorAfterMarkFailure(self):
- closure_queue = client._CoordinatedClosureQueue()
+ closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
closure_queue.put(self._create_closure())
closure = closure_queue.get()
@@ -126,22 +161,17 @@
# all inflight closures are finished.
def mark_finished_fn():
- with coord.stop_on_exception():
- self.assertFalse(wait_finish_event.is_set())
- try:
- raise ValueError('Some error.')
- except ValueError as e:
- closure_queue.mark_failed(e, closure)
- wait_finish_event.wait()
+ try:
+ raise ValueError('Some error.')
+ except ValueError as e:
+ closure_queue.mark_failed(e)
- t = threading.Thread(target=mark_finished_fn)
- t.start()
+ def wait_fn():
+ with self.assertRaises(ValueError):
+ closure_queue.wait()
- with self.assertRaises(ValueError):
- closure_queue.wait()
- wait_finish_event.set()
+ self._assert_one_unblock_the_other(mark_finished_fn, wait_fn)
- coord.join([t])
self.assertTrue(closure_queue.done())
def _create_closure(self):
@@ -150,10 +180,10 @@
def some_function():
return 1.0
- return client.Closure(some_function)
+ return client.Closure(some_function, MockCancellationManager())
def _put_two_closures_and_get_one(self):
- closure_queue = client._CoordinatedClosureQueue()
+ closure_queue = client._CoordinatedClosureQueue(MockCancellationManager())
closure1 = self._create_closure()
closure_queue.put(closure1)
@@ -166,9 +196,9 @@
return closure_queue, closure1, closure2
def testPutRaiseError(self):
- closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
+ closure_queue, _, closure2 = self._put_two_closures_and_get_one()
- closure_queue.mark_failed(ValueError(), closure1)
+ closure_queue.mark_failed(ValueError())
with self.assertRaises(ValueError):
closure_queue.put(self._create_closure())
@@ -185,9 +215,9 @@
closure_queue.put(self._create_closure())
def testWaitRaiseError(self):
- closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
+ closure_queue, _, closure2 = self._put_two_closures_and_get_one()
- closure_queue.mark_failed(ValueError(), closure1)
+ closure_queue.mark_failed(ValueError())
with self.assertRaises(ValueError):
closure_queue.wait()
@@ -203,15 +233,22 @@
closure_queue.wait()
def testDoneRaiseError(self):
- closure_queue, closure1, _ = self._put_two_closures_and_get_one()
- closure_queue.get()
+ closure_queue, _, _ = self._put_two_closures_and_get_one()
self.assertFalse(closure_queue.done())
- closure_queue.mark_failed(ValueError(), closure1)
+ closure_queue.mark_failed(ValueError())
with self.assertRaises(ValueError):
closure_queue.done()
- def _test_error_reporting_and_cancel_flow(self, call_wait):
+ def _set_error(self, closure_queue, closure, error):
+ try:
+ raise error
+ except Exception as e: # pylint: disable=broad-except
+ nest.map_structure(lambda x: x._set_error(e),
+ closure._output_remote_values)
+ closure_queue.mark_failed(e)
+
+ def _test_cancel_closure_when_error(self, call_wait):
closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
closure_queue.put(self._create_closure())
closure_queue.get()
@@ -219,34 +256,37 @@
self.assertEqual(closure_queue._inflight_closure_count, 2)
# Simulating closure1 fails.
- try:
- raise ValueError('Some error.')
- except ValueError as e:
- nest.map_structure(lambda x: x._set_error(e),
- closure1._output_remote_values)
- self.assertEqual(closure_queue._error_generation, 0) # pylint: disable=g-assert-in-except
- closure_queue.mark_failed(e, closure1)
- self.assertEqual(closure_queue._error_generation, 1)
- # At this moment, there are one inflight, nothing
- # in queue (because the ones in queue should have been removed and
- # cancelled).
- self.assertTrue(closure_queue._queue.empty())
- # Doesn't include out of generation closures.
+ self._set_error(closure_queue, closure1, ValueError('Some error.'))
+
+ # At this moment, there are one inflight, one in queue.
+ self.assertEqual(closure_queue._queue.qsize(), 1)
self.assertEqual(closure_queue._inflight_closure_count, 1)
- coord = coordinator.Coordinator(clean_stop_exception_types=[])
closure3 = self._create_closure()
- with self.assertRaises(ValueError):
- # Verifying `wait()` or `put()` raises even if one closure is in
- # flight.
- if call_wait:
- closure_queue.wait()
- else:
- closure_queue.put(closure3)
- # At this moment, there is one inflight, nothing in queue.
+ def fake_cancellation():
+ self._set_error(closure_queue, closure2,
+ ValueError('Fake cancellation error.'))
+
+ def report_error():
+ # It should not report the fake cancellation error.
+ with self.assertRaisesRegex(ValueError, 'Some error.'):
+ # Verifying `wait()` or `put()` raises even if one closure is in
+ # flight.
+ if call_wait:
+ closure_queue.wait()
+ else:
+ closure_queue.put(closure3)
+
+ self._assert_one_unblock_the_other(fake_cancellation, report_error)
+
+ # Cancellation manager has been called.
+ self.assertTrue(closure_queue._cancellation_mgr.cancelled)
+
+ # At this moment, there is zero inflight, nothing in queue.
self.assertTrue(closure_queue._queue.empty())
- self.assertEqual(closure_queue._inflight_closure_count, 1)
+ self.assertEqual(closure_queue._inflight_closure_count, 0)
+ self.assertIsNone(closure_queue._error)
# This asserts that closure1 has errored.
with self.assertRaisesRegex(ValueError, 'Some error.'):
@@ -260,107 +300,36 @@
'function.'):
closure3._fetch_output_remote_values()
- # Closure2 is inflight, so it shouldn't be ready.
+ # Closure2 was an inflight closure when it got cancelled.
self.assertEqual(closure2._output_remote_values._status,
- client._RemoteValueStatus.NOT_READY)
-
- # And `wait` should block because closure2 is not back yet.
- self.assertFalse(closure_queue.wait(timeout=20))
-
- # Now let's assume that closure2 isn't successful due to worker preemption,
- # and now it's attempted to be put back, but ends up getting cancelled.
- self.assertEqual(closure2._error_generation, 0)
- self.assertEqual(closure_queue._error_generation, 1)
- closure_queue.put_back(closure2)
-
- with self.assertRaisesRegex(
- client.FunctionRetryableError,
- 'The corresponding function is cancelled. Please reschedule the '
- 'function.'):
+ client._RemoteValueStatus.READY)
+ with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
closure2._fetch_output_remote_values()
- # At this moment, there is nothing inflight, and the queue is also empty
- # (because closure2 should not be added back to the queue).
- self.assertTrue(closure_queue._queue.empty())
- self.assertEqual(closure_queue._inflight_closure_count, 0)
+ # This asserts that the queue has a clear state.
+ self.testBasic()
- closure4 = self._create_closure()
+ def testWaitRaiseErrorAfterCancelClosure(self):
+ self._test_cancel_closure_when_error(call_wait=True)
- e = threading.Event()
-
- def get_fn():
- with coord.stop_on_exception():
- # This should end up getting closure4, not closure2, because closure2
- # has been cancelled and should not be got.
- closure_got = closure_queue.get()
- e.set()
- self.assertEqual(closure_got._error_generation, 1)
- self.assertEqual(closure_queue._error_generation, 1)
- self.assertIs(closure4, closure_got)
- self.assertIsNot(closure2, closure_got)
-
- t = threading.Thread(target=get_fn)
- t.start()
-
- time.sleep(10)
-
- # Make sure `closure_got = closure_queue.get()` is unblocked as a result of
- # `closure_queue.put(closure4)`.
- self.assertFalse(e.is_set())
- closure_queue.put(closure4)
- self.assertTrue(e.wait())
- coord.join([t])
-
- self.assertEqual(closure_queue._inflight_closure_count, 1)
- closure_queue.mark_finished(closure4)
- # The queue is now cleared and nothing inflight.
- self.assertEqual(closure_queue._inflight_closure_count, 0)
- closure_queue.wait()
-
- def testWaitRaiseErrorAfterAnErrorIsReported(self):
- self._test_error_reporting_and_cancel_flow(call_wait=True)
-
- def testPutRaiseErrorAfterAnErrorIsReported(self):
- self._test_error_reporting_and_cancel_flow(call_wait=False)
+ def testPutRaiseErrorAfterCancelClosure(self):
+ self._test_cancel_closure_when_error(call_wait=False)
def testStateIsRestoredAfterJoinIsCalled(self):
- closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
- closure_queue.get()
- self.assertEqual(closure_queue._inflight_closure_count, 2)
- closure_queue.mark_failed(ValueError('test error'), closure1)
+ closure_queue, _, _ = self._put_two_closures_and_get_one()
+ self.assertEqual(closure_queue._inflight_closure_count, 1)
+ closure_queue.mark_failed(ValueError('test error'))
with self.assertRaises(ValueError):
closure_queue.put(self._create_closure())
- closure_queue.mark_failed(ValueError('test error'), closure2)
- # closure2's error is previous generation so should not raise at this
- # following put, and _error should have been cleared.
+ # Its error should have been cleared.
self.assertIsNone(closure_queue._error)
closure_queue.put(self._create_closure())
self.assertIsNone(closure_queue._error)
- def testStateIsRestoredAfterJoinIsCalled_WaitShouldReturn(self):
- closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
- closure_queue.put(self._create_closure())
- closure_queue.get() # got closure2
- self.assertFalse(closure_queue._queue.empty()) # still has closure3
- self.assertEqual(closure_queue._inflight_closure_count, 2) # closure1,2
- closure_queue.mark_failed(ValueError('test error'), closure1)
- self.assertTrue(closure_queue._queue.empty()) # closure3 cancelled
- self.assertEqual(closure_queue._inflight_closure_count, 1)
- with self.assertRaises(ValueError):
- closure_queue.wait() # reports error from closure1
-
- # `wait` should block because closure2 is not back yet, even if closure2
- # was sent inflight before the error.
- self.assertFalse(closure_queue.wait(timeout=20))
- self.assertEqual(closure_queue._inflight_closure_count, 1)
- closure_queue.mark_finished(closure2)
- closure_queue.wait() # wait should pass immediately
- self.assertEqual(closure_queue._inflight_closure_count, 0)
-
def testThreadSafey(self):
thread_count = 10
- queue = client._CoordinatedClosureQueue()
+ queue = client._CoordinatedClosureQueue(MockCancellationManager())
# Each thread performs 20 queue actions: 10 are `put_back` and 10 are
# `mark_finished`.
@@ -372,7 +341,7 @@
if i % 2 == 0:
queue.put_back(closure)
else:
- queue.mark_finished(closure)
+ queue.mark_finished()
threads = [threading.Thread(target=func) for i in range(thread_count)]
for t in threads:
diff --git a/tensorflow/python/distribute/client/parameter_server_client_test.py b/tensorflow/python/distribute/client/parameter_server_client_test.py
index db22a47..32c7ff9 100644
--- a/tensorflow/python/distribute/client/parameter_server_client_test.py
+++ b/tensorflow/python/distribute/client/parameter_server_client_test.py
@@ -19,7 +19,10 @@
from __future__ import division
from __future__ import print_function
+import functools
+import threading
from absl import logging
+
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import sharded_variable
@@ -40,6 +43,48 @@
from tensorflow.python.training.server_lib import ClusterSpec
+class ErrorReportingThread(threading.Thread):
+
+ error = None
+
+ def __init__(self, *args, **kwargs):
+ assert "target" in kwargs
+ target = kwargs["target"]
+
+ @functools.wraps(target)
+ def wrapped_target(*args, **kwargs):
+ try:
+ return target(*args, **kwargs)
+ except Exception as e: # pylint: disable=broad-except
+ ErrorReportingThread.error = e
+
+ kwargs["target"] = wrapped_target
+ super(ErrorReportingThread, self).__init__(*args, **kwargs)
+
+
+class TestCaseWithErrorReportingThread(test.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._threading_thread = threading.Thread
+ threading.Thread = ErrorReportingThread
+ super(TestCaseWithErrorReportingThread, cls).setUpClass()
+
+ @classmethod
+ def tearDownClass(cls):
+ super(TestCaseWithErrorReportingThread, cls).tearDownClass()
+ threading.Thread = cls._threading_thread
+
+ def setUp(self):
+ ErrorReportingThread.error = None
+ super(TestCaseWithErrorReportingThread, self).setUp()
+
+ def tearDown(self):
+ super(TestCaseWithErrorReportingThread, self).tearDown()
+ if ErrorReportingThread.error:
+ raise ErrorReportingThread.error # pylint: disable=raising-bad-type
+
+
def make_client(num_workers, num_ps):
# TODO(rchao): Test the internal rpc_layer version.
cluster_def = multi_worker_test_base.create_in_process_cluster(
@@ -52,7 +97,7 @@
return parameter_server_client.ParameterServerClient(cluster_resolver)
-class ParameterServerClientTest(test.TestCase):
+class ParameterServerClientTest(TestCaseWithErrorReportingThread):
@classmethod
def setUpClass(cls):
@@ -304,7 +349,7 @@
self.assertEqual(var_sum, 10.0)
-class ErrorReportingTest(test.TestCase):
+class ErrorReportingTest(TestCaseWithErrorReportingThread):
@classmethod
def setUpClass(cls):
@@ -344,8 +389,16 @@
while True:
self.client.schedule(self._normal_function)
+ def testScheduleRaiseErrorWithMultipleFailure(self):
+ for _ in range(3):
+ self.client.schedule(self._normal_function)
+ self.client.schedule(self._error_function)
+ with self.assertRaises(errors.InvalidArgumentError):
+ while True:
+ self.client.schedule(self._error_function)
+ self.client.join()
+
def testErrorWillbeCleared(self):
- self.skipTest("b/157597579")
self.client.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
@@ -356,7 +409,7 @@
with self.assertRaises(errors.InvalidArgumentError):
self.client.join()
- def testFutureReturnError(self):
+ def testRemoteValueReturnError(self):
result = self.client.schedule(self._error_function)
with self.assertRaises(errors.InvalidArgumentError):