blob: 3829a431f51a9099dbc1b908d6891fbc57d8f601 [file] [log] [blame]
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for `ClusterCoordinator` and relevant cluster-worker related library.
This is currently under development and the API is subject to change.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import enum
import functools
import os
import re
import sys
import threading
import time
import weakref
from six.moves import queue
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.coordinator import metric_utils
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import executor
from tensorflow.python.eager import function as tf_function
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# Maximum time for failed worker to come back is 1 hour
_WORKER_MAXIMUM_RECOVERY_SEC = 3600
# Maximum size for queued closures, "infinite" if set to 0.
# When the maximum queue size is reached, further schedule calls will become
# blocking until some previously queued closures are executed on workers.
# Note that using an "infinite" queue size can take a non-trivial portion of
# memory, and even lead to coordinator OOM. Modify the size to a smaller value
# for coordinator with constrained memory resource (only recommended for
# advanced users). Also used in unit tests to ensure the correctness when the
# queue is full.
_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024
# RPC error message from PS
_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
# InvalidArgumentError (unknown device) will not have "GRPC error..." string.
_JOB_WORKER_STRING_IDENTIFIER = "/job:worker"
class _RemoteValueStatus(enum.Enum):
"""The status of a `RemoteValue` object.
A `RemoteValue` object can have three states:
1) not ready: no value, no non-retryable error and not aborted;
2) aborted: i.e. the execution of function was aborted because of task
failure, but can be retried;
3) ready: i.e. has value or has non-tryable error;
The initial state of a `RemoteValue` is "not ready". When its corresponding
closure has
been executed at least once, it will become aborted or ready. The state
transitions are:
1) not ready -> 2) aborted:
when the corresponding closure is aborted due to worker failure, and the
worker failure is not immediately handled.
1) not ready -> 3) ready:
when the corresponding closure has been executed successfully.
2) aborted -> 3) ready:
when the `RemoteValue` is rebuilt by rerunning the corresponding closure
and the closure has been executed successfully.
3) ready -> 2) aborted:
when the corresponding closure had been executed successfully but later
the corresponding remote worker failed. This is currently only implemented
for resource `RemoteValue` like iterators.
"""
NOT_READY = "NOT_READY"
ABORTED = "ABORTED"
READY = "READY"
@tf_export("distribute.experimental.coordinator.RemoteValue", v1=[])
class RemoteValue(object):
"""An asynchronously available value of a scheduled function.
This class is used as the return value of
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` where
the underlying value becomes available at a later time once the function has
been executed.
Using `tf.distribute.experimental.coordinator.RemoteValue` as an input to
a subsequent function scheduled with
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` is
currently not supported.
Example:
```python
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = (
tf.distribute.experimental.coordinator.ClusterCoordinator(strategy))
with strategy.scope():
v1 = tf.Variable(initial_value=0.0)
v2 = tf.Variable(initial_value=1.0)
@tf.function
def worker_fn():
v1.assign_add(0.1)
v2.assign_sub(0.2)
return v1.read_value() / v2.read_value()
result = coordinator.schedule(worker_fn)
# Note that `fetch()` gives the actual result instead of a `tf.Tensor`.
assert result.fetch() == 0.125
for _ in range(10):
# `worker_fn` will be run on arbitrary workers that are available. The
# `result` value will be available later.
result = coordinator.schedule(worker_fn)
```
"""
def fetch(self):
"""Wait for the result of `RemoteValue` to be ready and return the result.
This makes the value concrete by copying the remote value to local.
Returns:
The actual output of the `tf.function` associated with this `RemoteValue`,
previously by a
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call.
This can be a single value, or a structure of values, depending on the
output of the `tf.function`.
Raises:
tf.errors.CancelledError: If the function that produces this `RemoteValue`
is aborted or cancelled due to failure.
"""
raise NotImplementedError("Must be implemented in subclasses.")
class RemoteValueImpl(RemoteValue):
"""Implementation of `RemoteValue`."""
def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called
"""Initializes a `RemoteValueImpl`.
Args:
closure: The closure from which the `RemoteValue` is created.
type_spec: The type spec for this `RemoteValue` which is used to trace
functions that take this `RemoteValue` as input.
"""
self._closure = closure
self._type_spec = type_spec
self._values = None
self._fetched_numpys = None
self._error = None
self._status_available_event = threading.Event()
self._status = _RemoteValueStatus.NOT_READY
def _set_aborted(self):
self._status = _RemoteValueStatus.ABORTED
self._values = None
self._error = None
# Wake up any waiting thread and clear the event.
self._status_available_event.set()
def _rebuild_on(self, worker):
self._status_available_event.clear()
# TODO(yuefengz): we may need to rebuild its inputs as well.
self._closure.execute_on(worker)
def _set_values(self, tensors):
self._status = _RemoteValueStatus.READY
self._values = tensors
self._error = None
self._status_available_event.set()
def _set_error(self, exception):
self._status = _RemoteValueStatus.READY
self._values = None
self._error = exception
self._status_available_event.set()
def _get_values(self):
self._status_available_event.wait()
return self._values
def _get_error(self):
self._status_available_event.wait()
return self._error
def fetch(self):
self._status_available_event.wait()
if self._status is _RemoteValueStatus.ABORTED:
raise errors.CancelledError(
None, None,
"The corresponding function is aborted. Please reschedule the "
"function.")
if self._error is not None:
raise self._error
if self._fetched_numpys is None:
self._fetched_numpys = nest.map_structure(
lambda x: x.numpy() if hasattr(x, "numpy") else x, self._values)
return self._fetched_numpys
class InputError(Exception):
def __init__(self, original_exception):
message = ("Input has an error, the original exception is %r, "
"error message is %s." %
(original_exception, str(original_exception)))
super().__init__(message)
def _maybe_rebuild_remote_values(worker, structure):
"""Attempts to return errors from `RemoteValue`s. Rebuilds them if needed."""
errors_in_structure = []
def _get_error(val):
if isinstance(val, RemoteValue):
if val._status is _RemoteValueStatus.ABORTED: # pylint: disable=protected-access
try:
with worker.failure_handler.wait_on_failure(
on_recovery_fn=functools.partial(val._rebuild_on, worker), # pylint: disable=protected-access
worker_device_name=worker.device_name):
val._rebuild_on(worker) # pylint: disable=protected-access
except Exception as e: # pylint: disable=broad-except
val._set_error(e) # pylint: disable=protected-access
error = val._get_error() # pylint: disable=protected-access
if error:
errors_in_structure.append(error)
nest.map_structure(_get_error, structure)
if errors_in_structure:
return errors_in_structure[0]
else:
return None
def _maybe_get_remote_value(val):
"""Gets the value of `val` if it is a `RemoteValue`."""
if isinstance(val, RemoteValue):
error = val._get_error() # pylint: disable=protected-access
if error:
raise AssertionError(
"RemoteValue doesn't have a value because it has errors.")
else:
return val._get_values() # pylint: disable=protected-access
else:
return val
def _maybe_as_type_spec(val):
if isinstance(val, RemoteValue):
if val._type_spec is None: # pylint: disable=protected-access
raise ValueError("Output of a scheduled function that is not "
"tf.function cannot be the input of another function.")
return val._type_spec # pylint: disable=protected-access
else:
return val
@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
class PerWorkerValues(object):
"""A container that holds a list of values, one value per worker.
`tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection
of values, where each of the values is located on its corresponding worker,
and upon being used as one of the `args` or `kwargs` of
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the
value specific to a worker will be passed into the function being executed at
that corresponding worker.
Currently, the only supported path to create an object of
`tf.distribute.experimental.coordinator.PerWorkerValues` is through calling
`iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned
distributed dataset instance. The mechanism to create a custom
`tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported.
"""
def __init__(self, values):
self._values = tuple(values)
def _select_worker_slice(worker_id, structured):
"""Selects the worker slice of each of the items in `structured`."""
def _get(x):
return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access
return nest.map_structure(_get, structured)
def _disallow_remote_value_as_input(structured):
"""Raises if any element of `structured` is a RemoteValue."""
def _raise_if_remote_value(x):
if isinstance(x, RemoteValue):
raise ValueError(
"`tf.distribute.experimental.coordinator.RemoteValue` used "
"as an input to scheduled function is not yet "
"supported.")
nest.map_structure(_raise_if_remote_value, structured)
class Closure(object):
"""Hold a function to be scheduled and its arguments."""
def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
if not callable(function):
raise ValueError("Function passed to `ClusterCoordinator.schedule` must "
"be a callable object.")
self._args = args or ()
self._kwargs = kwargs or {}
_disallow_remote_value_as_input(self._args)
_disallow_remote_value_as_input(self._kwargs)
if isinstance(function, def_function.Function):
replica_args = _select_worker_slice(0, self._args)
replica_kwargs = _select_worker_slice(0, self._kwargs)
# Note: no need to handle function registration failure since this kind of
# failure will not raise exceptions as designed in the runtime. The
# coordinator has to rely on subsequent operations that raise to catch
# function registration failure.
# Record the function tracing overhead. Note that we pass in the tracing
# count of the def_function.Function as a state tracker, so that metrics
# will only record the time for actual function tracing (i.e., excluding
# function cache lookups).
with metric_utils.monitored_timer(
"function_tracing", state_tracker=function._get_tracing_count): # pylint: disable=protected-access
self._concrete_function = function.get_concrete_function(
*nest.map_structure(_maybe_as_type_spec, replica_args),
**nest.map_structure(_maybe_as_type_spec, replica_kwargs))
elif isinstance(function, tf_function.ConcreteFunction):
self._concrete_function = function
if hasattr(self, "_concrete_function"):
# If we have a concrete function, we get to retrieve the output type spec
# via the structured_output.
output_type_spec = func_graph.convert_structure_to_signature(
self._concrete_function.structured_outputs)
self._function = cancellation_mgr.get_cancelable_function(
self._concrete_function)
else:
# Otherwise (i.e. what is passed in is a regular python function), we have
# no such information.
output_type_spec = None
self._function = function
self.output_remote_value = RemoteValueImpl(self, output_type_spec)
def mark_cancelled(self):
self.output_remote_value._set_error( # pylint: disable=protected-access
errors.CancelledError(
None, None, "The corresponding function is "
"cancelled. Please reschedule the function."))
def execute_on(self, worker):
"""Executes the closure on the given worker.
Args:
worker: a `Worker` object.
"""
replica_args = _select_worker_slice(worker.worker_index, self._args)
replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
e = (
_maybe_rebuild_remote_values(worker, replica_args) or
_maybe_rebuild_remote_values(worker, replica_kwargs))
if e:
if not isinstance(e, InputError):
e = InputError(e)
self.output_remote_value._set_error(e) # pylint: disable=protected-access
return
with ops.device(worker.device_name):
with context.executor_scope(worker.executor):
with metric_utils.monitored_timer("closure_execution"):
output_values = self._function(
*nest.map_structure(_maybe_get_remote_value, replica_args),
**nest.map_structure(_maybe_get_remote_value, replica_kwargs))
self.output_remote_value._set_values(output_values) # pylint: disable=protected-access
class _CoordinatedClosureQueue(object):
"""Manage a queue of closures, inflight count and errors from execution.
This class is thread-safe.
"""
def __init__(self):
# `self._inflight_closure_count` only tracks the number of inflight closures
# that are "in generation". Once an error occurs, error generation is
# incremented and all subsequent arriving closures (from inflight) are
# considered "out of generation".
self._inflight_closure_count = 0
self._queue_lock = threading.Lock()
# Condition indicating that all pending closures (either queued or inflight)
# have been processed, failed, or cancelled.
self._stop_waiting_condition = threading.Condition(self._queue_lock)
# Condition indicating that an item becomes available in queue (not empty).
self._closures_queued_condition = threading.Condition(self._queue_lock)
# Condition indicating that a queue slot becomes available (not full).
# Note that even with "infinite" queue size, there is still a "practical"
# size limit for the queue depending on host memory capacity, and thus the
# queue will eventually become full with a lot of enqueued closures.
self._queue_free_slot_condition = threading.Condition(self._queue_lock)
# Condition indicating there is no inflight closures.
self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
# Use to cancel in-flight closures.
self._cancellation_mgr = cancellation.CancellationManager()
if _CLOSURE_QUEUE_MAX_SIZE <= 0:
logging.warning(
"In a `ClusterCoordinator`, creating an infinite closure queue can "
"consume a significant amount of memory and even lead to OOM.")
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
self._error = None
# The following is a lock to make sure when `wait` is called and before it
# returns no `put` can be executed during this period. It is because `wait`
# won't know what to do with newly put closures. This lock adds an cutoff
# for `wait` so that closures put into the queue while waiting would not be
# taken responsible by this `wait`.
#
# We cannot reuse the `self._queue_lock` since when `wait` waits for a
# condition, the `self._queue_lock` will be released.
#
# We don't use a reader/writer's lock on purpose to reduce the complexity
# of the code.
self._put_wait_lock = threading.Lock()
def _cancel_all_closures(self):
"""Clears the queue and sets remaining closures cancelled error.
This method expects self._queue_lock to be held prior to entry.
"""
self._cancellation_mgr.start_cancel()
while self._inflight_closure_count > 0:
self._no_inflight_closure_condition.wait()
while True:
try:
closure = self._queue.get(block=False)
self._queue_free_slot_condition.notify()
closure.mark_cancelled()
except queue.Empty:
break
# The cancellation manager cannot be reused once cancelled. After all
# closures (queued or inflight) are cleaned up, recreate the cancellation
# manager with clean state.
# Note on thread-safety: this is triggered when one of theses
# ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the
# same time, no new closures can be constructed (which reads the
# _cancellation_mgr to get cancellable functions).
self._cancellation_mgr = cancellation.CancellationManager()
def _raise_if_error(self):
"""Raises the error if one exists.
If an error exists, cancel the closures in queue, raises it, and clear
the error.
This method expects self._queue_lock to be held prior to entry.
"""
if self._error:
logging.error("Start cancelling closures due to error %r: %s",
self._error, self._error)
self._cancel_all_closures()
try:
raise self._error # pylint: disable=raising-bad-type
finally:
self._error = None
def put(self, closure):
"""Put a closure into the queue for later execution.
If `mark_failed` was called before `put`, the error from the first
invocation of `mark_failed` will be raised.
Args:
closure: The `Closure` to put into the queue.
"""
with self._put_wait_lock, self._queue_lock:
self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
self._queue.put(closure, block=False)
self._raise_if_error()
self._closures_queued_condition.notify()
def get(self, timeout=None):
"""Return a closure from the queue to be executed."""
with self._queue_lock:
while self._queue.empty():
if not self._closures_queued_condition.wait(timeout=timeout):
return None
closure = self._queue.get(block=False)
self._queue_free_slot_condition.notify()
self._inflight_closure_count += 1
return closure
def mark_finished(self):
"""Let the queue know that a closure has been successfully executed."""
with self._queue_lock:
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to mark_finished.")
self._inflight_closure_count -= 1
if self._inflight_closure_count == 0:
self._no_inflight_closure_condition.notifyAll()
if self._queue.empty() and self._inflight_closure_count == 0:
self._stop_waiting_condition.notifyAll()
def put_back(self, closure):
"""Put the closure back into the queue as it was not properly executed."""
with self._queue_lock:
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to put_back.")
if self._error:
closure.mark_cancelled()
else:
self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
self._queue.put(closure, block=False)
self._closures_queued_condition.notify()
self._inflight_closure_count -= 1
if self._inflight_closure_count == 0:
self._no_inflight_closure_condition.notifyAll()
def wait(self, timeout=None):
"""Wait for all closures to be finished before returning.
If `mark_failed` was called before or during `wait`, the error from the
first invocation of `mark_failed` will be raised.
Args:
timeout: A float specifying a timeout for the wait in seconds.
Returns:
True unless the given timeout expired, in which case it returns False.
"""
with self._put_wait_lock, self._queue_lock:
while (not self._error and
(not self._queue.empty() or self._inflight_closure_count > 0)):
if not self._stop_waiting_condition.wait(timeout=timeout):
return False
self._raise_if_error()
return True
def mark_failed(self, e):
"""Sets error and unblocks any wait() call."""
with self._queue_lock:
# TODO(yuefengz): maybe record all failure and give users more
# information?
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to mark_failed.")
if self._error is None:
self._error = e
self._inflight_closure_count -= 1
if self._inflight_closure_count == 0:
self._no_inflight_closure_condition.notifyAll()
self._stop_waiting_condition.notifyAll()
def done(self):
"""Returns true if the queue is empty and there is no inflight closure.
If `mark_failed` was called before `done`, the error from the first
invocation of `mark_failed` will be raised.
"""
with self._queue_lock:
self._raise_if_error()
return self._queue.empty() and self._inflight_closure_count == 0
class WorkerPreemptionHandler(object):
"""Handles worker preemptions."""
def __init__(self, server_def, cluster):
self._server_def = server_def
self._cluster = cluster
self._cluster_update_lock = threading.Lock()
self._cluster_due_for_update_or_finish = threading.Event()
self._worker_up_cond = threading.Condition(self._cluster_update_lock)
self._should_preemption_thread_run = True
threading.Thread(target=self._preemption_handler,
name="WorkerPreemptionHandler",
daemon=True).start()
def _mark_finished(self):
"""Ensure the worker preemption thread is closed."""
self._should_preemption_thread_run = False
with self._cluster_update_lock:
self._cluster_due_for_update_or_finish.set()
def _validate_preemption_failure(self, e):
"""Validates that the given exception represents worker preemption."""
if _is_worker_failure(e):
return
raise e
@contextlib.contextmanager
def wait_on_failure(self,
on_failure_fn=None,
on_recovery_fn=None,
worker_device_name="(unknown)"):
"""Catches worker preemption error and wait until failed workers are back.
Args:
on_failure_fn: an optional function to run if preemption happens.
on_recovery_fn: an optional function to run when a worker is recovered
from preemption.
worker_device_name: the device name of the worker instance that is passing
through the failure.
Yields:
None.
"""
try:
yield
except errors.OpError as e:
# If the error is due to temporary connectivity issues between worker and
# ps, put back closure, ignore error and do not mark worker as failure.
if self._cluster._record_and_ignore_transient_ps_failure(e): # pylint: disable=protected-access
if on_failure_fn:
on_failure_fn()
return
self._validate_preemption_failure(e)
logging.error("Worker %s failed with error: %s", worker_device_name, e)
if on_failure_fn:
on_failure_fn()
with self._cluster_update_lock:
self._cluster_due_for_update_or_finish.set()
self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
logging.info("Worker %s has been recovered.", worker_device_name)
if on_recovery_fn:
with self.wait_on_failure(
on_recovery_fn=on_recovery_fn,
worker_device_name=worker_device_name):
on_recovery_fn()
def _preemption_handler(self):
"""A loop that handles preemption.
This loop waits for signal of worker preemption and upon worker preemption,
it waits until all workers are back and updates the cluster about the
restarted workers.
"""
while True:
self._cluster_due_for_update_or_finish.wait()
if not self._should_preemption_thread_run:
break
with self._cluster_update_lock:
try:
# TODO(haoyuzhang): support partial cluster recovery
logging.info("Cluster now being recovered.")
context.context().update_server_def(self._server_def)
# Cluster updated successfully, clear the update signal, and notify
# all workers that they are recovered from failure.
logging.info("Cluster successfully recovered.")
self._worker_up_cond.notify_all()
self._cluster_due_for_update_or_finish.clear()
except Exception as e: # pylint: disable=broad-except
self._validate_preemption_failure(e)
# NOTE: Since the first RPC (GetStatus) of update_server_def is
# currently blocking by default, error should only happen if:
# (1) More workers failed while waiting for the previous workers to
# come back;
# (2) Worker failed when exchanging subsequent RPCs after the first
# RPC returns.
# Consider adding backoff retry logic if we see the error logged
# too frequently.
logging.error("Cluster update failed with error: %s. Retrying...", e)
class Worker(object):
"""A worker in a cluster.
Attributes:
worker_index: The index of the worker in the cluster.
device_name: The device string of the worker, e.g. "/job:worker/task:1".
executor: The worker's executor for remote function execution.
failure_handler: The failure handler used to handler worker preemption
failure.
"""
def __init__(self, worker_index, device_name, cluster):
self.worker_index = worker_index
self.device_name = device_name
self.executor = executor.new_executor(enable_async=False)
self.failure_handler = cluster.failure_handler
self._cluster = cluster
self._resource_remote_value_refs = []
# Worker threads need to start after `Worker`'s initialization.
threading.Thread(target=self._process_queue,
name="WorkerClosureProcessingLoop-%d" % self.worker_index,
daemon=True).start()
def _set_resources_aborted(self):
# TODO(yuefengz): maybe we can query whether a tensor is valid or not
# instead of marking a tensor aborted?
for weakref_resource in self._resource_remote_value_refs:
resource = weakref_resource()
if resource:
resource._set_aborted() # pylint: disable=protected-access
def _set_dead(self):
raise NotImplementedError("_set_dead is not implemented.")
def _process_closure(self, closure):
"""Runs a closure with preemption handling."""
try:
with self._cluster.failure_handler.wait_on_failure(
on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure), # pylint: disable=protected-access
on_recovery_fn=self._set_resources_aborted,
worker_device_name=self.device_name):
closure.execute_on(self)
# TODO(yuefengz): we don't have to materialize results every step.
with metric_utils.monitored_timer("remote_value_fetch"):
closure.output_remote_value.fetch()
self._cluster._closure_queue.mark_finished() # pylint: disable=protected-access
except Exception as e: # pylint: disable=broad-except
# Avoid logging the derived cancellation error
if not isinstance(e, errors.CancelledError):
logging.error(
"/job:worker/task:%d encountered the following error when "
"processing closure: %r:%s", self.worker_index, e, e)
closure.output_remote_value._set_error(e) # pylint: disable=protected-access
self._cluster._closure_queue.mark_failed(e) # pylint: disable=protected-access
def _maybe_delay(self):
"""Delay if corresponding env vars are set."""
# If the following two env vars variables are set. Scheduling for workers
# will start in a staggered manner. Worker i will wait for
# `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding
# `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`.
delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
delay_cap = int(
os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
if delay_cap:
delay_secs = min(delay_secs * self.worker_index, delay_cap)
if delay_secs > 0:
logging.info("Worker %d sleeping for %d seconds before running function",
self.worker_index, delay_secs)
time.sleep(delay_secs)
def _process_queue(self):
"""Function running in a thread to process closure queues."""
self._maybe_delay()
while True:
closure = self._cluster._closure_queue.get() # pylint: disable=protected-access
self._process_closure(closure)
def _create_resource(self, function, args=None, kwargs=None):
"""Synchronously creates a per-worker resource represented by a `RemoteValue`.
Args:
function: the resource function to be run remotely. It should be a
`tf.function`, a concrete function or a Python function.
args: positional arguments to be passed to the function.
kwargs: keyword arguments to be passed to the function.
Returns:
one or several RemoteValue objects depending on the function return
values.
"""
# Some notes about the concurrency: currently all the activities related to
# the same worker such as creating resources, setting resources' aborted
# status, and executing closures happen on the same thread. This allows us
# to have simpler logic of concurrency.
closure = Closure(
function,
self._cluster._closure_queue._cancellation_mgr, # pylint: disable=protected-access
args=args,
kwargs=kwargs)
resource_remote_value = closure.output_remote_value
self._register_resource(resource_remote_value)
# The following is a short-term solution to lazily create resources in
# parallel.
# TODO(b/160343165): we should create resources eagerly, i.e. schedule the
# resource creation function as soon as users call this method.
resource_remote_value._set_aborted() # pylint: disable=protected-access
return resource_remote_value
def _register_resource(self, resource_remote_value):
if not isinstance(resource_remote_value, RemoteValue):
raise ValueError("Resource being registered is not of type "
"`tf.distribute.experimental.coordinator.RemoteValue`.")
self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
class Cluster(object):
"""A cluster with workers.
We assume all function errors are fatal and based on this assumption our
error reporting logic is:
1) Both `schedule` and `join` can raise a non-retryable error which is the
first error seen by the coordinator from any previously scheduled functions.
2) When an error is raised, there is no guarantee on how many previously
scheduled functions have been executed; functions that have not been executed
will be thrown away and marked as cancelled.
3) After an error is raised, the internal state of error will be cleared.
I.e. functions can continue to be scheduled and subsequent calls of `schedule`
or `join` will not raise the same error again.
Attributes:
failure_handler: The failure handler used to handler worker preemption
failure.
workers: a list of `Worker` objects in the cluster.
"""
def __init__(self, strategy):
"""Initializes the cluster instance."""
self._num_workers = strategy._num_workers
self._num_ps = strategy._num_ps
# Ignore PS failures reported by workers due to transient connection errors.
# Transient connectivity issues between workers and PS are relayed by the
# workers to the coordinator, leading the coordinator to believe that there
# are PS failures. The difference between transient vs. permanent PS failure
# is the number of reports from the workers. When this env var is set to a
# positive integer K, the coordinator ignores up to K reports of a failed PS
# task, i.e., only when there are more than K trials of executing closures
# fail due to errors from the same PS instance do we consider the PS
# instance encounters a failure.
# TODO(b/164279603): Remove this workaround when the underlying connectivity
# issue in gRPC server is resolved.
self._transient_ps_failures_threshold = int(
os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
self._potential_ps_failures_lock = threading.Lock()
self._potential_ps_failures_count = [0] * self._num_ps
self._closure_queue = _CoordinatedClosureQueue()
self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
self)
worker_device_strings = [
"/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
]
self.workers = [
Worker(i, w, self) for i, w in enumerate(worker_device_strings)
]
self._strategy = strategy
def _record_and_ignore_transient_ps_failure(self, e):
"""Records potential PS failures and return if failure should be ignored."""
if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e):
return False
ps_tasks = _extract_failed_ps_instances(str(e))
with self._potential_ps_failures_lock:
for t in ps_tasks:
self._potential_ps_failures_count[t] += 1
# The number of UnavailableError encountered on this PS task exceeds the
# maximum number of ignored error
if (self._potential_ps_failures_count[t] >=
self._transient_ps_failures_threshold):
return False
return True
def schedule(self, function, args, kwargs):
"""Schedules `function` to be dispatched to a worker for execution.
Args:
function: The function to be dispatched to a worker for execution
asynchronously.
args: Positional arguments for `fn`.
kwargs: Keyword arguments for `fn`.
Returns:
A `RemoteValue` object.
"""
self._strategy.extended._being_scheduled = True # pylint: disable=protected-access
closure = Closure(
function,
self._closure_queue._cancellation_mgr, # pylint: disable=protected-access
args=args,
kwargs=kwargs)
self._strategy.extended._being_scheduled = False # pylint: disable=protected-access
self._closure_queue.put(closure)
return closure.output_remote_value
def join(self):
"""Blocks until all scheduled functions are executed."""
self._closure_queue.wait()
def done(self):
"""Returns true if all scheduled functions are executed."""
return self._closure_queue.done()
@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
class ClusterCoordinator(object):
"""An object to schedule and coordinate remote function execution.
This class is used to create fault-tolerant resources and dispatch functions
to remote TensorFlow servers.
Currently, this class is not supported to be used in a standalone manner. It
should be used in conjunction with a `tf.distribute` strategy that is designed
to work with it. The `ClusterCoordinator` class currently only works
`tf.distribute.experimental.ParameterServerStrategy`.
__The `schedule`/`join` APIs__
The most important APIs provided by this class is the `schedule`/`join` pair.
The `schedule` API is non-blocking in that it queues a `tf.function` and
returns a `RemoteValue` immediately. The queued functions will be dispatched
to remote workers in background threads and their `RemoteValue`s will be
filled asynchronously. Since `schedule` doesn’t require worker assignment, the
`tf.function` passed in can be executed on any available worker. If the worker
it is executed on becomes unavailable before its completion, it will be
migrated to another worker. Because of this fact and function execution is not
atomic, a function may be executed more than once.
__Handling Task Failure__
This class when used with
`tf.distribute.experimental.ParameterServerStrategy`, comes with built-in
fault tolerance for worker failures. That is, when some workers are not
available for any reason to be reached from the coordinator, the training
progress continues to be made with the remaining workers. Upon recovery of a
failed worker, it will be added for function execution after datasets created
by `create_per_worker_dataset` are re-built on it.
When a parameter server fails, a `tf.errors.UnavailableError` is raised by
`schedule`, `join` or `done`. In this case, in addition to bringing back the
failed parameter server, users should restart the coordinator so that it
reconnects to workers and parameter servers, re-creates the variables, and
loads checkpoints. If the coordinator fails, after the user brings it back,
the program will automatically connect to workers and parameter servers, and
continue the progress from a checkpoint.
It is thus essential that in user's program, a checkpoint file is periodically
saved, and restored at the start of the program. If an
`tf.keras.optimizers.Optimizer` is checkpointed, after restoring from a
checkpoiont, its `iterations` property roughly indicates the number of steps
that have been made. This can be used to decide how many epochs and steps are
needed before the training completion.
See `tf.distribute.experimental.ParameterServerStrategy` docstring for an
example usage of this API.
This is currently under development, and the API as well as implementation
are subject to changes.
"""
def __init__(self, strategy):
"""Initialization of a `ClusterCoordinator` instance.
Args:
strategy: a supported `tf.distribute.Strategy` object. Currently, only
`tf.distribute.experimental.ParameterServerStrategy` is supported.
Raises:
ValueError: if the strategy being used is not supported.
"""
if not isinstance(strategy,
parameter_server_strategy_v2.ParameterServerStrategyV2):
raise ValueError(
"Only `tf.distribute.experimental.ParameterServerStrategy` "
"is supported to work with "
"`tf.distribute.experimental.coordinator.ClusterCoordinator` "
"currently.")
self._strategy = strategy
self._strategy.extended._used_with_coordinator = True
self._cluster = Cluster(strategy)
def __del__(self):
# TODO(xingliu): Stop the worker threads.
self._cluster.failure_handler._mark_finished()
@property
def strategy(self):
"""Returns the `Strategy` associated with the `ClusterCoordinator`."""
return self._strategy
def schedule(self, fn, args=None, kwargs=None):
"""Schedules `fn` to be dispatched to a worker for asynchronous execution.
This method is non-blocking in that it queues the `fn` which will be
executed later and returns a
`tf.distribute.experimental.coordinator.RemoteValue` object immediately.
`fetch` can be called on it to wait for the function execution to finish
and retrieve its output from a remote worker. On the other hand, call
`tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for
all scheduled functions to finish.
`schedule` guarantees that `fn` will be executed on a worker at least once;
it could be more than once if its corresponding worker fails in the middle
of its execution. Note that since worker can fail at any point when
executing the function, it is possible that the function is partially
executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator`
guarantees that in those events, the function will eventually be executed on
any worker that is available.
If any previously scheduled function raises an error, `schedule` will raise
any one of those errors, and clear the errors collected so far. What happens
here, some of the previously scheduled functions may have not been executed.
User can call `fetch` on the returned
`tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
executed, failed, or cancelled, and reschedule the corresponding function if
needed.
When `schedule` raises, it guarantees that there is no function that is
still being executed.
At this time, there is no support of worker assignment for function
execution, or priority of the workers.
`args` and `kwargs` are the arguments passed into `fn`, when `fn` is
executed on a worker. They can be
`tf.distribute.experimental.coordinator.PerWorkerValues` and in this case,
the argument will be substituted with the corresponding component on the
target worker. Arguments that are not
`tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into
`fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue`
is not supported to be input `args` or `kwargs`.
Args:
fn: A `tf.function`; the function to be dispatched to a worker for
execution asynchronously. Regular python funtion is not supported to be
scheduled.
args: Positional arguments for `fn`.
kwargs: Keyword arguments for `fn`.
Returns:
A `tf.distribute.experimental.coordinator.RemoteValue` object that
represents the output of the function scheduled.
Raises:
Exception: one of the exceptions caught by the coordinator from any
previously scheduled function, since the last time an error was thrown
or since the beginning of the program.
"""
if not isinstance(fn,
(def_function.Function, tf_function.ConcreteFunction)):
raise TypeError(
"`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`"
" only accepts a `tf.function` or a concrete function.")
# Slot variables are usually created during function tracing time; thus
# `schedule` needs to be called within the `strategy.scope()`.
with self.strategy.scope():
return self._cluster.schedule(fn, args=args, kwargs=kwargs)
def join(self):
"""Blocks until all the scheduled functions have finished execution.
If any previously scheduled function raises an error, `join` will fail by
raising any one of those errors, and clear the errors collected so far. If
this happens, some of the previously scheduled functions may have not been
executed. Users can call `fetch` on the returned
`tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
executed, failed, or cancelled. If some that have been cancelled need to be
rescheduled, users should call `schedule` with the function again.
When `join` returns or raises, it guarantees that there is no function that
is still being executed.
Raises:
Exception: one of the exceptions caught by the coordinator by any
previously scheduled function since the last time an error was thrown or
since the beginning of the program.
"""
self._cluster.join()
def done(self):
"""Returns whether all the scheduled functions have finished execution.
If any previously scheduled function raises an error, `done` will fail by
raising any one of those errors.
When `done` returns True or raises, it guarantees that there is no function
that is still being executed.
Returns:
Whether all the scheduled functions have finished execution.
Raises:
Exception: one of the exceptions caught by the coordinator by any
previously scheduled function since the last time an error was thrown or
since the beginning of the program.
"""
return self._cluster.done()
def create_per_worker_dataset(self, dataset_fn):
"""Create dataset on workers by calling `dataset_fn` on worker devices.
This creates the given dataset generated by dataset_fn on workers
and returns an object that represents the collection of those individual
datasets. Calling `iter` on such collection of datasets returns a
`tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
collection of iterators, where the iterators have been placed on respective
workers.
Calling `next` on a `PerWorkerValues` of iterator is unsupported. The
iterator is meant to be passed as an argument into
`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When
the scheduled function is about to be executed by a worker, the
function will receive the individual iterator that corresponds to the
worker. The `next` method can be called on an iterator inside a
scheduled function when the iterator is an input of the function.
Currently the `schedule` method assumes workers are all the same and thus
assumes the datasets on different workers are the same, except they may be
shuffled differently if they contain a `dataset.shuffle` operation and a
random seed is not set. Because of this, we also recommend the datasets to
be repeated indefinitely and schedule a finite number of steps instead of
relying on the `OutOfRangeError` from a dataset.
Example:
```python
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy=strategy)
@tf.function
def worker_fn(iterator):
return next(iterator)
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(
lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
per_worker_iter = iter(per_worker_dataset)
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
assert remote_value.fetch() == 3
```
Args:
dataset_fn: The dataset function that returns a dataset. This is to be
executed on the workers.
Returns:
An object that represents the collection of those individual
datasets. `iter` is expected to be called on this object that returns
a `tf.distribute.experimental.coordinator.PerWorkerValues` of the
iterators (that are on the workers).
"""
input_workers = input_lib.InputWorkers([
(w.device_name, [w.device_name]) for w in self._cluster.workers
])
return _PerWorkerDistributedDataset(dataset_fn, input_workers, self)
def _create_per_worker_resources(self, fn, args=None, kwargs=None):
"""Synchronously create resources on the workers.
The resources are represented by
`tf.distribute.experimental.coordinator.RemoteValue`s.
Args:
fn: The function to be dispatched to all workers for execution
asynchronously.
args: Positional arguments for `fn`.
kwargs: Keyword arguments for `fn`.
Returns:
A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which
wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue`
objects.
"""
results = []
for w in self._cluster.workers:
results.append(w._create_resource(fn, args=args, kwargs=kwargs)) # pylint: disable=protected-access
return PerWorkerValues(tuple(results))
def fetch(self, val):
"""Blocking call to fetch results from the remote values.
This is a wrapper around
`tf.distribute.experimental.coordinator.RemoteValue.fetch` for a
`RemoteValue` structure; it returns the execution results of
`RemoteValue`s. If not ready, wait for them while blocking the caller.
Example:
```python
strategy = ...
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy)
def dataset_fn():
return tf.data.Dataset.from_tensor_slices([1, 1, 1])
with strategy.scope():
v = tf.Variable(initial_value=0)
@tf.function
def worker_fn(iterator):
def replica_fn(x):
v.assign_add(x)
return v.read_value()
return strategy.run(replica_fn, args=(next(iterator),))
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
distributed_iterator = iter(distributed_dataset)
result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
assert coordinator.fetch(result) == 1
```
Args:
val: The value to fetch the results from. If this is structure of
`tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be
called on the individual
`tf.distribute.experimental.coordinator.RemoteValue` to get the result.
Returns:
If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a
structure of `tf.distribute.experimental.coordinator.RemoteValue`s,
return the fetched `tf.distribute.experimental.coordinator.RemoteValue`
values immediately if they are available, or block the call until they are
available, and return the fetched
`tf.distribute.experimental.coordinator.RemoteValue` values with the same
structure. If `val` is other types, return it as-is.
"""
def _maybe_fetch(val):
if isinstance(val, RemoteValue):
return val.fetch()
else:
return val
# TODO(yuefengz): we should fetch values in a batch.
return nest.map_structure(_maybe_fetch, val)
# pylint: disable=missing-function-docstring
@contextlib.contextmanager
def handle_parameter_server_failure():
try:
yield
except errors.UnavailableError as e: # pylint: disable=broad-except
restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE",
None)
if restart_exit_code is not None:
sys.exit(int(restart_exit_code))
else:
raise
class _PerWorkerDistributedDataset(object):
"""Represents worker-distributed datasets created from dataset function."""
def __init__(self, dataset_fn, input_workers, coordinator):
"""Makes an iterable from datasets created by the given function.
Args:
dataset_fn: A function that returns a `Dataset`.
input_workers: an `InputWorkers` object.
coordinator: a `ClusterCoordinator` object, used to create dataset
resources.
"""
def disallow_variable_creation(next_creator, **kwargs):
raise ValueError("Creating variables in `dataset_fn` is not allowed.")
if isinstance(dataset_fn, def_function.Function):
with variable_scope.variable_creator_scope(disallow_variable_creation):
dataset_fn = dataset_fn.get_concrete_function()
elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
with variable_scope.variable_creator_scope(disallow_variable_creation):
dataset_fn = def_function.function(dataset_fn).get_concrete_function()
self._dataset_fn = dataset_fn
self._input_workers = input_workers
self._coordinator = coordinator
self._element_spec = None
def __iter__(self):
# We would like users to create iterators outside `tf.function`s so that we
# can track them.
if (not context.executing_eagerly() or
ops.get_default_graph().building_function):
raise RuntimeError(
"__iter__() is not supported inside of tf.function or in graph mode.")
def _create_per_worker_iterator():
dataset = self._dataset_fn()
return iter(dataset)
# If _PerWorkerDistributedDataset.__iter__ is called multiple
# times, for the same object it should only create and register resource
# once. Using object id to distinguish different iterator resources.
per_worker_iterator = self._coordinator._create_per_worker_resources(
_create_per_worker_iterator)
# Setting type_spec of each RemoteValue so that functions taking these
# RemoteValues as inputs can be traced.
for iterator_remote_value in per_worker_iterator._values:
iterator_remote_value._type_spec = ( # pylint: disable=protected-access
iterator_ops.IteratorSpec(
self._dataset_fn.structured_outputs.element_spec))
return _PerWorkerDistributedIterator(per_worker_iterator._values)
@property
def element_spec(self):
"""The type specification of an element of this dataset."""
raise NotImplementedError("Passing `AsyncDistributedDataset` to a "
"tf.function is not supported.")
class _PerWorkerDistributedIterator(PerWorkerValues):
"""Distributed iterator for `ClusterCoordinator`."""
def __next__(self):
return self.get_next()
def get_next(self, name=None):
"""Returns the next input from the iterator for all replicas."""
raise NotImplementedError("Iterating over an `AsyncDistributedIterator` "
"is not supported right now.")
def _extract_failed_ps_instances(err_msg):
"""Return a set of potentially failing ps instances from error message."""
tasks = re.findall("/job:ps/replica:0/task:[0-9]+", err_msg)
return set(int(t.split(":")[-1]) for t in tasks)
def _is_ps_failure(error):
"""Whether the error is considered a parameter server failure."""
return (isinstance(error, errors.UnavailableError) and
_RPC_ERROR_FROM_PS in str(error))
def _is_worker_failure(error):
"""Whether the error is considered a worker failure."""
if _JOB_WORKER_STRING_IDENTIFIER not in str(error):
return False
if _RPC_ERROR_FROM_PS in str(error):
return False
# TODO(haoyuzhang): Consider using special status code if error from a
# remote is derived from RPC errors originated from other hosts.
if isinstance(error, (errors.UnavailableError, errors.AbortedError)):
return True
# The following error could happen when the remote task fails and restarts
# in a very short interval during which no RPCs were exchanged to detect the
# failure. In that case, gRPC allows channel (which is different from a
# connection) to be reused for a replaced server listening to same address.
if isinstance(error, errors.InvalidArgumentError):
if ("unknown device" in str(error) or
"Unable to find the relevant tensor remote_handle" in str(error)):
# TODO(b/159961667): Fix "Unable to find the relevant tensor
# remote_handle" part.
return True
# TODO(b/162541228): The following 2 types of errors are very rare and only
# observed in large-scale testing. The types of errors should be reduced.
# This could happen when the function registration fails. In the observed
# cases this only happens to the dataset related functions.
if isinstance(error, errors.NotFoundError):
if ("is neither a type of a primitive operation nor a name of a function "
"registered" in str(error)):
return True
return False