Change LossScaleOptimizer checkpoint format.

Now the format is identical to as if a LossScaleOptimzier is not used, except that the loss scale is saved with a LossScaleOptimizer. This allows saving checkpoints with a LossScaleOptimizer and restoring without a LossScaleOptimizer, and vice versa.

Checkpoints with LossScaleOptimizers created in older versions of TensorFlow can still be loaded. New checkpoints saved will use the new format.

PiperOrigin-RevId: 306511555
Change-Id: Ie316ab8c4fbfec7babd6f7803d337799d0ff10a5
diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD
index 7609745..18390b4 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/BUILD
+++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD
@@ -224,9 +224,16 @@
     name = "keras_test",
     size = "medium",
     srcs = ["keras_test.py"],
+    data = [
+        "//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_ckpt_tf2.2",
+        "//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_savedmodel_tf2.2",
+    ],
     python_version = "PY3",
     shard_count = 10,
-    tags = ["no_windows"],  # b/139083295: bfloat16 tests fail on Windows
+    tags = [
+        "no_pip",
+        "no_windows",  # b/139083295: bfloat16 tests fail on Windows
+    ],
     deps = [
         ":test_util",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
index a27be08..d2e80cf 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
@@ -41,6 +41,7 @@
 from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import base_layer_utils
 from tensorflow.python.keras.engine import input_spec
+from tensorflow.python.keras.engine import sequential
 from tensorflow.python.keras.layers import core
 from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
 from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
@@ -994,6 +995,56 @@
     self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1)
 
   @keras_parameterized.run_all_keras_modes
