Add platform configuration argument for WorkerPreemptionHandler to extend support to arbitrary platform.
PiperOrigin-RevId: 437223791
diff --git a/tensorflow/python/distribute/failure_handling/BUILD b/tensorflow/python/distribute/failure_handling/BUILD
index 5750dfa..b5f605f 100644
--- a/tensorflow/python/distribute/failure_handling/BUILD
+++ b/tensorflow/python/distribute/failure_handling/BUILD
@@ -15,6 +15,7 @@
srcs_version = "PY3",
deps = [
":gce_util",
+ "//tensorflow/python:variables",
"//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/eager:context",
"//tensorflow/python/framework:constant_op",
diff --git a/tensorflow/python/distribute/failure_handling/__init__.py b/tensorflow/python/distribute/failure_handling/__init__.py
index b03515e..c8cb802 100644
--- a/tensorflow/python/distribute/failure_handling/__init__.py
+++ b/tensorflow/python/distribute/failure_handling/__init__.py
@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Library imports for CoordinatedCheckpointManager."""
+"""Library imports for WorkerPreemptionHandler."""
-from tensorflow.python.distribute.failure_handling.failure_handling import CoordinatedCheckpointManager
+from tensorflow.python.distribute.failure_handling.failure_handling import WorkerPreemptionHandler
diff --git a/tensorflow/python/distribute/failure_handling/failure_handler_test.py b/tensorflow/python/distribute/failure_handling/failure_handler_test.py
index 5381067..ef38069 100644
--- a/tensorflow/python/distribute/failure_handling/failure_handler_test.py
+++ b/tensorflow/python/distribute/failure_handling/failure_handler_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for CoordinatedCheckpointManager."""
+"""Tests for WorkerPreemptionHandler."""
import os
import random
import re
@@ -70,7 +70,7 @@
class PreemptionCheckpointTest(test.TestCase, parameterized.TestCase):
- """Integration test for CoordinatedCheckpointManager."""
+ """Integration test for WorkerPreemptionHandler."""
def _mwms_write_checkpoint_dir(self, checkpoint_dir, cluster_spec, task_type,
task_id):
@@ -123,12 +123,12 @@
model = Model()
# Named it fh_ckpt because it'd be better that the user have their
# regular checkpoint separate from the checkpoint for
- # CoordinatedCheckpointManager, since we will create CheckpointManager
+ # WorkerPreemptionHandler, since we will create CheckpointManager
# to manage the checkpoint and only one CheckpointManager should be
# active in a particular directory at a time.
fh_ckpt = tracking_util.Checkpoint(model=model)
- failure_handler = failure_handling.CoordinatedCheckpointManager(
+ worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
strategy.cluster_resolver, fh_ckpt, checkpoint_dir)
def distributed_train_step(current_epoch, current_step):
@@ -147,13 +147,16 @@
if current_step == STEPS_PER_EPOCH - 1:
logging.info('epoch %d finished', current_epoch)
- logging.info('Restored training at %d', failure_handler.total_runs)
- for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH,
- EPOCHS_TO_RUN):
+ logging.info('Restored training at %d',
+ worker_preemption_watcher.total_runs)
+ for epoch in range(
+ worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
+ EPOCHS_TO_RUN):
- for step in range(failure_handler.total_runs % STEPS_PER_EPOCH,
- STEPS_PER_EPOCH):
- failure_handler.run(distributed_train_step, epoch, step)
+ for step in range(
+ worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
+ STEPS_PER_EPOCH):
+ worker_preemption_watcher.run(distributed_train_step, epoch, step)
# Add some randomness to when preemption actually happens. We should
# trigger it for sure if the training is coming to an end and it hasn't
# been triggered yet.
diff --git a/tensorflow/python/distribute/failure_handling/failure_handling.py b/tensorflow/python/distribute/failure_handling/failure_handling.py
index ca93874..b87e353 100644
--- a/tensorflow/python/distribute/failure_handling/failure_handling.py
+++ b/tensorflow/python/distribute/failure_handling/failure_handling.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Module for `CoordinatedCheckpointManager`.
+"""Module for `WorkerPreemptionHandler`.
This is currently under development and the API is subject to change.
"""
@@ -54,20 +54,125 @@
return os.path.join(dirpath, base)
-class CoordinatedCheckpointManager(object):
+class TerminationConfig(object):
+ """Configurations to customize for a platform other than Google's Borg or GCP.
+
+ A TerminationConfig can be created and passed to the
+ `WorkerPreemptionHandler` to provide customization based on the platform.
+ It will deliver three pieces of information:
+
+ * How to decide if there is a termination event soon
+
+ The termination notification and how to fetch it varies across platforms. Thus
+ we accept a user-defined function, `termination_watcher_function`, and execute
+ it repeatedly to check for termination notification.
+ `termination_watcher_function` should be a function that returns True if a
+ termination notification has been made available and False otherwise. And the
+ function should be lightweight and non-blocking so that we can clean up the
+ resources properly if no termination signal is ever raised until training
+ finishes.
+
+ * How to exit the program
+
+ We are asking for an `restart_code` to execute `sys.exit(restart_code)` after
+ saving the checkpoint to exit the training program gracefully. A restart is
+ inevitable to reset the program's state. However, you can configure the
+ `restart_code` to facilitate the restart and make the training experience
+ smooth. How so? Maybe your platform has an agreement to a RESTART_CODE that’s
+ recognized as a program auto-restart signal, or you may have a coordinating
+ script that starts up the training, in which you can configure the program to
+ auto-restart if it ever exits with this RESTART_CODE. In both cases,
+ you can pass in this RESTART_CODE and then wouldn’t even notice that the
+ training has been interrupted and restarted.
+
+ * How long do we have from receiving a termination event notice till the
+ actual termination.
+
+ Some platforms have the gap time as long as, say, one hour. In this case, you
+ might want to utilize this time for training as much as possible until you
+ have to save a checkpoint and exit. We can utilize this information if you
+ pass it through the `time_till_termination` argument.
+
+
+ *The default behavior*:
+
+ If you are training with Google’s Borg system or GCP, we automatically detect
+ the platform and make the right configuration for you. Besides these two
+ platforms, the default behavior on an unrecognized platform is:
+
+ * If `termination_event` is `None`, we will treat `signal.SIGTERM` as a
+ termination event.
+
+ * If `restart_code` not configured, we exit with an arbitrary choice, 42.
+
+ * If `time_till_termination` is not configured, the default is 0, and we will
+ wrap up the current training step, save a checkpoint, and exit the program as
+ soon as we receive the termination signal.
+ """
+
+ def __init__(self,
+ termination_watcher_function=None,
+ restart_code=None,
+ time_till_termination=None):
+ self.termination_watcher_function = termination_watcher_function
+ self.restart_code = restart_code
+ self.time_till_termination = time_till_termination
+
+
+class GCPTerminationConfig(TerminationConfig):
+
+ def __init__( # pylint: disable=super-init-not-called
+ self,
+ termination_watcher_function=None,
+ restart_code=None,
+ time_till_termination=None):
+ self.termination_watcher_function = termination_watcher_function or gce_util.termination_watcher_function_gce
+ self.restart_code = restart_code or gce_util._RESTARTABLE_EXIT_CODE
+ self.time_till_termination = time_till_termination or gce_util.GRACE_PERIOD_GCE
+
+
+class BorgTerminationConfig(TerminationConfig):
+
+ def __init__( # pylint: disable=super-init-not-called
+ self,
+ termination_watcher_function=None,
+ restart_code=None,
+ time_till_termination=None):
+ self.termination_watcher_function = termination_watcher_function
+ self.restart_code = restart_code or 42
+ self.time_till_termination = time_till_termination or 0
+
+
+def _complete_config_for_environement(platform_device, termination_config):
+ """Complete un-filled fields of TerminationConfig based on platform."""
+ if platform_device is gce_util.PlatformDevice.GCE_GPU:
+ return GCPTerminationConfig(termination_config.termination_watcher_function,
+ termination_config.restart_code,
+ termination_config.time_till_termination)
+
+ else:
+ # The default we chose are the same as the ones used by Borg. So we just
+ # return this.
+ return BorgTerminationConfig(
+ termination_config.termination_watcher_function,
+ termination_config.restart_code,
+ termination_config.time_till_termination)
+
+
+class WorkerPreemptionHandler(object):
"""Preemption and error handler for synchronous training.
The API helps coordinate all workers to save a checkpoint upon receiving a
preemption signal and helps propagate accurate error messages during training.
When the program recovers from preemption, the checkpoint that is passed to
- initialize a `CoordinatedCheckpointManager` object will be loaded
+ initialize a `WorkerPreemptionHandler` object will be loaded
automatically.
- Right after the initialization, a thread starts to watch out for a preemption
+ Right after the initialization, a thread starts to watch out for a termination
signal for any member in the cluster, but the signal will only be handled
(which includes aligning the step to save a checkpoint, saving a checkpoint,
and exiting with a platform recognized restart code) after entering a
- `CoordinatedCheckpointManager.run` call.
+ `WorkerPreemptionHandler.run` call.
Example usage:
```python
@@ -78,49 +183,54 @@
fh_checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
- coordinated_checkpoint_manager = tf.distribute.CoordinatedCheckpointManager(
+ worker_preemption_watcher = tf.distribute.WorkerPreemptionHandler(
cluster_resolver, fh_checkpoint, checkpoint_directory)
- # `coordinated_checkpoint_manager.total_runs` will be restored to its
+ # `worker_preemption_watcher.total_runs` will be restored to its
# checkpointed value when training is restored after interruption.
- for epoch in range(coordinated_checkpoint_manager.total_runs //
+ for epoch in range(worker_preemption_watcher.total_runs //
STEPS_PER_EPOCH, num_epochs):
- for step in range(coordinated_checkpoint_manager.total_runs %
+ for step in range(worker_preemption_watcher.total_runs %
STEPS_PER_EPOCH, num_steps):
# distributed_train_step is a function wrapped by strategy.run
- loss += coordinated_checkpoint_manager.run(distributed_train_step,
+ loss += worker_preemption_watcher.run(distributed_train_step,
args=(next(dataset),))
```
- `CoordinatedCheckpointManager` will create a CheckpointManager to manage the
+ `WorkerPreemptionHandler` will create a CheckpointManager to manage the
checkpoint and only one CheckpointManager should be active in a particular
directory at a time. Thus, if the user would like to save a checkpoint for
purpose other than fault tolerance, e.g., for evaluation, they should save it
in a directory different from the one passed to a
- `CoordinatedCheckpointManager`.
+ `WorkerPreemptionHandler`.
This API targets multi-client distributed training, and right now only
`tf.distribute.MultiWorkerMirroredStrategy` is supported.
"""
- def __init__(self, cluster_resolver, checkpoint, checkpoint_dir):
+ def __init__(self,
+ cluster_resolver,
+ checkpoint,
+ checkpoint_dir,
+ termination_config=TerminationConfig()):
"""Creates the failure handler.
Args:
cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You
- may also get it through the `cluster_resolver` attribute of the
- strategy in use.
+ may also get it through the `cluster_resolver` attribute of the strategy
+ in use.
checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and
- loaded upon restart by the `CoordinatedCheckpointManager` API
- automatically.
- checkpoint_dir: a directory for the `CoordinatedCheckpointManager` to play
- with checkpoints. `CoordinatedCheckpointManager` will create a
+ loaded upon restart by the `WorkerPreemptionHandler` API automatically.
+ checkpoint_dir: a directory for the `WorkerPreemptionHandler` to play with
+ checkpoints. `WorkerPreemptionHandler` will create a
`tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since
only one `tf.train.CheckpointManager` should be active in a particular
directory at a time, this `checkpoint_dir` arg should preferably be
separated from where the user saves their checkpoint for non-fault
tolerance purpose.
+ termination_config: a `TerminationConfig` object to configure for a
+ platform other than Google Borg or GCP.
"""
self._cluster_resolver = cluster_resolver
self._checkpoint = checkpoint
@@ -130,7 +240,7 @@
self._cluster_resolver.task_type,
self._cluster_resolver.task_id))
- # The number of calls to `CoordinatedCheckpointManager.run` when the latest
+ # The number of calls to `WorkerPreemptionHandler.run` when the latest
# checkpoint was saved.
self._checkpointed_runs = variables.Variable(
initial_value=constant_op.constant(0, dtype=dtypes.int64),
@@ -162,9 +272,9 @@
# An internal step counter that's restored to checkpointed_iterations when
# training is restored. It increments by one every time
- # `CoordinatedCheckpointManager.run` is called. Note that in this case, the
+ # `WorkerPreemptionHandler.run` is called. Note that in this case, the
# user must pass a single-step training function to
- # `CoordinatedCheckpointManager.run` instead of a multiple-step one.
+ # `WorkerPreemptionHandler.run` instead of a multiple-step one.
self._run_counter = self._checkpointed_runs.numpy()
# The worker itself has received preeption signal.
@@ -188,36 +298,39 @@
self._cluster_wise_termination_watcher_thread.start()
logging.info('Start watcher for peer\'s signal.')
- self._poll_gce_signal_thread = None
+ self._poll_termination_signal_thread = None
+
self._platform_device = gce_util.detect_platform()
- if self._platform_device is gce_util.PlatformDevice.GCE_GPU:
- self._start_polling_for_gce_signal()
- self._exit_code = gce_util._RESTARTABLE_EXIT_CODE
- elif self._platform_device is gce_util.PlatformDevice.INTERNAL:
- self._start_watching_for_signal()
- self._exit_code = _RESTARTABLE_EXIT_CODE
+
+ completed_termination_config = _complete_config_for_environement(
+ self._platform_device, termination_config)
+ self._termination_watcher_function = completed_termination_config.termination_watcher_function
+ self._restart_code = completed_termination_config.restart_code
+ self._time_till_termination = completed_termination_config.time_till_termination
+
+ if completed_termination_config.termination_watcher_function:
+ self._start_polling_for_termination_signal()
else:
- raise NotImplementedError('CoordinatedCheckpointManager is only supported'
- ' for MultiWorkerMirroredStrategy with GPU.')
+ self._start_watching_for_signal()
def _start_watching_for_signal(self):
signal.signal(signal.SIGTERM, self._sigterm_handler_fn)
- def _start_polling_for_gce_signal(self):
- self._poll_gce_signal_thread_should_stop = threading.Event()
- self._poll_gce_signal_thread = threading.Thread(
- target=self._poll_gce_signal,
+ def _start_polling_for_termination_signal(self):
+ self._poll_termination_signal_thread_should_stop = threading.Event()
+ self._poll_termination_signal_thread = threading.Thread(
+ target=self._poll_termination_signal,
name='WorkerTerminationSignalWatcher-%s' % self._id_in_cluster,
daemon=True)
- self._poll_gce_signal_thread.start()
+ self._poll_termination_signal_thread.start()
logging.info('Start polling for termination signal.')
- def _poll_gce_signal(self):
- """Poll maintenance notice from GCE and notify peers if receiving one."""
+ def _poll_termination_signal(self):
+ """Poll maintenance notice and notify peers if receiving one."""
while True:
- if self._poll_gce_signal_thread_should_stop.is_set():
+ if self._poll_termination_signal_thread_should_stop.is_set():
return
- if gce_util.signal_polling_fn():
+ if self._termination_watcher_function():
# For the countdown.
self._signal_receipt_time = time.time()
break
@@ -227,11 +340,13 @@
self._id_in_cluster)
self._received_own_sigterm.set()
- def _stop_poll_gce_signal_thread(self):
- if self._poll_gce_signal_thread:
- self._poll_gce_signal_thread_should_stop.set()
- self._poll_gce_signal_thread.join()
- self._poll_gce_signal_thread = None
+ def _stop_poll_termination_signal_thread(self):
+ if self._poll_termination_signal_thread:
+
+ self._poll_termination_signal_thread_should_stop.set()
+ self._poll_termination_signal_thread.join()
+
+ self._poll_termination_signal_thread = None
logging.info('Shut down watcher for one\'s own termination signal')
def _stop_cluster_wise_termination_watcher_thread(self):
@@ -258,17 +373,17 @@
def __del__(self):
self._stop_cluster_wise_termination_watcher_thread()
- self._stop_poll_gce_signal_thread()
+ self._stop_poll_termination_signal_thread()
@property
def total_runs(self):
- """Returns the number of times `CoordinatedCheckpointManager.run` is called.
+ """Returns the number of times `WorkerPreemptionHandler.run` is called.
This value tracks the number of all calls to
- `CoordinatedCheckpointManager.run` including those before the program is
+ `WorkerPreemptionHandler.run` including those before the program is
restarted and the training is restored. The user can compute their total
number of iterations by:
- `coordinated_checkpoint_manager.run * number_of_steps_in_train_function`,
+ `worker_preemption_watcher.run * number_of_steps_in_train_function`,
while for tf.distribute.MultiWorkerMirroredStrategy users,
`number_of_steps_in_train_function` should be one.
"""
@@ -292,11 +407,11 @@
function (i.e., containing a call to `tf.distribute.Strategy.run`). For
`tf.distribute.MultiWorkerMirroredStrategy` users, we recommend passing in a
single-step `distributed_train_function` to
- `CoordinatedCheckpointManager.run` so that the checkpoint can be saved in
+ `WorkerPreemptionHandler.run` so that the checkpoint can be saved in
time in case a preemption signal or maintenance notice is sent.
Besides the preemption and error handling part,
- `CoordinatedCheckpointManager.run(distributed_train_function, *args,
+ `WorkerPreemptionHandler.run(distributed_train_function, *args,
**kwargs)` has the same effect and output as
`distributed_train_function(*args, **kwargs)`. `distributed_train_function`
can return either some or no result. The following is a shortened example:
@@ -317,15 +432,15 @@
return strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
- for epoch in range(coordinated_checkpoint_manager.total_runs //
+ for epoch in range(worker_preemption_watcher.total_runs //
STEPS_PER_EPOCH, EPOCHS_TO_RUN):
iterator = iter(multi_worker_dataset)
total_loss = 0.0
num_batches = 0
- for step in range(coordinated_checkpoint_manager.total_runs %
+ for step in range(worker_preemption_watcher.total_runs %
STEPS_PER_EPOCH, STEPS_PER_EPOCH):
- total_loss += coordinated_checkpoint_manager.run(distributed_train_step)
+ total_loss += worker_preemption_watcher.run(distributed_train_step)
num_batches += 1
train_loss = total_loss / num_batches
@@ -392,9 +507,9 @@
logging.info('Checkpoint finished at path %s',
self._write_checkpoint_manager.directory)
logging.info('Checkpoint time: %f', end_time - start_time)
- self._stop_poll_gce_signal_thread()
+ self._stop_poll_termination_signal_thread()
self._stop_cluster_wise_termination_watcher_thread()
- sys.exit(self._exit_code)
+ sys.exit(self._restart_code)
def _checkpoint_if_preempted(self):
"""Checkpoint if any worker has received a preemption signal.
@@ -431,7 +546,6 @@
try:
context.context().set_config_key_value(_RUN_COUNT_KEY, step_to_save_at)
logging.info('Termination caught in main thread on preempted worker')
-
logging.info('%s set to %s', _RUN_COUNT_KEY, step_to_save_at)
n_workers = multi_worker_util.worker_count(
@@ -440,7 +554,6 @@
for i in range(n_workers):
context.context().get_config_key_value(f'{_ACKNOWLEDGE_KEY}_{i}')
logging.info('Sigterm acknowledgement from replica %d received', i)
-
# This is to handle the case that some other worker receives termination
# notice as well, and it has made a step key available right before this
# worker attempts to set it. In this case, it incurs a config key
@@ -498,5 +611,6 @@
ack_key = f'{_ACKNOWLEDGE_KEY}_{self._id_in_cluster}'
context.context().set_config_key_value(ack_key, '1')
- logging.info('CoordinatedCheckpointManager._wait_for_signal: %s set, '
- 'preemption awareness acknowledged', ack_key)
+ logging.info(
+ 'WorkerPreemptionHandler._wait_for_signal: %s set, '
+ 'preemption awareness acknowledged', ack_key)
diff --git a/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py b/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py
index 386be23..8f84dc5 100644
--- a/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py
+++ b/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for GCE specifics of CoordinatedCheckpointManager."""
+"""Tests for GCE specifics of WorkerPreemptionHandler."""
import os
import random
import re
@@ -69,7 +69,7 @@
class GceFailureHandlingTest(test.TestCase, parameterized.TestCase):
- """Integration test for CoordinatedCheckpointManager."""
+ """Integration test for WorkerPreemptionHandler."""
def _mwms_write_checkpoint_dir(self, checkpoint_dir, cluster_spec, task_type,
task_id):
@@ -85,33 +85,33 @@
def worker_fn(self,
checkpoint_dir,
cluster_spec,
- maintenance_event,
- training_finished,
+ maintenance_event=None,
+ training_finished=None,
frequent_send=False):
_enable_coordination_service(cluster_spec)
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
- def mock_request_compute_metadata(*args, **kwargs):
- del kwargs # Unused.
- if args[0] == 'instance/maintenance-event':
- if not frequent_send:
- time.sleep(1)
- if (not maintenance_event.is_set()) and (random.randrange(0, 20) >
- 18):
- maintenance_event.set()
- logging.info('Maintenance notice available.')
- return 'TERMINATE_ON_HOST_MAINTENANCE'
- elif frequent_send and not maintenance_event.is_set():
- logging.info('Maintenance notice available.')
- return 'TERMINATE_ON_HOST_MAINTENANCE'
+ def mock_termination_watcher_function_gce(*args, **kwargs):
+ del args, kwargs
+ if not frequent_send:
+ time.sleep(1)
+ if (not maintenance_event.is_set()) and (random.randrange(0, 20) > 18):
+ maintenance_event.set()
+ logging.info('Termination notice available.')
+ return True
- return 'NONE'
+ elif frequent_send and not maintenance_event.is_set():
+ logging.info('Termination notice available.')
+ return True
- with mock.patch.object(gce_util, 'request_compute_metadata',
- mock_request_compute_metadata), mock.patch.object(
- gce_util, 'detect_platform',
- lambda: gce_util.PlatformDevice.GCE_GPU):
+ return False
+
+ with mock.patch.object(
+ gce_util, 'termination_watcher_function_gce',
+ mock_termination_watcher_function_gce), mock.patch.object(
+ gce_util, 'detect_platform',
+ lambda: gce_util.PlatformDevice.GCE_GPU):
class Model(module.Module):
@@ -129,7 +129,7 @@
model = Model()
fh_ckpt = tracking_util.Checkpoint(model=model)
- failure_handler = failure_handling.CoordinatedCheckpointManager(
+ worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
strategy.cluster_resolver, fh_ckpt, checkpoint_dir)
def distributed_train_step(current_epoch, current_step):
@@ -143,13 +143,15 @@
if current_step == STEPS_PER_EPOCH - 1:
logging.info('epoch %d finished', current_epoch)
- logging.info('Start training at %d', failure_handler.total_runs)
- for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH,
- EPOCHS_TO_RUN):
+ logging.info('Start training at %d', worker_preemption_watcher.total_runs)
+ for epoch in range(
+ worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
+ EPOCHS_TO_RUN):
- for step in range(failure_handler.total_runs % STEPS_PER_EPOCH,
- STEPS_PER_EPOCH):
- failure_handler.run(distributed_train_step, epoch, step)
+ for step in range(
+ worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
+ STEPS_PER_EPOCH):
+ worker_preemption_watcher.run(distributed_train_step, epoch, step)
training_finished.set()
@@ -165,7 +167,7 @@
try:
# Explicitly call __del__ since making it None and gc.collect does
# not invoke __del__ here.
- failure_handler.__del__()
+ worker_preemption_watcher.__del__()
time.sleep(2)
diff --git a/tensorflow/python/distribute/failure_handling/gce_util.py b/tensorflow/python/distribute/failure_handling/gce_util.py
index 7c73328..29b1306 100644
--- a/tensorflow/python/distribute/failure_handling/gce_util.py
+++ b/tensorflow/python/distribute/failure_handling/gce_util.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Integration of CoordinatedCheckpointManager with GCE specific logic."""
+"""Util of GCE specifics to ingegrate with WorkerPreemptionHandler."""
import enum
import os
@@ -25,9 +25,10 @@
GCP_METADATA_HEADER = {'Metadata-Flavor': 'Google'}
_GCE_METADATA_URL_ENV_VARIABLE = 'GCE_METADATA_IP'
_RESTARTABLE_EXIT_CODE = 143
+GRACE_PERIOD_GCE = 0
-def request_compute_metadata(path: str) -> str:
+def request_compute_metadata(path):
"""Returns GCE VM compute metadata."""
gce_metadata_endpoint = 'http://' + os.environ.get(
_GCE_METADATA_URL_ENV_VARIABLE, 'metadata.google.internal')
@@ -41,7 +42,7 @@
return info
-def signal_polling_fn() -> bool:
+def termination_watcher_function_gce():
result = request_compute_metadata(
'instance/maintenance-event') == 'TERMINATE_ON_HOST_MAINTENANCE'
return result
@@ -74,7 +75,7 @@
UNSUPPORTED = 'unsupported'
-def detect_platform() -> PlatformDevice:
+def detect_platform():
"""Returns the platform and device information."""
if on_gcp():
if context.context().list_physical_devices('GPU'):