Add platform configuration argument for WorkerPreemptionHandler to extend support to arbitrary platform.

PiperOrigin-RevId: 437223791
diff --git a/tensorflow/python/distribute/failure_handling/BUILD b/tensorflow/python/distribute/failure_handling/BUILD
index 5750dfa..b5f605f 100644
--- a/tensorflow/python/distribute/failure_handling/BUILD
+++ b/tensorflow/python/distribute/failure_handling/BUILD
@@ -15,6 +15,7 @@
     srcs_version = "PY3",
     deps = [
         ":gce_util",
+        "//tensorflow/python:variables",
         "//tensorflow/python/distribute:multi_worker_util",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/framework:constant_op",
diff --git a/tensorflow/python/distribute/failure_handling/__init__.py b/tensorflow/python/distribute/failure_handling/__init__.py
index b03515e..c8cb802 100644
--- a/tensorflow/python/distribute/failure_handling/__init__.py
+++ b/tensorflow/python/distribute/failure_handling/__init__.py
@@ -12,6 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Library imports for CoordinatedCheckpointManager."""
+"""Library imports for WorkerPreemptionHandler."""
 
-from tensorflow.python.distribute.failure_handling.failure_handling import CoordinatedCheckpointManager
+from tensorflow.python.distribute.failure_handling.failure_handling import WorkerPreemptionHandler
diff --git a/tensorflow/python/distribute/failure_handling/failure_handler_test.py b/tensorflow/python/distribute/failure_handling/failure_handler_test.py
index 5381067..ef38069 100644
--- a/tensorflow/python/distribute/failure_handling/failure_handler_test.py
+++ b/tensorflow/python/distribute/failure_handling/failure_handler_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for CoordinatedCheckpointManager."""
+"""Tests for WorkerPreemptionHandler."""
 import os
 import random
 import re
@@ -70,7 +70,7 @@
 
 
 class PreemptionCheckpointTest(test.TestCase, parameterized.TestCase):
-  """Integration test for CoordinatedCheckpointManager."""
+  """Integration test for WorkerPreemptionHandler."""
 
   def _mwms_write_checkpoint_dir(self, checkpoint_dir, cluster_spec, task_type,
                                  task_id):
@@ -123,12 +123,12 @@
         model = Model()
         # Named it fh_ckpt because it'd be better that the user have their
         # regular checkpoint separate from the checkpoint for
-        # CoordinatedCheckpointManager, since we will create CheckpointManager
+        # WorkerPreemptionHandler, since we will create CheckpointManager
         # to manage the checkpoint and only one CheckpointManager should be
         # active in a particular directory at a time.
         fh_ckpt = tracking_util.Checkpoint(model=model)
 
-        failure_handler = failure_handling.CoordinatedCheckpointManager(
+        worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
             strategy.cluster_resolver, fh_ckpt, checkpoint_dir)
 
       def distributed_train_step(current_epoch, current_step):
@@ -147,13 +147,16 @@
         if current_step == STEPS_PER_EPOCH - 1:
           logging.info('epoch %d finished', current_epoch)
 
-      logging.info('Restored training at %d', failure_handler.total_runs)
-      for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH,
-                         EPOCHS_TO_RUN):
+      logging.info('Restored training at %d',
+                   worker_preemption_watcher.total_runs)
+      for epoch in range(
+          worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
+          EPOCHS_TO_RUN):
 
-        for step in range(failure_handler.total_runs % STEPS_PER_EPOCH,
-                          STEPS_PER_EPOCH):
-          failure_handler.run(distributed_train_step, epoch, step)
+        for step in range(
+            worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
+            STEPS_PER_EPOCH):
+          worker_preemption_watcher.run(distributed_train_step, epoch, step)
         # Add some randomness to when preemption actually happens. We should
         # trigger it for sure if the training is coming to an end and it hasn't
         # been triggered yet.
