blob: b87e353359dd8d81bac5170e5331900ae349caa3 [file] [log] [blame]
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for `WorkerPreemptionHandler`.
This is currently under development and the API is subject to change.
"""
import os
import signal
import sys
import threading
import time
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute.failure_handling import gce_util
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
_RUN_COUNT_KEY = 'RUN_TO_CHECKPOINT'
_ACKNOWLEDGE_KEY = 'RECEIVED_SIGNAL'
_ITERATION_VARIABLE = 'checkpointed_runs'
# TODO(wxinyi): Move the hardcoded 42 to an internal module.
_RESTARTABLE_EXIT_CODE = 42
_STOP_WATCHING_CLUSTER_VALUE = 'STOP_WATCHER'
def _mwms_write_checkpoint_dir(checkpoint_dir, task_type, task_id,
cluster_spec):
"""Returns checkpoint_dir for chief and a temp dir for any other worker."""
dirpath = os.path.dirname(checkpoint_dir)
base = os.path.basename(checkpoint_dir)
if not multi_worker_util.is_chief(
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id):
base_dirpath = 'workertemp_' + str(task_id)
dirpath = os.path.join(dirpath, base_dirpath)
gfile.MakeDirs(dirpath)
return os.path.join(dirpath, base)
class TerminationConfig(object):
"""Configurations to customize for a platform other than Google's Borg or GCP.
A TerminationConfig can be created and passed to the
`WorkerPreemptionHandler` to provide customization based on the platform.
It will deliver three pieces of information:
* How to decide if there is a termination event soon
The termination notification and how to fetch it varies across platforms. Thus
we accept a user-defined function, `termination_watcher_function`, and execute
it repeatedly to check for termination notification.
`termination_watcher_function` should be a function that returns True if a
termination notification has been made available and False otherwise. And the
function should be lightweight and non-blocking so that we can clean up the
resources properly if no termination signal is ever raised until training
finishes.
* How to exit the program
We are asking for an `restart_code` to execute `sys.exit(restart_code)` after
saving the checkpoint to exit the training program gracefully. A restart is
inevitable to reset the program's state. However, you can configure the
`restart_code` to facilitate the restart and make the training experience
smooth. How so? Maybe your platform has an agreement to a RESTART_CODE that’s
recognized as a program auto-restart signal, or you may have a coordinating
script that starts up the training, in which you can configure the program to
auto-restart if it ever exits with this RESTART_CODE. In both cases,
you can pass in this RESTART_CODE and then wouldn’t even notice that the
training has been interrupted and restarted.
* How long do we have from receiving a termination event notice till the
actual termination.
Some platforms have the gap time as long as, say, one hour. In this case, you
might want to utilize this time for training as much as possible until you
have to save a checkpoint and exit. We can utilize this information if you
pass it through the `time_till_termination` argument.
*The default behavior*:
If you are training with Google’s Borg system or GCP, we automatically detect
the platform and make the right configuration for you. Besides these two
platforms, the default behavior on an unrecognized platform is:
* If `termination_event` is `None`, we will treat `signal.SIGTERM` as a
termination event.
* If `restart_code` not configured, we exit with an arbitrary choice, 42.
* If `time_till_termination` is not configured, the default is 0, and we will
wrap up the current training step, save a checkpoint, and exit the program as
soon as we receive the termination signal.
"""
def __init__(self,
termination_watcher_function=None,
restart_code=None,
time_till_termination=None):
self.termination_watcher_function = termination_watcher_function
self.restart_code = restart_code
self.time_till_termination = time_till_termination
class GCPTerminationConfig(TerminationConfig):
def __init__( # pylint: disable=super-init-not-called
self,
termination_watcher_function=None,
restart_code=None,
time_till_termination=None):
self.termination_watcher_function = termination_watcher_function or gce_util.termination_watcher_function_gce
self.restart_code = restart_code or gce_util._RESTARTABLE_EXIT_CODE
self.time_till_termination = time_till_termination or gce_util.GRACE_PERIOD_GCE
class BorgTerminationConfig(TerminationConfig):
def __init__( # pylint: disable=super-init-not-called
self,
termination_watcher_function=None,
restart_code=None,
time_till_termination=None):
self.termination_watcher_function = termination_watcher_function
self.restart_code = restart_code or 42
self.time_till_termination = time_till_termination or 0
def _complete_config_for_environement(platform_device, termination_config):
"""Complete un-filled fields of TerminationConfig based on platform."""
if platform_device is gce_util.PlatformDevice.GCE_GPU:
return GCPTerminationConfig(termination_config.termination_watcher_function,
termination_config.restart_code,
termination_config.time_till_termination)
else:
# The default we chose are the same as the ones used by Borg. So we just
# return this.
return BorgTerminationConfig(
termination_config.termination_watcher_function,
termination_config.restart_code,
termination_config.time_till_termination)
class WorkerPreemptionHandler(object):
"""Preemption and error handler for synchronous training.
The API helps coordinate all workers to save a checkpoint upon receiving a
preemption signal and helps propagate accurate error messages during training.
When the program recovers from preemption, the checkpoint that is passed to
initialize a `WorkerPreemptionHandler` object will be loaded
automatically.
Right after the initialization, a thread starts to watch out for a termination
signal for any member in the cluster, but the signal will only be handled
(which includes aligning the step to save a checkpoint, saving a checkpoint,
and exiting with a platform recognized restart code) after entering a
`WorkerPreemptionHandler.run` call.
Example usage:
```python
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
dataset, model, optimizer = ...
fh_checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
worker_preemption_watcher = tf.distribute.WorkerPreemptionHandler(
cluster_resolver, fh_checkpoint, checkpoint_directory)
# `worker_preemption_watcher.total_runs` will be restored to its
# checkpointed value when training is restored after interruption.
for epoch in range(worker_preemption_watcher.total_runs //
STEPS_PER_EPOCH, num_epochs):
for step in range(worker_preemption_watcher.total_runs %
STEPS_PER_EPOCH, num_steps):
# distributed_train_step is a function wrapped by strategy.run
loss += worker_preemption_watcher.run(distributed_train_step,
args=(next(dataset),))
```
`WorkerPreemptionHandler` will create a CheckpointManager to manage the
checkpoint and only one CheckpointManager should be active in a particular
directory at a time. Thus, if the user would like to save a checkpoint for
purpose other than fault tolerance, e.g., for evaluation, they should save it
in a directory different from the one passed to a
`WorkerPreemptionHandler`.
This API targets multi-client distributed training, and right now only
`tf.distribute.MultiWorkerMirroredStrategy` is supported.
"""
def __init__(self,
cluster_resolver,
checkpoint,
checkpoint_dir,
termination_config=TerminationConfig()):
"""Creates the failure handler.
Args:
cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You
may also get it through the `cluster_resolver` attribute of the strategy
in use.
checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and
loaded upon restart by the `WorkerPreemptionHandler` API automatically.
checkpoint_dir: a directory for the `WorkerPreemptionHandler` to play with
checkpoints. `WorkerPreemptionHandler` will create a
`tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since
only one `tf.train.CheckpointManager` should be active in a particular
directory at a time, this `checkpoint_dir` arg should preferably be
separated from where the user saves their checkpoint for non-fault
tolerance purpose.
termination_config: a `TerminationConfig` object to configure for a
platform other than Google Borg or GCP.
"""
self._cluster_resolver = cluster_resolver
self._checkpoint = checkpoint
self._id_in_cluster = str(
multi_worker_util.id_in_cluster(
self._cluster_resolver.cluster_spec(),
self._cluster_resolver.task_type,
self._cluster_resolver.task_id))
# The number of calls to `WorkerPreemptionHandler.run` when the latest
# checkpoint was saved.
self._checkpointed_runs = variables.Variable(
initial_value=constant_op.constant(0, dtype=dtypes.int64),
trainable=False,
name=_ITERATION_VARIABLE)
if not hasattr(self._checkpoint,
_ITERATION_VARIABLE):
setattr(self._checkpoint, _ITERATION_VARIABLE,
self._checkpointed_runs)
# Make CheckpointManagers. MultiWorkerMirroredStrategy requires different
# setup on chief and on other workers.
self._read_checkpoint_manager = checkpoint_management.CheckpointManager(
checkpoint, directory=checkpoint_dir, max_to_keep=1)
if multi_worker_util.is_chief(
cluster_spec=cluster_resolver.cluster_spec(),
task_type=cluster_resolver.task_type,
task_id=cluster_resolver.task_id):
self._write_checkpoint_manager = self._read_checkpoint_manager
else:
self._write_checkpoint_manager = checkpoint_management.CheckpointManager(
checkpoint,
_mwms_write_checkpoint_dir(checkpoint_dir, cluster_resolver.task_type,
cluster_resolver.task_id,
cluster_resolver.cluster_spec()),
max_to_keep=1)
self._read_checkpoint_manager.restore_or_initialize()
# An internal step counter that's restored to checkpointed_iterations when
# training is restored. It increments by one every time
# `WorkerPreemptionHandler.run` is called. Note that in this case, the
# user must pass a single-step training function to
# `WorkerPreemptionHandler.run` instead of a multiple-step one.
self._run_counter = self._checkpointed_runs.numpy()
# The worker itself has received preeption signal.
self._received_own_sigterm = threading.Event()
# Some member (could be oneself) has received preemption signal, and the
# step number to save a checkpoint has been aligned.
self._received_sigterm_and_step = threading.Event()
# When training is interrupted, we explicitly call the cleanup methods for
# the thread watching for local worker's termination signal and the thread
# watching for clusterwise information before we save a checkpoint and exit.
# In the final chapter of the training where no interruption is encountered,
# we rely on __del__ to clean up. However, there is no guarantee when or
# whether __del__ is executed, thus we make the threads daemon to avoid it
# preventing program from exit.
self._cluster_wise_termination_watcher_thread = threading.Thread(
target=self._wait_for_signal,
name='PeerTerminationWatcher-%s' % self._id_in_cluster,
daemon=True)
self._cluster_wise_termination_watcher_thread.start()
logging.info('Start watcher for peer\'s signal.')
self._poll_termination_signal_thread = None
self._platform_device = gce_util.detect_platform()
completed_termination_config = _complete_config_for_environement(
self._platform_device, termination_config)
self._termination_watcher_function = completed_termination_config.termination_watcher_function
self._restart_code = completed_termination_config.restart_code
self._time_till_termination = completed_termination_config.time_till_termination
if completed_termination_config.termination_watcher_function:
self._start_polling_for_termination_signal()
else:
self._start_watching_for_signal()
def _start_watching_for_signal(self):
signal.signal(signal.SIGTERM, self._sigterm_handler_fn)
def _start_polling_for_termination_signal(self):
self._poll_termination_signal_thread_should_stop = threading.Event()
self._poll_termination_signal_thread = threading.Thread(
target=self._poll_termination_signal,
name='WorkerTerminationSignalWatcher-%s' % self._id_in_cluster,
daemon=True)
self._poll_termination_signal_thread.start()
logging.info('Start polling for termination signal.')
def _poll_termination_signal(self):
"""Poll maintenance notice and notify peers if receiving one."""
while True:
if self._poll_termination_signal_thread_should_stop.is_set():
return
if self._termination_watcher_function():
# For the countdown.
self._signal_receipt_time = time.time()
break
time.sleep(1)
logging.info('Member %s has received termination notice.',
self._id_in_cluster)
self._received_own_sigterm.set()
def _stop_poll_termination_signal_thread(self):
if self._poll_termination_signal_thread:
self._poll_termination_signal_thread_should_stop.set()
self._poll_termination_signal_thread.join()
self._poll_termination_signal_thread = None
logging.info('Shut down watcher for one\'s own termination signal')
def _stop_cluster_wise_termination_watcher_thread(self):
"""Stop the thread that is _wait_for_signal."""
if self._cluster_wise_termination_watcher_thread:
try:
if self._cluster_wise_termination_watcher_thread.is_alive():
context.context().set_config_key_value(_RUN_COUNT_KEY,
_STOP_WATCHING_CLUSTER_VALUE)
except: # pylint: disable=bare-except
# We'll ignore any error in the process of setting this key. There
# certainly will be a AlreadyExistError since all workers are trying to
# push this key. Or some worker might have exited already, leading to a
# errors.UnavailableError or errors.AbortedError. We'll also ignore
# other errors since they are not important to the process.
pass
finally:
self._cluster_wise_termination_watcher_thread.join()
self._cluster_wise_termination_watcher_thread = None
logging.info('Shut down watcher for peer\'s termination signal.')
def __del__(self):
self._stop_cluster_wise_termination_watcher_thread()
self._stop_poll_termination_signal_thread()
@property
def total_runs(self):
"""Returns the number of times `WorkerPreemptionHandler.run` is called.
This value tracks the number of all calls to
`WorkerPreemptionHandler.run` including those before the program is
restarted and the training is restored. The user can compute their total
number of iterations by:
`worker_preemption_watcher.run * number_of_steps_in_train_function`,
while for tf.distribute.MultiWorkerMirroredStrategy users,
`number_of_steps_in_train_function` should be one.
"""
return self._run_counter
def run(self,
distributed_train_function,
*args,
**kwargs):
"""Runs a training function with error and preemption handling.
This function handles the preemption signal from any peer in the cluster by
saving the training progress and exiting gracefully. (Specifically, when
running on Borg, it exits with a special code so that the cluster
automatically restarts the training after the down worker is back.) It will
also propagate any program error encountered during execution of
`distributed_train_function` to all workers so that they can raise the same
error.
The `distributed_train_function` argument should be a distributed train
function (i.e., containing a call to `tf.distribute.Strategy.run`). For
`tf.distribute.MultiWorkerMirroredStrategy` users, we recommend passing in a
single-step `distributed_train_function` to
`WorkerPreemptionHandler.run` so that the checkpoint can be saved in
time in case a preemption signal or maintenance notice is sent.
Besides the preemption and error handling part,
`WorkerPreemptionHandler.run(distributed_train_function, *args,
**kwargs)` has the same effect and output as
`distributed_train_function(*args, **kwargs)`. `distributed_train_function`
can return either some or no result. The following is a shortened example:
```python
@tf.function
def distributed_train_step(iterator):
# A distributed single-step training function.
def step_fn(inputs):
# A per-replica single-step training function.
x, y = inputs
...
return loss
per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
return strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
for epoch in range(worker_preemption_watcher.total_runs //
STEPS_PER_EPOCH, EPOCHS_TO_RUN):
iterator = iter(multi_worker_dataset)
total_loss = 0.0
num_batches = 0
for step in range(worker_preemption_watcher.total_runs %
STEPS_PER_EPOCH, STEPS_PER_EPOCH):
total_loss += worker_preemption_watcher.run(distributed_train_step)
num_batches += 1
train_loss = total_loss / num_batches
print('Epoch: %d, train_loss: %f.' %(epoch.numpy(), train_loss))
train_accuracy.reset_states()
```
Args:
distributed_train_function: A (single-step) distributed training function.
*args: args for `distributed_train_function`.
**kwargs: kwargs for `distributed_train_function`.
Raises:
Program error encountered by any member in the cluster encounters one
while executing the `distributed_train_function`, or any error from the
program error propagation process.
Returns:
Result of running the `distributed_train_function`.
"""
# TODO(wxinyi): after we support use with TPUStrategy, we should expand the
# API doc to state that `distributed_train_function` does not need to be a
# single-step training function, since a multi-step host-training loop is
# the dominant use case for TPU user. Besides, passing in a multi-step
# `distributed_train_function` will require the user to track their own
# training steps.
try:
self._checkpoint_if_preempted()
result = distributed_train_function(*args, **kwargs)
self._run_counter += 1
except errors.OpError as e:
logging.info('Propagating error to cluster: %r: %s', e, e)
try:
context.context().report_error_to_cluster(e.error_code, e.message)
except Exception as ex: # pylint: disable=broad-except
logging.info('Ignoring error during error propagation: %r:%s', ex, ex)
raise
return result
def _save_checkpoint_and_exit(self):
"""Saves the checkpoint and exit program."""
logging.info('Starting checkpoint and exit')
self._checkpointed_runs.assign(self.total_runs)
start_time = time.monotonic()
self._write_checkpoint_manager.save()
# All workers need to participate in saving a checkpoint to avoid
# deadlock. They need to write to different paths so that they would not
# override each other. We make temporary directories for non-chief
# workers to write to, and clean them up afterward.
if not multi_worker_util.is_chief(
cluster_spec=self._cluster_resolver.cluster_spec(),
task_type=self._cluster_resolver.task_type,
task_id=self._cluster_resolver.task_id):
gfile.DeleteRecursively(
os.path.dirname(self._write_checkpoint_manager.directory))
end_time = time.monotonic()
logging.info('Checkpoint finished at path %s',
self._write_checkpoint_manager.directory)
logging.info('Checkpoint time: %f', end_time - start_time)
self._stop_poll_termination_signal_thread()
self._stop_cluster_wise_termination_watcher_thread()
sys.exit(self._restart_code)
def _checkpoint_if_preempted(self):
"""Checkpoint if any worker has received a preemption signal.
This function handles preemption signal reported by any worker in the
cluster. The current implementation relies on the fact that all workers in a
MultiWorkerMirroredStrategy training cluster have a step number difference
maximum of 1.
- If the signal comes from the worker itself (i.e., where this failure
handler sits), the worker will notify all peers to checkpoint after they
finish CURRENT_STEP+1 steps, where CURRENT_STEP is the step this worker has
just finished. And the worker will wait for all peers to acknowledge that
they have received its preemption signal and the final-step number before
the worker proceeds on training the final step.
- If the signal comes from another member in the cluster but NO final-step
info is available, proceed on training, because it will be available after
finishing the next step.
- If the signal comes from some other member in the cluster, and final-step
info is available, if the worker has not finished these steps yet, keep
training; otherwise, checkpoint and exit with a cluster-recognized restart
code.
"""
if self._received_sigterm_and_step.is_set():
run_count_key = context.context().get_config_key_value(_RUN_COUNT_KEY)
if run_count_key == str(self._run_counter):
self._save_checkpoint_and_exit()
elif self._received_own_sigterm.is_set():
step_to_save_at = str(self._run_counter + 1)
try:
context.context().set_config_key_value(_RUN_COUNT_KEY, step_to_save_at)
logging.info('Termination caught in main thread on preempted worker')
logging.info('%s set to %s', _RUN_COUNT_KEY, step_to_save_at)
n_workers = multi_worker_util.worker_count(
self._cluster_resolver.cluster_spec(),
self._cluster_resolver.task_type)
for i in range(n_workers):
context.context().get_config_key_value(f'{_ACKNOWLEDGE_KEY}_{i}')
logging.info('Sigterm acknowledgement from replica %d received', i)
# This is to handle the case that some other worker receives termination
# notice as well, and it has made a step key available right before this
# worker attempts to set it. In this case, it incurs a config key
# AlreadyExistsError.
# With MultiWorkerMirroredStrategy, every step contains collective ops
# (all-reduce, all-gather, etc.) that require the participation of all
# workers, which forms a synchronization point. Thus the max difference in
# the training progresses made by the workers is less than one complete
# step (e.g., one worker is finishing up the post-collective ops part of
# step N, and another is doing the pre-collective ops part of step N+1.)
#
# We can safely ignore this AlreadyExistsError. Say both worker-a and
# worker-b have received preemption notice, and worker-b encounters an
# AlreadyExistsError here because worker-a has already uploaded a value as
# the last step to finish before saving a checkpoint. Assume worker-b has
# finished step N and attempt to set the last step as N+1. If the training
# progress made by worker-a is ahead of that of worker-b, then worker-a
# must be running step N+1 due to the mechanism mentioned above and has
# set the last step as step N+1. If worker-a is behind worker-b, then it
# cannot possibly have set the last step as step N (not to mention a step
# number less than N). Because worker-a would only do so before executing
# step N. Consider that when a worker resolves and sets the last step to
# finish, it waits until receiving acknowledgment from all workers before
# continuing to train the next step. And thus, in this case, worker-b
# would never have finished step N, which requires the participation of
# worker-a. In either case, we can safely ignore the error and revisit the
# checkpoint and exit option after finishing the current step, step N+1.
# At that time, it will be re-direct to the earlier branch.
except errors.AlreadyExistsError:
logging.info(
'Member %s has received termination notice. But some other'
' worker has received it as well! Leaving'
' it to them to decide when to checkpoint. ', self._id_in_cluster)
return
def _sigterm_handler_fn(self, signum, frame):
"""Upload the to-be-preempted worker's id to coordination service."""
del signum, frame
logging.info('Member %s has received termination signal.',
self._id_in_cluster)
self._received_own_sigterm.set()
def _wait_for_signal(self):
"""Watch out for peer preemption signal and step-to-save and acknowledge."""
step_key = context.context().get_config_key_value(_RUN_COUNT_KEY)
# get_config_key_value does not return until it gets some result. Thus at
# the time to clean up, we upload a _STOP_WATCHING_CLUSTER_VALUE as the
# value so we can join the thread executing _wait_for_signal.
if step_key != _STOP_WATCHING_CLUSTER_VALUE:
# This must be set before we set the ack key below, otherwise its value in
# _checkpoint_if_preempted may be outdated.
self._received_sigterm_and_step.set()
ack_key = f'{_ACKNOWLEDGE_KEY}_{self._id_in_cluster}'
context.context().set_config_key_value(ack_key, '1')
logging.info(
'WorkerPreemptionHandler._wait_for_signal: %s set, '
'preemption awareness acknowledged', ack_key)