Add SaveableObject --> Trackable converter.
Another CL for the SaveableObject deprecation. This converter will allow Trackables that have not yet migrated off of `_gather_saveables_for_checkpoint` to be compatible with checkpointing once the code has moved to using `Trackable._serialize_to_tensors` and `Trackable._serialize_from_tensors`
PiperOrigin-RevId: 461776781
diff --git a/tensorflow/python/training/saving/BUILD b/tensorflow/python/training/saving/BUILD
index d622fa0..2d95343 100644
--- a/tensorflow/python/training/saving/BUILD
+++ b/tensorflow/python/training/saving/BUILD
@@ -1,6 +1,8 @@
# Description:
# Low-level utilities for reading and writing checkpoints.
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
package(
default_visibility = [
"//tensorflow:internal",
@@ -55,3 +57,20 @@
"@six_archive//:six",
],
)
+
+tf_py_test(
+ name = "saveable_object_util_test",
+ srcs = ["saveable_object_util_test.py"],
+ deps = [
+ ":saveable_object_util",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/checkpoint",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/framework:dtypes",
+ "//tensorflow/python/framework:ops",
+ "//tensorflow/python/trackable:base",
+ "//tensorflow/python/trackable:resource",
+ ],
+)
diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py
index 21a9c3e..3836c34 100644
--- a/tensorflow/python/training/saving/saveable_object_util.py
+++ b/tensorflow/python/training/saving/saveable_object_util.py
@@ -613,3 +613,87 @@
obj_serialize_fn = obj_serialize_fn.__func__
return trackable.Trackable._serialize_to_tensors != obj_serialize_fn
# pylint: enable=protected-access
+
+
+class SaveableCompatibilityConverter(trackable.Trackable):
+ """Converts object's `SaveableObjects` to functions used in TF2 checkpointing.
+
+ A class that converts a Trackable object's `SaveableObjects` to save and
+ restore functions with the same signatures as
+ `Trackable._serialize_to_tensors` and `Trackable._restore_from_tensors`.
+ This class also produces a method for filling the object proto.
+ """
+
+ __slots__ = ("_obj", "_cached_saveables")
+
+ def __init__(self, obj):
+ """Constructor.
+
+ Args:
+ obj: A Trackable object which implements the deprecated
+ `_gather_saveables_for_checkpoint`.
+ """
+ self._obj = obj
+ # The following are cached the first time any of the public methods are run.
+ self._cached_saveables = None
+
+ @property
+ def _saveables(self):
+ """Returns a list of SaveableObjects generated from the Trackable object."""
+ if self._cached_saveables is not None:
+ return self._cached_saveables
+
+ self._cached_saveables = []
+ for name, saveable_factory in (
+ saveable_objects_from_trackable(self._obj).items()):
+ if callable(saveable_factory):
+ maybe_saveable = create_saveable_object(
+ saveable_factory, name, call_with_mapped_captures=None)
+ else:
+ maybe_saveable = saveable_factory
+ if isinstance(maybe_saveable, saveable_object.SaveableObject):
+ saveables = (maybe_saveable,)
+ else:
+ saveables = tuple(saveable_objects_for_op(op=maybe_saveable, name=name))
+ self._cached_saveables.extend(saveables)
+ return self._cached_saveables
+
+ def _serialize_to_tensors(self):
+ """Returns a dict of tensors to serialize."""
+ tensor_dict = {}
+ for saveable in self._saveables:
+ for spec in saveable.specs:
+ tensor = spec.tensor
+ if spec.slice_spec:
+ tensor_dict[spec.name][spec.slice_spec] = tensor
+ else:
+ tensor_dict[spec.name] = tensor
+ return tensor_dict
+
+ def _restore_from_tensors(self, restored_tensors):
+ """Returns the restore ops defined in the Saveables."""
+ # Map restored tensors to the corresponding SaveableObjects, then call
+ # restore. There must be an exact match between restored tensors and the
+ # expected attributes.
+ expected_keys = []
+ for saveable in self._saveables:
+ expected_keys.extend(spec.name for spec in saveable.specs)
+ if set(expected_keys) != restored_tensors.keys():
+ raise ValueError(f"Could not restore object {self._obj} because not all "
+ "expected tensors were in the checkpoint."
+ f"\n\tExpected: {expected_keys}"
+ f"\n\tGot: {list(restored_tensors.keys())}")
+
+ restore_ops = {}
+ for saveable in self._saveables:
+ saveable_restored_tensors = []
+ for spec in saveable.specs:
+ if spec.slice_spec:
+ saveable_restored_tensors.append(
+ restored_tensors[spec.name][spec.slice_spec])
+ else:
+ saveable_restored_tensors.append(restored_tensors[spec.name])
+
+ restore_ops[saveable.name] = saveable.restore(
+ saveable_restored_tensors, restored_shapes=None)
+ return restore_ops
diff --git a/tensorflow/python/training/saving/saveable_object_util_test.py b/tensorflow/python/training/saving/saveable_object_util_test.py
new file mode 100644
index 0000000..9ef9ee4
--- /dev/null
+++ b/tensorflow/python/training/saving/saveable_object_util_test.py
@@ -0,0 +1,258 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for saveable_object_util."""
+
+import os
+
+from tensorflow.python.checkpoint import checkpoint
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.trackable import base
+from tensorflow.python.trackable import resource
+from tensorflow.python.training.saving import saveable_object
+from tensorflow.python.training.saving import saveable_object_util
+
+_VAR_SAVEABLE = saveable_object_util.ResourceVariableSaveable
+
+
+class SaveableCompatibilityConverterTest(test.TestCase):
+
+ def test_convert_no_saveable(self):
+ t = base.Trackable()
+ converter = saveable_object_util.SaveableCompatibilityConverter(t)
+ self.assertEmpty(converter._serialize_to_tensors())
+ converter._restore_from_tensors({})
+
+ with self.assertRaisesRegex(ValueError, "Could not restore object"):
+ converter._restore_from_tensors({"": 0})
+
+ def test_convert_single_saveable(self):
+
+ class MyTrackable(base.Trackable):
+
+ def __init__(self):
+ self.a = variables.Variable(5.0)
+
+ def _gather_saveables_for_checkpoint(self):
+ return {"a": lambda name: _VAR_SAVEABLE(self.a, "", name)}
+
+ t = MyTrackable()
+ converter = saveable_object_util.SaveableCompatibilityConverter(t)
+
+ serialized_tensors = converter._serialize_to_tensors()
+ self.assertLen(serialized_tensors, 1)
+ self.assertIn("a", serialized_tensors)
+ self.assertEqual(5, self.evaluate(serialized_tensors["a"]))
+
+ with self.assertRaisesRegex(ValueError, "Could not restore object"):
+ converter._restore_from_tensors({})
+ with self.assertRaisesRegex(ValueError, "Could not restore object"):
+ converter._restore_from_tensors({"not_a": 1.})
+
+ self.assertEqual(5, self.evaluate(t.a))
+ converter._restore_from_tensors({"a": 123.})
+ self.assertEqual(123, self.evaluate(t.a))
+
+ def test_convert_single_saveable_renamed(self):
+
+ class MyTrackable(base.Trackable):
+
+ def __init__(self):
+ self.a = variables.Variable(15.0)
+
+ def _gather_saveables_for_checkpoint(self):
+ return {"a": lambda name: _VAR_SAVEABLE(self.a, "", name + "-value")}
+
+ t = MyTrackable()
+ converter = saveable_object_util.SaveableCompatibilityConverter(t)
+
+ serialized_tensors = converter._serialize_to_tensors()
+
+ self.assertLen(serialized_tensors, 1)
+ self.assertEqual(15, self.evaluate(serialized_tensors["a-value"]))
+
+ with self.assertRaisesRegex(ValueError, "Could not restore object"):
+ converter._restore_from_tensors({"a": 1.})
+
+ self.assertEqual(15, self.evaluate(t.a))
+ converter._restore_from_tensors({"a-value": 456.})
+ self.assertEqual(456, self.evaluate(t.a))
+
+ def test_convert_multiple_saveables(self):
+
+ class MyTrackable(base.Trackable):
+
+ def __init__(self):
+ self.a = variables.Variable(15.0)
+ self.b = variables.Variable(20.0)
+
+ def _gather_saveables_for_checkpoint(self):
+ return {
+ "a": lambda name: _VAR_SAVEABLE(self.a, "", name + "-1"),
+ "b": lambda name: _VAR_SAVEABLE(self.b, "", name + "-2")}
+
+ t = MyTrackable()
+ converter = saveable_object_util.SaveableCompatibilityConverter(t)
+
+ serialized_tensors = converter._serialize_to_tensors()
+ self.assertLen(serialized_tensors, 2)
+ self.assertEqual(15, self.evaluate(serialized_tensors["a-1"]))
+ self.assertEqual(20, self.evaluate(serialized_tensors["b-2"]))
+
+ with self.assertRaisesRegex(ValueError, "Could not restore object"):
+ converter._restore_from_tensors({"a": 1., "b": 2.})
+ with self.assertRaisesRegex(ValueError, "Could not restore object"):
+ converter._restore_from_tensors({"b-2": 2.})
+
+ converter._restore_from_tensors({"a-1": -123., "b-2": -456.})
+ self.assertEqual(-123, self.evaluate(t.a))
+ self.assertEqual(-456, self.evaluate(t.b))
+
+ def test_convert_variables(self):
+ # The method `_gather_saveables_for_checkpoint` allowed the users to pass
+ # Variables instead of Saveables.
+
+ class MyTrackable(base.Trackable):
+
+ def __init__(self):
+ self.a = variables.Variable(25.)
+ self.b = resource_variable_ops.UninitializedVariable(
+ dtype=dtypes.float32)
+
+ def _gather_saveables_for_checkpoint(self):
+ return {"a": self.a, "b": self.b}
+
+ t = MyTrackable()
+ converter = saveable_object_util.SaveableCompatibilityConverter(t)
+ serialized_tensors = converter._serialize_to_tensors()
+
+ self.assertLen(serialized_tensors, 2)
+ self.assertEqual(25, self.evaluate(serialized_tensors["a"]))
+ self.assertIsNone(serialized_tensors["b"])
+
+ with self.assertRaisesRegex(ValueError, "Could not restore object"):
+ converter._restore_from_tensors({"a": 5.})
+
+ converter._restore_from_tensors({"a": 5., "b": 6.})
+ self.assertEqual(5, self.evaluate(t.a))
+ self.assertEqual(6, self.evaluate(t.b))
+
+
+class State(resource.TrackableResource):
+
+ def __init__(self, initial_value):
+ super().__init__()
+ self._initial_value = initial_value
+ self._initialize()
+
+ def _create_resource(self):
+ return gen_resource_variable_ops.var_handle_op(
+ shape=[],
+ dtype=dtypes.float32,
+ shared_name=context.anonymous_name(),
+ name="StateVar",
+ container="")
+
+ def _initialize(self):
+ gen_resource_variable_ops.assign_variable_op(self.resource_handle,
+ self._initial_value)
+
+ def _destroy_resource(self):
+ gen_resource_variable_ops.destroy_resource_op(self.resource_handle,
+ ignore_lookup_error=True)
+
+ def read(self):
+ return gen_resource_variable_ops.read_variable_op(self.resource_handle,
+ dtypes.float32)
+
+ def assign(self, value):
+ gen_resource_variable_ops.assign_variable_op(self.resource_handle, value)
+
+
+class _StateSaveable(saveable_object.SaveableObject):
+
+ def __init__(self, obj, name):
+ spec = saveable_object.SaveSpec(obj.read(), "", name)
+ self.obj = obj
+ super(_StateSaveable, self).__init__(obj, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ del restored_shapes # Unused.
+ self.obj.assign(restored_tensors[0])
+
+
+class SaveableState(State):
+
+ def _gather_saveables_for_checkpoint(self):
+ return {
+ "value": lambda name: _StateSaveable(self, name)
+ }
+
+
+class TrackableState(State):
+
+ def _serialize_to_tensors(self):
+ return {
+ "value": self.read()
+ }
+
+ def _restore_from_tensors(self, restored_tensors):
+ self.assign(restored_tensors["value"])
+
+
+class SaveableCompatibilityEndToEndTest(test.TestCase):
+
+ def test_checkpoint_comparison(self):
+ saveable_state = SaveableState(5.)
+ trackable_state = TrackableState(10.)
+
+ # First test that SaveableState and TrackableState are equivalent by
+ # saving a checkpoint with both objects and swapping values.
+
+ self.assertEqual(5, self.evaluate(saveable_state.read()))
+ self.assertEqual(10, self.evaluate(trackable_state.read()))
+
+ ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
+ checkpoint.Checkpoint(a=saveable_state, b=trackable_state).write(ckpt_path)
+
+ status = checkpoint.Checkpoint(b=saveable_state,
+ a=trackable_state).read(ckpt_path)
+ status.assert_consumed()
+
+ self.assertEqual(10, self.evaluate(saveable_state.read()))
+ self.assertEqual(5, self.evaluate(trackable_state.read()))
+
+ # Test that the converted SaveableState is compatible with the checkpoint
+ # saved above.
+ to_convert = SaveableState(0.0)
+
+ converted_saveable_state = (
+ saveable_object_util.SaveableCompatibilityConverter(to_convert))
+
+ checkpoint.Checkpoint(a=converted_saveable_state).read(
+ ckpt_path).assert_existing_objects_matched()
+ self.assertEqual(5, self.evaluate(to_convert.read()))
+
+ checkpoint.Checkpoint(b=converted_saveable_state).read(
+ ckpt_path).assert_existing_objects_matched()
+ self.assertEqual(10, self.evaluate(to_convert.read()))
+
+
+if __name__ == "__main__":
+ test.main()