diff --git a/tensorflow/python/distribute/failure_handling/failure_handling.py b/tensorflow/python/distribute/failure_handling/failure_handling.py
index ca93874..b87e353 100644
--- a/tensorflow/python/distribute/failure_handling/failure_handling.py
+++ b/tensorflow/python/distribute/failure_handling/failure_handling.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Module for `CoordinatedCheckpointManager`.
+"""Module for `WorkerPreemptionHandler`.
 
 This is currently under development and the API is subject to change.
 """
@@ -54,20 +54,125 @@
   return os.path.join(dirpath, base)
 
 
-class CoordinatedCheckpointManager(object):
+class TerminationConfig(object):
+  """Configurations to customize for a platform other than Google's Borg or GCP.
+
+  A TerminationConfig can be created and passed to the
+  `WorkerPreemptionHandler` to provide customization based on the platform.
+  It will deliver three pieces of information:
+
+  * How to decide if there is a termination event soon
+
+  The termination notification and how to fetch it varies across platforms. Thus
+  we accept a user-defined function, `termination_watcher_function`, and execute
+  it repeatedly to check for termination notification.
+  `termination_watcher_function` should be a function that returns True if a
+  termination notification has been made available and False otherwise. And the
+  function should be lightweight and non-blocking so that we can clean up the
+  resources properly if no termination signal is ever raised until training
+  finishes.
+
+  * How to exit the program
+
+  We are asking for an `restart_code` to execute `sys.exit(restart_code)` after
+  saving the checkpoint to exit the training program gracefully. A restart is
+  inevitable to reset the program's state. However, you can configure the
+  `restart_code` to facilitate the restart and make the training experience
+  smooth. How so? Maybe your platform has an agreement to a RESTART_CODE that’s
+  recognized as a program auto-restart signal, or you may have a coordinating
+  script that starts up the training, in which you can configure the program to
+  auto-restart if it ever exits with this RESTART_CODE. In both cases,
+  you can pass in this RESTART_CODE and then wouldn’t even notice that the
+  training has been interrupted and restarted.
+
+  * How long do we have from receiving a termination event notice till the
+  actual termination.
+
+  Some platforms have the gap time as long as, say, one hour. In this case, you
+  might want to utilize this time for training as much as possible until you
+  have to save a checkpoint and exit. We can utilize this information if you
+  pass it through the `time_till_termination` argument.
+
+
+  *The default behavior*:
+
+  If you are training with Google’s Borg system or GCP, we automatically detect
+  the platform and make the right configuration for you. Besides these two
+  platforms, the default behavior on an unrecognized platform is:
+
+  * If `termination_event` is `None`, we will treat `signal.SIGTERM` as a
+  termination event.
+
+  * If `restart_code` not configured, we exit with an arbitrary choice, 42.
+
+  * If `time_till_termination` is not configured, the default is 0, and we will
+  wrap up the current training step, save a checkpoint, and exit the program as
+  soon as we receive the termination signal.
+  """
+
+  def __init__(self,
+               termination_watcher_function=None,
+               restart_code=None,
+               time_till_termination=None):
+    self.termination_watcher_function = termination_watcher_function
+    self.restart_code = restart_code
+    self.time_till_termination = time_till_termination
+
+
+class GCPTerminationConfig(TerminationConfig):
+
+  def __init__(  # pylint: disable=super-init-not-called
+      self,
+      termination_watcher_function=None,
+      restart_code=None,
+      time_till_termination=None):
+    self.termination_watcher_function = termination_watcher_function or gce_util.termination_watcher_function_gce
+    self.restart_code = restart_code or gce_util._RESTARTABLE_EXIT_CODE
+    self.time_till_termination = time_till_termination or gce_util.GRACE_PERIOD_GCE
+
+
+class BorgTerminationConfig(TerminationConfig):
+
+  def __init__(  # pylint: disable=super-init-not-called
+      self,
+      termination_watcher_function=None,
+      restart_code=None,
+      time_till_termination=None):
+    self.termination_watcher_function = termination_watcher_function
+    self.restart_code = restart_code or 42
+    self.time_till_termination = time_till_termination or 0
+
+
+def _complete_config_for_environement(platform_device, termination_config):
+  """Complete un-filled fields of TerminationConfig based on platform."""
+  if platform_device is gce_util.PlatformDevice.GCE_GPU:
+    return GCPTerminationConfig(termination_config.termination_watcher_function,
+                                termination_config.restart_code,
+                                termination_config.time_till_termination)
+
+  else:
+    # The default we chose are the same as the ones used by Borg. So we just
+    # return this.
+    return BorgTerminationConfig(
+        termination_config.termination_watcher_function,
+        termination_config.restart_code,
+        termination_config.time_till_termination)
+
+
+class WorkerPreemptionHandler(object):
   """Preemption and error handler for synchronous training.
 
   The API helps coordinate all workers to save a checkpoint upon receiving a
   preemption signal and helps propagate accurate error messages during training.
   When the program recovers from preemption, the checkpoint that is passed to
-  initialize a `CoordinatedCheckpointManager` object will be loaded
+  initialize a `WorkerPreemptionHandler` object will be loaded
   automatically.
 
-  Right after the initialization, a thread starts to watch out for a preemption
+  Right after the initialization, a thread starts to watch out for a termination
   signal for any member in the cluster, but the signal will only be handled
   (which includes aligning the step to save a checkpoint, saving a checkpoint,
   and exiting with a platform recognized restart code) after entering a
-  `CoordinatedCheckpointManager.run` call.
+  `WorkerPreemptionHandler.run` call.
 
   Example usage:
   ```python
@@ -78,49 +183,54 @@
 
     fh_checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
 
-    coordinated_checkpoint_manager = tf.distribute.CoordinatedCheckpointManager(
+    worker_preemption_watcher = tf.distribute.WorkerPreemptionHandler(
         cluster_resolver, fh_checkpoint, checkpoint_directory)
 
 
-    # `coordinated_checkpoint_manager.total_runs` will be restored to its
+    # `worker_preemption_watcher.total_runs` will be restored to its
     # checkpointed value when training is restored after interruption.
-    for epoch in range(coordinated_checkpoint_manager.total_runs //
+    for epoch in range(worker_preemption_watcher.total_runs //
                        STEPS_PER_EPOCH, num_epochs):
-      for step in range(coordinated_checkpoint_manager.total_runs %
+      for step in range(worker_preemption_watcher.total_runs %
                         STEPS_PER_EPOCH, num_steps):
         # distributed_train_step is a function wrapped by strategy.run
-        loss += coordinated_checkpoint_manager.run(distributed_train_step,
+        loss += worker_preemption_watcher.run(distributed_train_step,
                                                    args=(next(dataset),))
   ```
 
-  `CoordinatedCheckpointManager` will create a CheckpointManager to manage the
+  `WorkerPreemptionHandler` will create a CheckpointManager to manage the
   checkpoint and only one CheckpointManager should be active in a particular
   directory at a time. Thus, if the user would like to save a checkpoint for
   purpose other than fault tolerance, e.g., for evaluation, they should save it
   in a directory different from the one passed to a
-  `CoordinatedCheckpointManager`.
+  `WorkerPreemptionHandler`.
 
   This API targets multi-client distributed training, and right now only
   `tf.distribute.MultiWorkerMirroredStrategy` is supported.
   """
 
