Change LossScaleOptimizer checkpoint format.
Now the format is identical to as if a LossScaleOptimzier is not used, except that the loss scale is saved with a LossScaleOptimizer. This allows saving checkpoints with a LossScaleOptimizer and restoring without a LossScaleOptimizer, and vice versa.
Checkpoints with LossScaleOptimizers created in older versions of TensorFlow can still be loaded. New checkpoints saved will use the new format.
PiperOrigin-RevId: 306511555
Change-Id: Ie316ab8c4fbfec7babd6f7803d337799d0ff10a5
diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD
index 7609745..18390b4 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/BUILD
+++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD
@@ -224,9 +224,16 @@
name = "keras_test",
size = "medium",
srcs = ["keras_test.py"],
+ data = [
+ "//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_ckpt_tf2.2",
+ "//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_savedmodel_tf2.2",
+ ],
python_version = "PY3",
shard_count = 10,
- tags = ["no_windows"], # b/139083295: bfloat16 tests fail on Windows
+ tags = [
+ "no_pip",
+ "no_windows", # b/139083295: bfloat16 tests fail on Windows
+ ],
deps = [
":test_util",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
index a27be08..d2e80cf 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
@@ -41,6 +41,7 @@
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_spec
+from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
@@ -994,6 +995,56 @@
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1)
@keras_parameterized.run_all_keras_modes
+ def test_restore_old_loss_scale_checkpoint(self):
+ # Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format
+ # of LossScaleOptimizer changed, but old checkpoints can still be loaded
+ opt = gradient_descent.SGD(0.1, momentum=0.1)
+ opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
+ model = sequential.Sequential([core.Dense(2,)])
+
+ # The checkpoint and expected values were obtained from the program in
+ # testdata/BUILD.
+ ckpt_dir = test.test_src_dir_path(
+ 'python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2')
+ model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
+ model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
+ model(np.zeros((2, 2))) # Create model weights
+ opt._create_all_weights(model.weights)
+ expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
+ expected_slot = np.array([[10.049943, 9.917691], [10.049943, 9.917691]])
+ self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
+ self.assertAllClose(
+ self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
+ expected_slot)
+ self.assertEqual(self.evaluate(opt.loss_scale()), 32768)
+ self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1)
+
+ # Check restoring works even after the model is compiled and the weights
+ # have been created.
+ model.fit(np.random.normal(size=(2, 2)), np.random.normal(size=(2, 2)))
+ self.assertNotAllClose(self.evaluate(model.weights[0]), expected_kernel)
+ self.assertNotAllClose(
+ self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
+ expected_slot)
+ model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
+ self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
+ self.assertAllClose(
+ self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
+ expected_slot)
+ self.assertEqual(self.evaluate(opt.loss_scale()), 32768)
+ self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1)
+
+ def test_restore_old_saved_model(self):
+ saved_model_dir = test.test_src_dir_path(
+ 'python/keras/mixed_precision/experimental/testdata/'
+ 'lso_savedmodel_tf2.2')
+ model = save.load_model(saved_model_dir)
+ expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
+ self.assertAllClose(backend.eval(model.weights[0]), expected_kernel)
+ self.assertIsInstance(model.optimizer,
+ loss_scale_optimizer.LossScaleOptimizer)
+
+ @keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(
{
'testcase_name': 'base',
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
index 1c14955..d6a786a 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py
@@ -22,6 +22,7 @@
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import one_device_strategy
from tensorflow.python.distribute import tpu_strategy
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.keras import backend
@@ -32,6 +33,7 @@
from tensorflow.python.ops import math_ops
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision
+from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util.tf_export import keras_export
@@ -51,8 +53,126 @@
self.value = value
+class _DelegatingTrackableMixin(object):
+ """A mixin that delegates all Trackable methods to another trackable object.
+
+ This class must be used with multiple inheritance. A class that subclasses
+ Trackable can also subclass this class, which causes all Trackable methods to
+ be delegated to the trackable object passed in the constructor.
+
+ A subclass can use this mixin to appear as if it were the trackable passed to
+ the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this
+ mixin, so that the checkpoint format for a LossScaleOptimizer is identical to
+ the checkpoint format for a normal optimizer. This allows a model to be saved
+ with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa.
+ The only difference in checkpoint format is that the loss scale is also saved
+ with a LossScaleOptimizer.
+ """
+
+ def __init__(self, trackable_obj):
+ self._trackable = trackable_obj
+
+ # pylint: disable=protected-access
+ @property
+ def _setattr_tracking(self):
+ return self._trackable._setattr_tracking
+
+ @_setattr_tracking.setter
+ def _setattr_tracking(self, value):
+ self._trackable._setattr_tracking = value
+
+ @property
+ def _update_uid(self):
+ return self._trackable._update_uid
+
+ @_update_uid.setter
+ def _update_uid(self, value):
+ self._trackable._update_uid = value
+
+ @property
+ def _unconditional_checkpoint_dependencies(self):
+ return self._trackable._unconditional_checkpoint_dependencies
+
+ @property
+ def _unconditional_dependency_names(self):
+ return self._trackable._unconditional_dependency_names
+
+ @property
+ def _name_based_restores(self):
+ return self._trackable._name_based_restores
+
+ def _maybe_initialize_trackable(self):
+ return self._trackable._maybe_initialize_trackable()
+
+ @property
+ def _object_identifier(self):
+ return self._trackable._object_identifier
+
+ @property
+ def _tracking_metadata(self):
+ return self._trackable._tracking_metadata
+
+ def _no_dependency(self, value):
+ return self._trackable._no_dependency(value)
+
+ def _name_based_attribute_restore(self, checkpoint):
+ return self._trackable._name_based_attribute_restore(checkpoint)
+
+ @property
+ def _checkpoint_dependencies(self):
+ return self._trackable._checkpoint_dependencies
+
+ @property
+ def _deferred_dependencies(self):
+ return self._trackable._deferred_dependencies
+
+ def _lookup_dependency(self, name):
+ self._trackable._lookup_dependency(name)
+
+ def _add_variable_with_custom_getter(self,
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ getter=None,
+ overwrite=False,
+ **kwargs_for_getter):
+ return self._trackable._add_variable_with_custom_getter(
+ name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
+
+ def _preload_simple_restoration(self, name, shape):
+ return self._trackable._preload_simple_restoration(name, shape)
+
+ def _track_trackable(self, trackable, name, overwrite=False): # pylint: disable=redefined-outer-name
+ return self._trackable._track_trackable(trackable, name, overwrite)
+
+ def _handle_deferred_dependencies(self, name, trackable): # pylint: disable=redefined-outer-name
+ return self._trackable._handle_deferred_dependencies(name, trackable)
+
+ def _restore_from_checkpoint_position(self, checkpoint_position):
+ return self._trackable._restore_from_checkpoint_position(
+ checkpoint_position)
+
+ def _single_restoration_from_checkpoint_position(self, checkpoint_position,
+ visit_queue):
+ return self._trackable._single_restoration_from_checkpoint_position(
+ checkpoint_position, visit_queue)
+
+ def _gather_saveables_for_checkpoint(self):
+ return self._trackable._gather_saveables_for_checkpoint()
+
+ def _list_extra_dependencies_for_serialization(self, serialization_cache):
+ return self._trackable._list_extra_dependencies_for_serialization(
+ serialization_cache)
+
+ def _list_functions_for_serialization(self, serialization_cache):
+ return self._trackable._list_functions_for_serialization(
+ serialization_cache)
+ # pylint: enable=protected-access
+
+
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
-class LossScaleOptimizer(optimizer_v2.OptimizerV2):
+class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
"""An optimizer that applies loss scaling.
Loss scaling is a process that multiplies the loss by a multiplier called the
@@ -144,6 +264,11 @@
self._loss_scale = keras_loss_scale_module.get(loss_scale)
if self._loss_scale is None:
raise ValueError('loss_scale cannot be None.')
+
+ # We don't call super().__init__, since we do not want to call OptimizerV2's
+ # constructor.
+ _DelegatingTrackableMixin.__init__(self, self._optimizer)
+
for weight in loss_scale_module.get_loss_scale_weights(self._loss_scale):
# We cannot call `track_variable` in the LossScale class itself, because a
# file outside of Keras cannot depend on a Keras file. Calling it here
@@ -151,12 +276,15 @@
# a Keras class, and the only way to use LossScale with a Keras class is
# through the LossScaleOptimizer.
backend.track_variable(weight)
- self._track_trackable(self._optimizer, 'base_optimizer')
self._track_trackable(self._loss_scale, 'loss_scale')
# Needed because the superclass's __getattribute__ checks this.
self._hyper = {}
+ # To support restoring TensorFlow 2.2 checkpoints.
+ self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
+ 'base_optimizer')
+
@property
def loss_scale(self):
"""The `LossScale` instance associated with this optimizer."""
@@ -348,6 +476,21 @@
def _aggregate_gradients(self, grads_and_vars):
return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access
+ def _restore_slot_variable(self, slot_name, variable, slot_variable):
+ return self._optimizer._restore_slot_variable(slot_name, variable, # pylint: disable=protected-access
+ slot_variable)
+
+ def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
+ variable):
+ return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
+ slot_variable_position, slot_name, variable)
+
+ def get_slot(self, var, slot_name):
+ return self._optimizer.get_slot(var, slot_name)
+
+ def add_slot(self, var, slot_name, initializer='zeros'):
+ return self._optimizer.add_slot(var, slot_name, initializer)
+
# For the most part, we only expose methods in the base OptimizerV2, not
# individual subclasses like Adam. However, although "learning_rate" and "lr"
# properties are not part of the base OptimizerV2 class, they are part of most
@@ -369,23 +512,6 @@
def lr(self, lr):
self._optimizer.lr = lr
- def get_slot(self, var, slot_name):
- # We cannot implement get_slot for the following reason: When saving a
- # checkpoint, two optimizers cannot share slot variables. Since both the
- # LossScaleOptimizer and the wrapped optimizer (self and self._optimizer
- # respectively) are checkpointed, we cannot expose the wrapped optimizer's
- # slots in the LossScaleOptimizer. Otherwise, a checkpoint would believe
- # both optimizers share slot variables.
- raise AttributeError(
- 'You cannot call get_slot on a LossScaleOptimizer. This limitation '
- 'will be removed in the future.')
-
- def add_slot(self, var, slot_name, initializer='zeros'):
- # We disallow adding a slot for consistency with `get_slot`.
- raise AttributeError(
- 'You cannot call add_slot on a LossScaleOptimizer. This limitation '
- 'will be removed in the future.')
-
# We do not override some OptimizerV2 methods. For each, we describe why we do
# not delegate them to self._optimizer:
# * get_updates: get_updates() calls get_gradients(). Since we override
@@ -402,6 +528,51 @@
# TODO(reedwm): Maybe throw an error if mixed precision is used without this
# optimizer being used.
+ # Trackable delegations: Delegate all Trackable methods to the wrapped
+ # optimizer. This is so the checkpoint format for a LossScaleOptimizer is
+ # identical to the checkpoint format for a normal optimizer, except the loss
+ # scale is stored in the checkpoint.
+
+
+class FakeOptimizerForRestoration(trackable.Trackable):
+ """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
+
+ The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class
+ exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow.
+
+ In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the
+ following in LossScaleOptimizer.__init__
+
+ ```
+ self._track_trackable(self._optimizer, 'base_optimizer')
+ ```
+
+ This means a dependency from the LossScaleOptimizer to the wrapped optimizer
+ would be stored in the checkpoint. However now, the checkpoint format with a
+ LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
+ except the loss scale is also stored. This means there is no dependency from
+ the LossScaleOptimizer to the wrapped optimizer. Instead, the
+ LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's
+ perspective, by overriding all Trackable methods and delegating them to the
+ wrapped optimizer.
+
+ To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
+ on this class instead of the inner optimizer. When restored, this class will
+ instead restore the slot variables of the inner optimizer. Since this class
+ has no variables, it does not affect the checkpoint when saved.
+ """
+
+ def __init__(self, optimizer):
+ self._optimizer = optimizer
+
+ def get_slot_names(self):
+ return self._optimizer.get_slot_names()
+
+ def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
+ variable):
+ return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
+ slot_variable_position, slot_name, variable)
+
# pylint: disable=protected-access
mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
index cbabda3..20252ff 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
@@ -305,20 +305,6 @@
opt.set_weights([np.array(2.)])
self.assertEqual(self.evaluate(opt.variables()[0]), 2)
- def testSlotMethodErrors(self):
- opt = gradient_descent.SGD(1.0, momentum=1.0)
- opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
- with self.assertRaisesRegexp(
- AttributeError,
- 'You cannot call get_slot on a LossScaleOptimizer. This limitation '
- 'will be removed in the future.'):
- opt.get_slot(None, None)
- with self.assertRaisesRegexp(
- AttributeError,
- 'You cannot call add_slot on a LossScaleOptimizer. This limitation '
- 'will be removed in the future.'):
- opt.add_slot(None, None)
-
def testPassingNoneToLossScale(self):
opt = gradient_descent.SGD()
with self.assertRaisesRegexp(ValueError, r'loss_scale cannot be None'):
@@ -394,9 +380,49 @@
run_fn = lambda: opt.minimize(loss, [var])
strategy.experimental_run(run_fn)
- @parameterized.named_parameters(*TESTCASES)
- def testCheckpoint(self, strategy_fn):
+ @parameterized.named_parameters({
+ 'testcase_name': 'SaveAndRestoreBase',
+ 'strategy_fn': default_strategy_fn,
+ 'save_with_ls': True,
+ 'restore_with_ls': True,
+ }, {
+ 'testcase_name': 'SaveAndRestoreDistribute',
+ 'strategy_fn': create_mirrored_strategy,
+ 'save_with_ls': True,
+ 'restore_with_ls': True,
+ }, {
+ 'testcase_name': 'SaveBase',
+ 'strategy_fn': default_strategy_fn,
+ 'save_with_ls': True,
+ 'restore_with_ls': False,
+ }, {
+ 'testcase_name': 'SaveDistribute',
+ 'strategy_fn': create_mirrored_strategy,
+ 'save_with_ls': True,
+ 'restore_with_ls': False,
+ }, {
+ 'testcase_name': 'RestoreBase',
+ 'strategy_fn': default_strategy_fn,
+ 'save_with_ls': False,
+ 'restore_with_ls': True,
+ }, {
+ 'testcase_name': 'RestoreDistribute',
+ 'strategy_fn': create_mirrored_strategy,
+ 'save_with_ls': False,
+ 'restore_with_ls': True,
+ })
+ def testCheckpoint(self, strategy_fn, save_with_ls, restore_with_ls):
+
+ class MySGD(gradient_descent.SGD):
+ """A custom optimizer that tracks an extra variable."""
+
+ def __init__(self, *args, **kwargs):
+ super(MySGD, self).__init__(*args, **kwargs)
+ self.my_var = variables.Variable(0.)
+ self._track_trackable(self.my_var, 'my_var')
+
strategy = strategy_fn()
+ replicas = strategy.num_replicas_in_sync
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
not context.executing_eagerly()):
# TODO(b/121381184): Enable running the test in this case.
@@ -405,38 +431,89 @@
with self.test_session(), strategy.scope():
# Build and run a simple model.
var = variables.Variable([2.0])
- loss_scale = loss_scale_module.DynamicLossScale(
- initial_loss_scale=1., increment_period=2.,
- multiplier=2.)
- opt = gradient_descent.SGD(1., momentum=1.)
- opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
- run_fn = lambda: opt.minimize(lambda: var + 1., var_list=[var])
+ opt = inner_opt = MySGD(1., momentum=1.)
+ if save_with_ls:
+ loss_scale = loss_scale_module.DynamicLossScale(
+ initial_loss_scale=1., increment_period=2.,
+ multiplier=2.)
+ opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
+ run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
opt_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
- self.evaluate(opt_op)
- self.assertEqual(self.evaluate(loss_scale()), 1.)
- self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
- slot_var = opt._optimizer.get_slot(var, 'momentum')
- slot_value = self.evaluate(slot_var).item()
+ self.evaluate(strategy.experimental_local_results(opt_op))
+
+ # Assert values.
+ self.assertEqual(self.evaluate(var), 1.)
+ if save_with_ls:
+ self.assertEqual(self.evaluate(loss_scale()), 1.)
+ self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
+ slot_var = opt.get_slot(var, 'momentum')
+ self.assertEqual(self.evaluate(slot_var).item(), -1)
+ self.assertEqual(self.evaluate(opt.iterations), 1)
+
+ # Set optimizer variable to check arbitrary optimizer attributes can be
+ # saved/restored
+ self.evaluate(inner_opt.my_var.assign(1.))
# Save a checkpoint.
checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
prefix = os.path.join(self.get_temp_dir(), 'ckpt')
save_path = checkpoint.save(prefix)
- # Run model again.
- self.evaluate(strategy.experimental_run(run_fn))
- self.assertEqual(self.evaluate(loss_scale()), 2.)
- self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
- self.assertNotAlmostEqual(self.evaluate(slot_var).item(), slot_value)
+ # Create new model
+ var = variables.Variable([2.0])
+ opt = inner_opt = MySGD(1., momentum=1.)
+ if restore_with_ls:
+ loss_scale = loss_scale_module.DynamicLossScale(
+ initial_loss_scale=1., increment_period=2.,
+ multiplier=2.)
+ opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
- # Load checkpoint and ensure loss scale is back to it's original value.
+ # Restore new model.
+ checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
status = checkpoint.restore(save_path)
- status.assert_consumed()
+ if save_with_ls:
+ status.assert_existing_objects_matched()
+ else:
+ status.assert_nontrivial_match()
+
+ # Assert restored values. We can only assert in eager mode since the
+ # variables are uninitialized in graph mode
+ if context.executing_eagerly():
+ self.assertEqual(self.evaluate(var), 1.)
+ if save_with_ls and restore_with_ls:
+ self.assertEqual(self.evaluate(loss_scale()), 1.)
+ self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
+ elif restore_with_ls:
+ self.assertEqual(self.evaluate(loss_scale()), 1.)
+ self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
+ self.assertEqual(self.evaluate(opt.iterations), 1)
+
+ # Run the model again.
+ run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
+ opt_op = strategy.experimental_run(run_fn)
+
+ # Assert new values.
+ self.evaluate(variables.global_variables_initializer())
status.run_restore_ops()
- self.assertEqual(self.evaluate(loss_scale()), 1.)
- self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
- self.assertAlmostEqual(self.evaluate(slot_var).item(), slot_value)
+ self.evaluate(strategy.experimental_local_results(opt_op))
+ self.assertEqual(self.evaluate(var), -1)
+ slot_var = opt.get_slot(var, 'momentum')
+ self.assertEqual(self.evaluate(slot_var).item(), -2)
+ self.assertEqual(self.evaluate(opt.iterations), 2)
+ self.assertEqual(self.evaluate(inner_opt.my_var), 1)
+
+ # Restore model again to test restoring after slots are created
+ status = checkpoint.restore(save_path)
+ if save_with_ls and restore_with_ls:
+ status.assert_consumed()
+ elif save_with_ls:
+ status.assert_existing_objects_matched()
+ elif restore_with_ls:
+ status.assert_nontrivial_match()
+ status.run_restore_ops()
+ self.assertEqual(self.evaluate(var), 1)
+ self.assertEqual(self.evaluate(slot_var).item(), -1)
def testGetConfig(self):
opt = gradient_descent.SGD(2., momentum=0.5)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/BUILD b/tensorflow/python/keras/mixed_precision/experimental/testdata/BUILD
new file mode 100644
index 0000000..39e8906
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/BUILD
@@ -0,0 +1,48 @@
+# Description:
+# Contains checkpoints and SavedModels for testing purposes.
+
+package(
+ default_visibility = [
+ "//tensorflow/python/keras:__subpackages__",
+ "//tensorflow/tools/pip_package:__pkg__",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files(["LICENSE"])
+
+# These files were generated by running the following program with TensorFlow
+# 2.2rc2. The final release of TF 2.2 was not out when this change was created.:
+
+# import os
+# import numpy as np
+# import tensorflow as tf
+#
+# tf.random.set_seed(1)
+# opt = tf.keras.optimizers.SGD(0.1, momentum=0.1)
+# opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
+# model = tf.keras.Sequential([tf.keras.layers.Dense(2)])
+# model.compile(opt, 'mse')
+#
+# x = np.ones((10, 2))
+# y = x * 100
+# model.fit(x, y)
+# weight_dir = os.environ['TF_LSO_WEIGHT_DIR']
+# model_dir = os.environ['TF_LSO_MODEL_DIR']
+# model.save_weights(weight_dir)
+# model.save(model_dir)
+# print(model.get_weights()[0])
+# print(opt._optimizer.get_slot(model.weights[0], 'momentum'))
+# print(opt.loss_scale)
+
+filegroup(
+ name = "lso_ckpt_tf2.2",
+ srcs = glob(["lso_ckpt_tf2.2/**"]),
+ tags = ["no_pip"],
+)
+
+filegroup(
+ name = "lso_savedmodel_tf2.2",
+ srcs = glob(["lso_savedmodel_tf2.2/**"]),
+ tags = ["no_pip"],
+)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/checkpoint b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/checkpoint
new file mode 100644
index 0000000..30b5254
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/checkpoint
@@ -0,0 +1,2 @@
+model_checkpoint_path: "ckpt"
+all_model_checkpoint_paths: "ckpt"
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00000-of-00002 b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00000-of-00002
new file mode 100644
index 0000000..119d528
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00000-of-00002
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00001-of-00002 b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00001-of-00002
new file mode 100644
index 0000000..b3f9682
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.data-00001-of-00002
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.index b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.index
new file mode 100644
index 0000000..123174c
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.index
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/saved_model.pb b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/saved_model.pb
new file mode 100644
index 0000000..07701c3
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/saved_model.pb
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00000-of-00002 b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00000-of-00002
new file mode 100644
index 0000000..7053750
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00000-of-00002
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00001-of-00002 b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00001-of-00002
new file mode 100644
index 0000000..0136799
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.data-00001-of-00002
Binary files differ
diff --git a/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.index b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.index
new file mode 100644
index 0000000..3f1b412
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/variables/variables.index
Binary files differ
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
index b486a58..4f586cf 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
tf_class {
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
+ is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
index b486a58..4f586cf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
tf_class {
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
+ is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"