+  def test_restore_old_loss_scale_checkpoint(self):
+    # Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format
+    # of LossScaleOptimizer changed, but old checkpoints can still be loaded
+    opt = gradient_descent.SGD(0.1, momentum=0.1)
+    opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
+    model = sequential.Sequential([core.Dense(2,)])
+
+    # The checkpoint and expected values were obtained from the program in
+    # testdata/BUILD.
+    ckpt_dir = test.test_src_dir_path(
+        'python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2')
+    model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
+    model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
+    model(np.zeros((2, 2)))  # Create model weights
+    opt._create_all_weights(model.weights)
+    expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
+    expected_slot = np.array([[10.049943, 9.917691], [10.049943, 9.917691]])
+    self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
+    self.assertAllClose(
+        self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
+        expected_slot)
+    self.assertEqual(self.evaluate(opt.loss_scale()), 32768)
+    self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1)
+
+    # Check restoring works even after the model is compiled and the weights
+    # have been created.
+    model.fit(np.random.normal(size=(2, 2)), np.random.normal(size=(2, 2)))
+    self.assertNotAllClose(self.evaluate(model.weights[0]), expected_kernel)
+    self.assertNotAllClose(
+        self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
+        expected_slot)
+    model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
+    self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
+    self.assertAllClose(
+        self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
+        expected_slot)
+    self.assertEqual(self.evaluate(opt.loss_scale()), 32768)
+    self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1)
+
+  def test_restore_old_saved_model(self):
+    saved_model_dir = test.test_src_dir_path(
+        'python/keras/mixed_precision/experimental/testdata/'
+        'lso_savedmodel_tf2.2')
+    model = save.load_model(saved_model_dir)
+    expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
+    self.assertAllClose(backend.eval(model.weights[0]), expected_kernel)
+    self.assertIsInstance(model.optimizer,
+                          loss_scale_optimizer.LossScaleOptimizer)
+
+  @keras_parameterized.run_all_keras_modes
   @parameterized.named_parameters(
       {
           'testcase_name': 'base',
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
index 1c14955..d6a786a 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
@@ -22,6 +22,7 @@
 from tensorflow.python.distribute import mirrored_strategy
 from tensorflow.python.distribute import one_device_strategy
 from tensorflow.python.distribute import tpu_strategy
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import smart_cond
 from tensorflow.python.keras import backend
@@ -32,6 +33,7 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.training.experimental import loss_scale as loss_scale_module
 from tensorflow.python.training.experimental import mixed_precision
+from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util.tf_export import keras_export
 
 
@@ -51,8 +53,126 @@
     self.value = value
 
 
+class _DelegatingTrackableMixin(object):
+  """A mixin that delegates all Trackable methods to another trackable object.
+
+  This class must be used with multiple inheritance. A class that subclasses
+  Trackable can also subclass this class, which causes all Trackable methods to
+  be delegated to the trackable object passed in the constructor.
+
+  A subclass can use this mixin to appear as if it were the trackable passed to
+  the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this
+  mixin, so that the checkpoint format for a LossScaleOptimizer is identical to
+  the checkpoint format for a normal optimizer. This allows a model to be saved
+  with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa.
+  The only difference in checkpoint format is that the loss scale is also saved
+  with a LossScaleOptimizer.
+  """
+
+  def __init__(self, trackable_obj):
+    self._trackable = trackable_obj
+
+  # pylint: disable=protected-access
+  @property
+  def _setattr_tracking(self):
+    return self._trackable._setattr_tracking
+
+  @_setattr_tracking.setter
+  def _setattr_tracking(self, value):
+    self._trackable._setattr_tracking = value
+
+  @property
+  def _update_uid(self):
+    return self._trackable._update_uid
+
+  @_update_uid.setter
+  def _update_uid(self, value):
+    self._trackable._update_uid = value
+
+  @property
+  def _unconditional_checkpoint_dependencies(self):
+    return self._trackable._unconditional_checkpoint_dependencies
+
+  @property
+  def _unconditional_dependency_names(self):
+    return self._trackable._unconditional_dependency_names
+
+  @property
+  def _name_based_restores(self):
+    return self._trackable._name_based_restores
+
+  def _maybe_initialize_trackable(self):
+    return self._trackable._maybe_initialize_trackable()
+
+  @property
+  def _object_identifier(self):
+    return self._trackable._object_identifier
+
+  @property
+  def _tracking_metadata(self):
+    return self._trackable._tracking_metadata
+
+  def _no_dependency(self, value):
+    return self._trackable._no_dependency(value)
+
+  def _name_based_attribute_restore(self, checkpoint):
+    return self._trackable._name_based_attribute_restore(checkpoint)
+
+  @property
+  def _checkpoint_dependencies(self):
+    return self._trackable._checkpoint_dependencies
+
+  @property
+  def _deferred_dependencies(self):
+    return self._trackable._deferred_dependencies
+
+  def _lookup_dependency(self, name):
+    self._trackable._lookup_dependency(name)
+
+  def _add_variable_with_custom_getter(self,
+                                       name,
+                                       shape=None,
+                                       dtype=dtypes.float32,
+                                       initializer=None,
+                                       getter=None,
+                                       overwrite=False,
+                                       **kwargs_for_getter):
+    return self._trackable._add_variable_with_custom_getter(
+        name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
+
+  def _preload_simple_restoration(self, name, shape):
+    return self._trackable._preload_simple_restoration(name, shape)
+
+  def _track_trackable(self, trackable, name, overwrite=False):  # pylint: disable=redefined-outer-name
+    return self._trackable._track_trackable(trackable, name, overwrite)
+
+  def _handle_deferred_dependencies(self, name, trackable):  # pylint: disable=redefined-outer-name
+    return self._trackable._handle_deferred_dependencies(name, trackable)
+
+  def _restore_from_checkpoint_position(self, checkpoint_position):
+    return self._trackable._restore_from_checkpoint_position(
+        checkpoint_position)
+
+  def _single_restoration_from_checkpoint_position(self, checkpoint_position,
+                                                   visit_queue):
+    return self._trackable._single_restoration_from_checkpoint_position(
+        checkpoint_position, visit_queue)
+
+  def _gather_saveables_for_checkpoint(self):
+    return self._trackable._gather_saveables_for_checkpoint()
+
+  def _list_extra_dependencies_for_serialization(self, serialization_cache):
+    return self._trackable._list_extra_dependencies_for_serialization(
+        serialization_cache)
+
+  def _list_functions_for_serialization(self, serialization_cache):
+    return self._trackable._list_functions_for_serialization(
+        serialization_cache)
+  # pylint: enable=protected-access
+
+
 @keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
-class LossScaleOptimizer(optimizer_v2.OptimizerV2):
+class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
   """An optimizer that applies loss scaling.
 
   Loss scaling is a process that multiplies the loss by a multiplier called the
@@ -144,6 +264,11 @@
     self._loss_scale = keras_loss_scale_module.get(loss_scale)
     if self._loss_scale is None:
       raise ValueError('loss_scale cannot be None.')
+
+    # We don't call super().__init__, since we do not want to call OptimizerV2's
+    # constructor.
+    _DelegatingTrackableMixin.__init__(self, self._optimizer)
+
     for weight in loss_scale_module.get_loss_scale_weights(self._loss_scale):
       # We cannot call `track_variable` in the LossScale class itself, because a
       # file outside of Keras cannot depend on a Keras file. Calling it here
@@ -151,12 +276,15 @@
       # a Keras class, and the only way to use LossScale with a Keras class is
       # through the LossScaleOptimizer.
       backend.track_variable(weight)
-    self._track_trackable(self._optimizer, 'base_optimizer')
     self._track_trackable(self._loss_scale, 'loss_scale')
 
     # Needed because the superclass's __getattribute__ checks this.
     self._hyper = {}
 
+    # To support restoring TensorFlow 2.2 checkpoints.
+    self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
+                          'base_optimizer')
+
   @property
   def loss_scale(self):
     """The `LossScale` instance associated with this optimizer."""
@@ -348,6 +476,21 @@
   def _aggregate_gradients(self, grads_and_vars):
     return self._optimizer._aggregate_gradients(grads_and_vars)  # pylint: disable=protected-access
 
+  def _restore_slot_variable(self, slot_name, variable, slot_variable):
+    return self._optimizer._restore_slot_variable(slot_name, variable,  # pylint: disable=protected-access
+                                                  slot_variable)
+
+  def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
+                                       variable):
+    return self._optimizer._create_or_restore_slot_variable(  # pylint: disable=protected-access
+        slot_variable_position, slot_name, variable)
+
+  def get_slot(self, var, slot_name):
+    return self._optimizer.get_slot(var, slot_name)
+
+  def add_slot(self, var, slot_name, initializer='zeros'):
+    return self._optimizer.add_slot(var, slot_name, initializer)
+
   # For the most part, we only expose methods in the base OptimizerV2, not
   # individual subclasses like Adam. However, although "learning_rate" and "lr"
   # properties are not part of the base OptimizerV2 class, they are part of most
@@ -369,23 +512,6 @@
   def lr(self, lr):
     self._optimizer.lr = lr
 
-  def get_slot(self, var, slot_name):
-    # We cannot implement get_slot for the following reason: When saving a
-    # checkpoint, two optimizers cannot share slot variables. Since both the
-    # LossScaleOptimizer and the wrapped optimizer (self and self._optimizer
-    # respectively) are checkpointed, we cannot expose the wrapped optimizer's
-    # slots in the LossScaleOptimizer. Otherwise, a checkpoint would believe
-    # both optimizers share slot variables.
-    raise AttributeError(
-        'You cannot call get_slot on a LossScaleOptimizer. This limitation '
-        'will be removed in the future.')
-
-  def add_slot(self, var, slot_name, initializer='zeros'):
-    # We disallow adding a slot for consistency with `get_slot`.
-    raise AttributeError(
-        'You cannot call add_slot on a LossScaleOptimizer. This limitation '
-        'will be removed in the future.')
-
   # We do not override some OptimizerV2 methods. For each, we describe why we do
   # not delegate them to self._optimizer:
   # * get_updates: get_updates() calls get_gradients(). Since we override
@@ -402,6 +528,51 @@
   # TODO(reedwm): Maybe throw an error if mixed precision is used without this
   # optimizer being used.
 
+  # Trackable delegations: Delegate all Trackable methods to the wrapped
+  # optimizer. This is so the checkpoint format for a LossScaleOptimizer is
+  # identical to the checkpoint format for a normal optimizer, except the loss
+  # scale is stored in the checkpoint.
+
+
+class FakeOptimizerForRestoration(trackable.Trackable):
+  """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
+
+  The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class
+  exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow.
+
+  In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the
+  following in LossScaleOptimizer.__init__
+
+  ```
+  self._track_trackable(self._optimizer, 'base_optimizer')
+  ```
+
+  This means a dependency from the LossScaleOptimizer to the wrapped optimizer
+  would be stored in the checkpoint. However now, the checkpoint format with a
+  LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
+  except the loss scale is also stored. This means there is no dependency from
+  the LossScaleOptimizer to the wrapped optimizer. Instead, the
+  LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's
+  perspective, by overriding all Trackable methods and delegating them to the
+  wrapped optimizer.
+
+  To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
+  on this class instead of the inner optimizer. When restored, this class will
+  instead restore the slot variables of the inner optimizer. Since this class
+  has no variables, it does not affect the checkpoint when saved.
+  """
+
+  def __init__(self, optimizer):
+    self._optimizer = optimizer
+
+  def get_slot_names(self):
+    return self._optimizer.get_slot_names()
+
+  def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
+                                       variable):
+    return self._optimizer._create_or_restore_slot_variable(  # pylint: disable=protected-access
+        slot_variable_position, slot_name, variable)
+
 
 # pylint: disable=protected-access
 mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
index cbabda3..20252ff 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
@@ -305,20 +305,6 @@
       opt.set_weights([np.array(2.)])
       self.assertEqual(self.evaluate(opt.variables()[0]), 2)
 
-  def testSlotMethodErrors(self):
-    opt = gradient_descent.SGD(1.0, momentum=1.0)
-    opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
-    with self.assertRaisesRegexp(
-        AttributeError,
-        'You cannot call get_slot on a LossScaleOptimizer. This limitation '
-        'will be removed in the future.'):
-      opt.get_slot(None, None)
-    with self.assertRaisesRegexp(
-        AttributeError,
-        'You cannot call add_slot on a LossScaleOptimizer. This limitation '
-        'will be removed in the future.'):
-      opt.add_slot(None, None)
-
   def testPassingNoneToLossScale(self):
     opt = gradient_descent.SGD()
     with self.assertRaisesRegexp(ValueError, r'loss_scale cannot be None'):
@@ -394,9 +380,49 @@
       run_fn = lambda: opt.minimize(loss, [var])
       strategy.experimental_run(run_fn)
 
-  @parameterized.named_parameters(*TESTCASES)
-  def testCheckpoint(self, strategy_fn):
+  @parameterized.named_parameters({
+      'testcase_name': 'SaveAndRestoreBase',
+      'strategy_fn': default_strategy_fn,
+      'save_with_ls': True,
+      'restore_with_ls': True,
+  }, {
+      'testcase_name': 'SaveAndRestoreDistribute',
+      'strategy_fn': create_mirrored_strategy,
+      'save_with_ls': True,
+      'restore_with_ls': True,
+  }, {
+      'testcase_name': 'SaveBase',
+      'strategy_fn': default_strategy_fn,
+      'save_with_ls': True,
+      'restore_with_ls': False,
+  }, {
+      'testcase_name': 'SaveDistribute',
+      'strategy_fn': create_mirrored_strategy,
+      'save_with_ls': True,
+      'restore_with_ls': False,
+  }, {
+      'testcase_name': 'RestoreBase',
+      'strategy_fn': default_strategy_fn,
+      'save_with_ls': False,
+      'restore_with_ls': True,
+  }, {
+      'testcase_name': 'RestoreDistribute',
+      'strategy_fn': create_mirrored_strategy,
+      'save_with_ls': False,
+      'restore_with_ls': True,
+  })
+  def testCheckpoint(self, strategy_fn, save_with_ls, restore_with_ls):
+
+    class MySGD(gradient_descent.SGD):
+      """A custom optimizer that tracks an extra variable."""
+
+      def __init__(self, *args, **kwargs):
+        super(MySGD, self).__init__(*args, **kwargs)
+        self.my_var = variables.Variable(0.)
+        self._track_trackable(self.my_var, 'my_var')
+
     strategy = strategy_fn()
+    replicas = strategy.num_replicas_in_sync
     if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
         not context.executing_eagerly()):
       # TODO(b/121381184): Enable running the test in this case.
@@ -405,38 +431,89 @@
     with self.test_session(), strategy.scope():
       # Build and run a simple model.
       var = variables.Variable([2.0])
-      loss_scale = loss_scale_module.DynamicLossScale(
-          initial_loss_scale=1., increment_period=2.,
-          multiplier=2.)
-      opt = gradient_descent.SGD(1., momentum=1.)
-      opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
-      run_fn = lambda: opt.minimize(lambda: var + 1., var_list=[var])
+      opt = inner_opt = MySGD(1., momentum=1.)
+      if save_with_ls:
+        loss_scale = loss_scale_module.DynamicLossScale(
+            initial_loss_scale=1., increment_period=2.,
+            multiplier=2.)
+        opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
+      run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
       opt_op = strategy.experimental_run(run_fn)
       self.evaluate(variables.global_variables_initializer())
-      self.evaluate(opt_op)
-      self.assertEqual(self.evaluate(loss_scale()), 1.)
-      self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
-      slot_var = opt._optimizer.get_slot(var, 'momentum')
-      slot_value = self.evaluate(slot_var).item()
+      self.evaluate(strategy.experimental_local_results(opt_op))
+
+      # Assert values.
+      self.assertEqual(self.evaluate(var), 1.)
+      if save_with_ls:
+        self.assertEqual(self.evaluate(loss_scale()), 1.)
+        self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
+      slot_var = opt.get_slot(var, 'momentum')
+      self.assertEqual(self.evaluate(slot_var).item(), -1)
+      self.assertEqual(self.evaluate(opt.iterations), 1)
+
+      # Set optimizer variable to check arbitrary optimizer attributes can be
+      # saved/restored
+      self.evaluate(inner_opt.my_var.assign(1.))
 
       # Save a checkpoint.
       checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
       prefix = os.path.join(self.get_temp_dir(), 'ckpt')
       save_path = checkpoint.save(prefix)
 
-      # Run model again.
-      self.evaluate(strategy.experimental_run(run_fn))
-      self.assertEqual(self.evaluate(loss_scale()), 2.)
-      self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
-      self.assertNotAlmostEqual(self.evaluate(slot_var).item(), slot_value)
+      # Create new model
+      var = variables.Variable([2.0])
+      opt = inner_opt = MySGD(1., momentum=1.)
+      if restore_with_ls:
+        loss_scale = loss_scale_module.DynamicLossScale(
+            initial_loss_scale=1., increment_period=2.,
+            multiplier=2.)
+        opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
 
-      # Load checkpoint and ensure loss scale is back to it's original value.
+      # Restore new model.
+      checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
       status = checkpoint.restore(save_path)
-      status.assert_consumed()
+      if save_with_ls:
+        status.assert_existing_objects_matched()
+      else:
+        status.assert_nontrivial_match()
+
+      # Assert restored values. We can only assert in eager mode since the
+      # variables are uninitialized in graph mode
+      if context.executing_eagerly():
+        self.assertEqual(self.evaluate(var), 1.)
+        if save_with_ls and restore_with_ls:
+          self.assertEqual(self.evaluate(loss_scale()), 1.)
+          self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
+        elif restore_with_ls:
+          self.assertEqual(self.evaluate(loss_scale()), 1.)
+          self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
+        self.assertEqual(self.evaluate(opt.iterations), 1)
+
+      # Run the model again.
+      run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
+      opt_op = strategy.experimental_run(run_fn)
+
+      # Assert new values.
+      self.evaluate(variables.global_variables_initializer())
       status.run_restore_ops()
-      self.assertEqual(self.evaluate(loss_scale()), 1.)
-      self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
-      self.assertAlmostEqual(self.evaluate(slot_var).item(), slot_value)
+      self.evaluate(strategy.experimental_local_results(opt_op))
+      self.assertEqual(self.evaluate(var), -1)
+      slot_var = opt.get_slot(var, 'momentum')
+      self.assertEqual(self.evaluate(slot_var).item(), -2)
+      self.assertEqual(self.evaluate(opt.iterations), 2)
+      self.assertEqual(self.evaluate(inner_opt.my_var), 1)
+
+      # Restore model again to test restoring after slots are created
+      status = checkpoint.restore(save_path)
+      if save_with_ls and restore_with_ls:
+        status.assert_consumed()
+      elif save_with_ls:
+        status.assert_existing_objects_matched()
+      elif restore_with_ls:
+        status.assert_nontrivial_match()
+      status.run_restore_ops()
+      self.assertEqual(self.evaluate(var), 1)
+      self.assertEqual(self.evaluate(slot_var).item(), -1)
 
   def testGetConfig(self):
     opt = gradient_descent.SGD(2., momentum=0.5)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/BUILD b/tensorflow/python/keras/mixed_precision/experimental/testdata/BUILD
new file mode 100644
index 0000000..39e8906
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/BUILD
@@ -0,0 +1,48 @@
+# Description:
+#   Contains checkpoints and SavedModels for testing purposes.
+
+package(
+    default_visibility = [
+        "//tensorflow/python/keras:__subpackages__",
+        "//tensorflow/tools/pip_package:__pkg__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+exports_files(["LICENSE"])
+
+# These files were generated by running the following program with TensorFlow
+# 2.2rc2. The final release of TF 2.2 was not out when this change was created.:
+
+# import os
+# import numpy as np
+# import tensorflow as tf
+#
+# tf.random.set_seed(1)
+# opt = tf.keras.optimizers.SGD(0.1, momentum=0.1)
+# opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
+# model = tf.keras.Sequential([tf.keras.layers.Dense(2)])
+# model.compile(opt, 'mse')
+#
+# x = np.ones((10, 2))
+# y = x * 100
+# model.fit(x, y)
+# weight_dir = os.environ['TF_LSO_WEIGHT_DIR']
+# model_dir = os.environ['TF_LSO_MODEL_DIR']
+# model.save_weights(weight_dir)
+# model.save(model_dir)
+# print(model.get_weights()[0])
+# print(opt._optimizer.get_slot(model.weights[0], 'momentum'))
+# print(opt.loss_scale)
+
+filegroup(
+    name = "lso_ckpt_tf2.2",
+    srcs = glob(["lso_ckpt_tf2.2/**"]),
+    tags = ["no_pip"],
+)
+
+filegroup(
+    name = "lso_savedmodel_tf2.2",
+    srcs = glob(["lso_savedmodel_tf2.2/**"]),
+    tags = ["no_pip"],
+)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/checkpoint b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/checkpoint
new file mode 100644
index 0000000..30b5254
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/checkpoint
@@ -0,0 +1,2 @@
+model_checkpoint_path: "ckpt"
+all_model_checkpoint_paths: "ckpt"
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00000-of-00002 b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00000-of-00002
new file mode 100644
index 0000000..119d528
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00000-of-00002
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00001-of-00002 b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00001-of-00002
new file mode 100644
index 0000000..b3f9682
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00001-of-00002
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.index b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.index
new file mode 100644
index 0000000..123174c
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.index
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/saved_model.pb b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/saved_model.pb
new file mode 100644
index 0000000..07701c3
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/saved_model.pb
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00000-of-00002 b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00000-of-00002
new file mode 100644
index 0000000..7053750
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00000-of-00002
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00001-of-00002 b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00001-of-00002
new file mode 100644
index 0000000..0136799
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00001-of-00002
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.index b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.index
new file mode 100644
index 0000000..3f1b412
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.index
Binary files differ
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
index b486a58..4f586cf 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
@@ -1,6 +1,7 @@
 path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
+  is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
   is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
   is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
   is_instance: "<type \'object\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
index b486a58..4f586cf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
@@ -1,6 +1,7 @@
 path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
+  is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
   is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
   is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
   is_instance: "<type \'object\'>"