-  def __init__(self, cluster_resolver, checkpoint, checkpoint_dir):
+  def __init__(self,
+               cluster_resolver,
+               checkpoint,
+               checkpoint_dir,
+               termination_config=TerminationConfig()):
     """Creates the failure handler.
 
     Args:
       cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You
-        may also get it through the `cluster_resolver` attribute of the
-        strategy in use.
+        may also get it through the `cluster_resolver` attribute of the strategy
+        in use.
       checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and
-        loaded upon restart by the `CoordinatedCheckpointManager` API
-        automatically.
-      checkpoint_dir: a directory for the `CoordinatedCheckpointManager` to play
-        with checkpoints. `CoordinatedCheckpointManager` will create a
+        loaded upon restart by the `WorkerPreemptionHandler` API automatically.
+      checkpoint_dir: a directory for the `WorkerPreemptionHandler` to play with
+        checkpoints. `WorkerPreemptionHandler` will create a
         `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since
         only one `tf.train.CheckpointManager` should be active in a particular
         directory at a time, this `checkpoint_dir` arg should preferably be
         separated from where the user saves their checkpoint for non-fault
         tolerance purpose.
+      termination_config: a `TerminationConfig` object to configure for a
+        platform other than Google Borg or GCP.
     """
     self._cluster_resolver = cluster_resolver
     self._checkpoint = checkpoint
@@ -130,7 +240,7 @@
             self._cluster_resolver.task_type,
             self._cluster_resolver.task_id))
 
-    # The number of calls to `CoordinatedCheckpointManager.run` when the latest
+    # The number of calls to `WorkerPreemptionHandler.run` when the latest
     # checkpoint was saved.
     self._checkpointed_runs = variables.Variable(
         initial_value=constant_op.constant(0, dtype=dtypes.int64),
@@ -162,9 +272,9 @@
 
     # An internal step counter that's restored to checkpointed_iterations when
     # training is restored. It increments by one every time
-    # `CoordinatedCheckpointManager.run` is called. Note that in this case, the
+    # `WorkerPreemptionHandler.run` is called. Note that in this case, the
     # user must pass a single-step training function to
-    # `CoordinatedCheckpointManager.run` instead of a multiple-step one.
+    # `WorkerPreemptionHandler.run` instead of a multiple-step one.
     self._run_counter = self._checkpointed_runs.numpy()
 
     # The worker itself has received preeption signal.
@@ -188,36 +298,39 @@
     self._cluster_wise_termination_watcher_thread.start()
     logging.info('Start watcher for peer\'s signal.')
 
-    self._poll_gce_signal_thread = None
+    self._poll_termination_signal_thread = None
+
     self._platform_device = gce_util.detect_platform()
-    if self._platform_device is gce_util.PlatformDevice.GCE_GPU:
-      self._start_polling_for_gce_signal()
-      self._exit_code = gce_util._RESTARTABLE_EXIT_CODE
-    elif self._platform_device is gce_util.PlatformDevice.INTERNAL:
-      self._start_watching_for_signal()
-      self._exit_code = _RESTARTABLE_EXIT_CODE
+
+    completed_termination_config = _complete_config_for_environement(
+        self._platform_device, termination_config)
+    self._termination_watcher_function = completed_termination_config.termination_watcher_function
+    self._restart_code = completed_termination_config.restart_code
+    self._time_till_termination = completed_termination_config.time_till_termination
+
+    if completed_termination_config.termination_watcher_function:
+      self._start_polling_for_termination_signal()
     else:
-      raise NotImplementedError('CoordinatedCheckpointManager is only supported'
-                                ' for MultiWorkerMirroredStrategy with GPU.')
+      self._start_watching_for_signal()
 
   def _start_watching_for_signal(self):
     signal.signal(signal.SIGTERM, self._sigterm_handler_fn)
 
-  def _start_polling_for_gce_signal(self):
-    self._poll_gce_signal_thread_should_stop = threading.Event()
-    self._poll_gce_signal_thread = threading.Thread(
-        target=self._poll_gce_signal,
+  def _start_polling_for_termination_signal(self):
+    self._poll_termination_signal_thread_should_stop = threading.Event()
+    self._poll_termination_signal_thread = threading.Thread(
+        target=self._poll_termination_signal,
         name='WorkerTerminationSignalWatcher-%s' % self._id_in_cluster,
         daemon=True)
-    self._poll_gce_signal_thread.start()
+    self._poll_termination_signal_thread.start()
     logging.info('Start polling for termination signal.')
 
-  def _poll_gce_signal(self):
-    """Poll maintenance notice from GCE and notify peers if receiving one."""
+  def _poll_termination_signal(self):
+    """Poll maintenance notice and notify peers if receiving one."""
     while True:
-      if self._poll_gce_signal_thread_should_stop.is_set():
+      if self._poll_termination_signal_thread_should_stop.is_set():
         return
-      if gce_util.signal_polling_fn():
+      if self._termination_watcher_function():
         # For the countdown.
         self._signal_receipt_time = time.time()
         break
@@ -227,11 +340,13 @@
                  self._id_in_cluster)
     self._received_own_sigterm.set()
 
-  def _stop_poll_gce_signal_thread(self):
-    if self._poll_gce_signal_thread:
-      self._poll_gce_signal_thread_should_stop.set()
-      self._poll_gce_signal_thread.join()
-      self._poll_gce_signal_thread = None
+  def _stop_poll_termination_signal_thread(self):
+    if self._poll_termination_signal_thread:
+
+      self._poll_termination_signal_thread_should_stop.set()
+      self._poll_termination_signal_thread.join()
+
+      self._poll_termination_signal_thread = None
       logging.info('Shut down watcher for one\'s own termination signal')
 
   def _stop_cluster_wise_termination_watcher_thread(self):
@@ -258,17 +373,17 @@
 
   def __del__(self):
     self._stop_cluster_wise_termination_watcher_thread()
-    self._stop_poll_gce_signal_thread()
+    self._stop_poll_termination_signal_thread()
 
   @property
   def total_runs(self):
-    """Returns the number of times `CoordinatedCheckpointManager.run` is called.
+    """Returns the number of times `WorkerPreemptionHandler.run` is called.
 
     This value tracks the number of all calls to
-    `CoordinatedCheckpointManager.run` including those before the program is
+    `WorkerPreemptionHandler.run` including those before the program is
     restarted and the training is restored. The user can compute their total
     number of iterations by:
-    `coordinated_checkpoint_manager.run * number_of_steps_in_train_function`,
+    `worker_preemption_watcher.run * number_of_steps_in_train_function`,
     while for tf.distribute.MultiWorkerMirroredStrategy users,
     `number_of_steps_in_train_function` should be one.
     """
@@ -292,11 +407,11 @@
     function (i.e., containing a call to `tf.distribute.Strategy.run`). For
     `tf.distribute.MultiWorkerMirroredStrategy` users, we recommend passing in a
     single-step `distributed_train_function` to
-    `CoordinatedCheckpointManager.run` so that the checkpoint can be saved in
+    `WorkerPreemptionHandler.run` so that the checkpoint can be saved in
     time in case a preemption signal or maintenance notice is sent.
 
     Besides the preemption and error handling part,
-    `CoordinatedCheckpointManager.run(distributed_train_function, *args,
+    `WorkerPreemptionHandler.run(distributed_train_function, *args,
     **kwargs)` has the same effect and output as
     `distributed_train_function(*args, **kwargs)`. `distributed_train_function`
     can return either some or no result. The following is a shortened example:
@@ -317,15 +432,15 @@
       return strategy.reduce(
           tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
 
-    for epoch in range(coordinated_checkpoint_manager.total_runs //
+    for epoch in range(worker_preemption_watcher.total_runs //
                        STEPS_PER_EPOCH, EPOCHS_TO_RUN):
       iterator = iter(multi_worker_dataset)
       total_loss = 0.0
       num_batches = 0
 
-      for step in range(coordinated_checkpoint_manager.total_runs %
+      for step in range(worker_preemption_watcher.total_runs %
                         STEPS_PER_EPOCH, STEPS_PER_EPOCH):
-        total_loss += coordinated_checkpoint_manager.run(distributed_train_step)
+        total_loss += worker_preemption_watcher.run(distributed_train_step)
         num_batches += 1
 
       train_loss = total_loss / num_batches
@@ -392,9 +507,9 @@
     logging.info('Checkpoint finished at path %s',
                  self._write_checkpoint_manager.directory)
     logging.info('Checkpoint time: %f', end_time - start_time)
-    self._stop_poll_gce_signal_thread()
+    self._stop_poll_termination_signal_thread()
     self._stop_cluster_wise_termination_watcher_thread()
-    sys.exit(self._exit_code)
+    sys.exit(self._restart_code)
 
   def _checkpoint_if_preempted(self):
     """Checkpoint if any worker has received a preemption signal.
@@ -431,7 +546,6 @@
       try:
         context.context().set_config_key_value(_RUN_COUNT_KEY, step_to_save_at)
         logging.info('Termination caught in main thread on preempted worker')
-
         logging.info('%s set to %s', _RUN_COUNT_KEY, step_to_save_at)
 
         n_workers = multi_worker_util.worker_count(
@@ -440,7 +554,6 @@
         for i in range(n_workers):
           context.context().get_config_key_value(f'{_ACKNOWLEDGE_KEY}_{i}')
           logging.info('Sigterm acknowledgement from replica %d received', i)
-
       # This is to handle the case that some other worker receives termination
       # notice as well, and it has made a step key available right before this
       # worker attempts to set it. In this case, it incurs a config key
@@ -498,5 +611,6 @@
 
       ack_key = f'{_ACKNOWLEDGE_KEY}_{self._id_in_cluster}'
       context.context().set_config_key_value(ack_key, '1')
-      logging.info('CoordinatedCheckpointManager._wait_for_signal: %s set, '
-                   'preemption awareness acknowledged', ack_key)
+      logging.info(
+          'WorkerPreemptionHandler._wait_for_signal: %s set, '
+          'preemption awareness acknowledged', ack_key)
diff --git a/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py b/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py
index 386be23..8f84dc5 100644
--- a/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py
+++ b/tensorflow/python/distribute/failure_handling/gce_failure_handler_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for GCE specifics of CoordinatedCheckpointManager."""
+"""Tests for GCE specifics of WorkerPreemptionHandler."""
 import os
 import random
 import re
@@ -69,7 +69,7 @@
 
 
 class GceFailureHandlingTest(test.TestCase, parameterized.TestCase):
-  """Integration test for CoordinatedCheckpointManager."""
+  """Integration test for WorkerPreemptionHandler."""
 
   def _mwms_write_checkpoint_dir(self, checkpoint_dir, cluster_spec, task_type,
                                  task_id):
@@ -85,33 +85,33 @@
   def worker_fn(self,
                 checkpoint_dir,
                 cluster_spec,
-                maintenance_event,
-                training_finished,
+                maintenance_event=None,
+                training_finished=None,
                 frequent_send=False):
 
     _enable_coordination_service(cluster_spec)
     strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
 
-    def mock_request_compute_metadata(*args, **kwargs):
-      del kwargs  # Unused.
-      if args[0] == 'instance/maintenance-event':
-        if not frequent_send:
-          time.sleep(1)
-          if (not maintenance_event.is_set()) and (random.randrange(0, 20) >
-                                                   18):
-            maintenance_event.set()
-            logging.info('Maintenance notice available.')
-            return 'TERMINATE_ON_HOST_MAINTENANCE'
-        elif frequent_send and not maintenance_event.is_set():
-          logging.info('Maintenance notice available.')
-          return 'TERMINATE_ON_HOST_MAINTENANCE'
+    def mock_termination_watcher_function_gce(*args, **kwargs):
+      del args, kwargs
+      if not frequent_send:
+        time.sleep(1)
+        if (not maintenance_event.is_set()) and (random.randrange(0, 20) > 18):
+          maintenance_event.set()
+          logging.info('Termination notice available.')
+          return True
 
-      return 'NONE'
+      elif frequent_send and not maintenance_event.is_set():
+        logging.info('Termination notice available.')
+        return True
 
-    with mock.patch.object(gce_util, 'request_compute_metadata',
-                           mock_request_compute_metadata), mock.patch.object(
-                               gce_util, 'detect_platform',
-                               lambda: gce_util.PlatformDevice.GCE_GPU):
+      return False
+
+    with mock.patch.object(
+        gce_util, 'termination_watcher_function_gce',
+        mock_termination_watcher_function_gce), mock.patch.object(
+            gce_util, 'detect_platform',
+            lambda: gce_util.PlatformDevice.GCE_GPU):
 
       class Model(module.Module):
 
@@ -129,7 +129,7 @@
         model = Model()
         fh_ckpt = tracking_util.Checkpoint(model=model)
 
-        failure_handler = failure_handling.CoordinatedCheckpointManager(
+        worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
             strategy.cluster_resolver, fh_ckpt, checkpoint_dir)
 
       def distributed_train_step(current_epoch, current_step):
@@ -143,13 +143,15 @@
         if current_step == STEPS_PER_EPOCH - 1:
           logging.info('epoch %d finished', current_epoch)
 
-      logging.info('Start training at %d', failure_handler.total_runs)
-      for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH,
-                         EPOCHS_TO_RUN):
+      logging.info('Start training at %d', worker_preemption_watcher.total_runs)
+      for epoch in range(
+          worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
+          EPOCHS_TO_RUN):
 
-        for step in range(failure_handler.total_runs % STEPS_PER_EPOCH,
-                          STEPS_PER_EPOCH):
-          failure_handler.run(distributed_train_step, epoch, step)
+        for step in range(
+            worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
+            STEPS_PER_EPOCH):
+          worker_preemption_watcher.run(distributed_train_step, epoch, step)
 
       training_finished.set()
 
@@ -165,7 +167,7 @@
         try:
           # Explicitly call __del__ since making it None and gc.collect does
           # not invoke __del__ here.
-          failure_handler.__del__()
+          worker_preemption_watcher.__del__()
 
           time.sleep(2)
 
diff --git a/tensorflow/python/distribute/failure_handling/gce_util.py b/tensorflow/python/distribute/failure_handling/gce_util.py
index 7c73328..29b1306 100644
--- a/tensorflow/python/distribute/failure_handling/gce_util.py
+++ b/tensorflow/python/distribute/failure_handling/gce_util.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Integration of CoordinatedCheckpointManager with GCE specific logic."""
+"""Util of GCE specifics to ingegrate with WorkerPreemptionHandler."""
 import enum
 import os
 
@@ -25,9 +25,10 @@
 GCP_METADATA_HEADER = {'Metadata-Flavor': 'Google'}
 _GCE_METADATA_URL_ENV_VARIABLE = 'GCE_METADATA_IP'
 _RESTARTABLE_EXIT_CODE = 143
+GRACE_PERIOD_GCE = 0
 
 
-def request_compute_metadata(path: str) -> str:
+def request_compute_metadata(path):
   """Returns GCE VM compute metadata."""
   gce_metadata_endpoint = 'http://' + os.environ.get(
       _GCE_METADATA_URL_ENV_VARIABLE, 'metadata.google.internal')
@@ -41,7 +42,7 @@
     return info
 
 
-def signal_polling_fn() -> bool:
+def termination_watcher_function_gce():
   result = request_compute_metadata(
       'instance/maintenance-event') == 'TERMINATE_ON_HOST_MAINTENANCE'
   return result
@@ -74,7 +75,7 @@
   UNSUPPORTED = 'unsupported'
 
 
-def detect_platform() -> PlatformDevice:
+def detect_platform():
   """Returns the platform and device information."""
   if on_gcp():
     if context.context().list_physical_devices('GPU'):