Add SaveableObject --> Trackable converter.

Another CL for the SaveableObject deprecation. This converter will allow Trackables that have not yet migrated off of `_gather_saveables_for_checkpoint` to be compatible with checkpointing once the code has moved to using `Trackable._serialize_to_tensors` and `Trackable._serialize_from_tensors`

PiperOrigin-RevId: 461776781
diff --git a/tensorflow/python/training/saving/BUILD b/tensorflow/python/training/saving/BUILD
index d622fa0..2d95343 100644
--- a/tensorflow/python/training/saving/BUILD
+++ b/tensorflow/python/training/saving/BUILD
@@ -1,6 +1,8 @@
 # Description:
 #   Low-level utilities for reading and writing checkpoints.
 
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
 package(
     default_visibility = [
         "//tensorflow:internal",
@@ -55,3 +57,20 @@
         "@six_archive//:six",
     ],
 )
+
+tf_py_test(
+    name = "saveable_object_util_test",
+    srcs = ["saveable_object_util_test.py"],
+    deps = [
+        ":saveable_object_util",
+        "//tensorflow/python:resource_variable_ops",
+        "//tensorflow/python:variables",
+        "//tensorflow/python/checkpoint",
+        "//tensorflow/python/eager:context",
+        "//tensorflow/python/eager:test",
+        "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:ops",
+        "//tensorflow/python/trackable:base",
+        "//tensorflow/python/trackable:resource",
+    ],
+)
diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py
index 21a9c3e..3836c34 100644
--- a/tensorflow/python/training/saving/saveable_object_util.py
+++ b/tensorflow/python/training/saving/saveable_object_util.py
@@ -613,3 +613,87 @@
     obj_serialize_fn = obj_serialize_fn.__func__
   return trackable.Trackable._serialize_to_tensors != obj_serialize_fn
   # pylint: enable=protected-access
+
+
+class SaveableCompatibilityConverter(trackable.Trackable):
+  """Converts object's `SaveableObjects` to functions used in TF2 checkpointing.
+
+  A class that converts a Trackable object's `SaveableObjects` to save and
+  restore functions with the same signatures as
+  `Trackable._serialize_to_tensors` and `Trackable._restore_from_tensors`.
+  This class also produces a method for filling the object proto.
+  """
+
+  __slots__ = ("_obj", "_cached_saveables")
+
+  def __init__(self, obj):
+    """Constructor.
+
+    Args:
+      obj: A Trackable object which implements the deprecated
+        `_gather_saveables_for_checkpoint`.
+    """
+    self._obj = obj
+    # The following are cached the first time any of the public methods are run.
+    self._cached_saveables = None
+
+  @property
+  def _saveables(self):
+    """Returns a list of SaveableObjects generated from the Trackable object."""
+    if self._cached_saveables is not None:
+      return self._cached_saveables
+
+    self._cached_saveables = []
+    for name, saveable_factory in (
+        saveable_objects_from_trackable(self._obj).items()):
+      if callable(saveable_factory):
+        maybe_saveable = create_saveable_object(
+            saveable_factory, name, call_with_mapped_captures=None)
+      else:
+        maybe_saveable = saveable_factory
+      if isinstance(maybe_saveable, saveable_object.SaveableObject):
+        saveables = (maybe_saveable,)
+      else:
+        saveables = tuple(saveable_objects_for_op(op=maybe_saveable, name=name))
+      self._cached_saveables.extend(saveables)
+    return self._cached_saveables
+
+  def _serialize_to_tensors(self):
+    """Returns a dict of tensors to serialize."""
+    tensor_dict = {}
+    for saveable in self._saveables:
+      for spec in saveable.specs:
+        tensor = spec.tensor
+        if spec.slice_spec:
+          tensor_dict[spec.name][spec.slice_spec] = tensor
+        else:
+          tensor_dict[spec.name] = tensor
+    return tensor_dict
+
+  def _restore_from_tensors(self, restored_tensors):
+    """Returns the restore ops defined in the Saveables."""
+    # Map restored tensors to the corresponding SaveableObjects, then call
+    # restore. There must be an exact match between restored tensors and the
+    # expected attributes.
+    expected_keys = []
+    for saveable in self._saveables:
+      expected_keys.extend(spec.name for spec in saveable.specs)
+    if set(expected_keys) != restored_tensors.keys():
+      raise ValueError(f"Could not restore object {self._obj} because not all "
+                       "expected tensors were in the checkpoint."
+                       f"\n\tExpected: {expected_keys}"
+                       f"\n\tGot: {list(restored_tensors.keys())}")
+
+    restore_ops = {}
+    for saveable in self._saveables:
+      saveable_restored_tensors = []
+      for spec in saveable.specs:
+        if spec.slice_spec:
+          saveable_restored_tensors.append(
+              restored_tensors[spec.name][spec.slice_spec])
+        else:
+          saveable_restored_tensors.append(restored_tensors[spec.name])
+
+      restore_ops[saveable.name] = saveable.restore(
+          saveable_restored_tensors, restored_shapes=None)
+    return restore_ops
diff --git a/tensorflow/python/training/saving/saveable_object_util_test.py b/tensorflow/python/training/saving/saveable_object_util_test.py
new file mode 100644
index 0000000..9ef9ee4
--- /dev/null
+++ b/tensorflow/python/training/saving/saveable_object_util_test.py
@@ -0,0 +1,258 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for saveable_object_util."""
+
+import os
+
+from tensorflow.python.checkpoint import checkpoint
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.trackable import base
+from tensorflow.python.trackable import resource
+from tensorflow.python.training.saving import saveable_object
+from tensorflow.python.training.saving import saveable_object_util
+
+_VAR_SAVEABLE = saveable_object_util.ResourceVariableSaveable
+
+
+class SaveableCompatibilityConverterTest(test.TestCase):
+
+  def test_convert_no_saveable(self):
+    t = base.Trackable()
+    converter = saveable_object_util.SaveableCompatibilityConverter(t)
+    self.assertEmpty(converter._serialize_to_tensors())
+    converter._restore_from_tensors({})
+
+    with self.assertRaisesRegex(ValueError, "Could not restore object"):
+      converter._restore_from_tensors({"": 0})
+
+  def test_convert_single_saveable(self):
+
+    class MyTrackable(base.Trackable):
+
+      def __init__(self):
+        self.a = variables.Variable(5.0)
+
+      def _gather_saveables_for_checkpoint(self):
+        return {"a": lambda name: _VAR_SAVEABLE(self.a, "", name)}
+
+    t = MyTrackable()
+    converter = saveable_object_util.SaveableCompatibilityConverter(t)
+
+    serialized_tensors = converter._serialize_to_tensors()
+    self.assertLen(serialized_tensors, 1)
+    self.assertIn("a", serialized_tensors)
+    self.assertEqual(5, self.evaluate(serialized_tensors["a"]))
+
+    with self.assertRaisesRegex(ValueError, "Could not restore object"):
+      converter._restore_from_tensors({})
+    with self.assertRaisesRegex(ValueError, "Could not restore object"):
+      converter._restore_from_tensors({"not_a": 1.})
+
+    self.assertEqual(5, self.evaluate(t.a))
+    converter._restore_from_tensors({"a": 123.})
+    self.assertEqual(123, self.evaluate(t.a))
+
+  def test_convert_single_saveable_renamed(self):
+
+    class MyTrackable(base.Trackable):
+
+      def __init__(self):
+        self.a = variables.Variable(15.0)
+
+      def _gather_saveables_for_checkpoint(self):
+        return {"a": lambda name: _VAR_SAVEABLE(self.a, "", name + "-value")}
+
+    t = MyTrackable()
+    converter = saveable_object_util.SaveableCompatibilityConverter(t)
+
+    serialized_tensors = converter._serialize_to_tensors()
+
+    self.assertLen(serialized_tensors, 1)
+    self.assertEqual(15, self.evaluate(serialized_tensors["a-value"]))
+
+    with self.assertRaisesRegex(ValueError, "Could not restore object"):
+      converter._restore_from_tensors({"a": 1.})
+
+    self.assertEqual(15, self.evaluate(t.a))
+    converter._restore_from_tensors({"a-value": 456.})
+    self.assertEqual(456, self.evaluate(t.a))
+
+  def test_convert_multiple_saveables(self):
+
+    class MyTrackable(base.Trackable):
+
+      def __init__(self):
+        self.a = variables.Variable(15.0)
+        self.b = variables.Variable(20.0)
+
+      def _gather_saveables_for_checkpoint(self):
+        return {
+            "a": lambda name: _VAR_SAVEABLE(self.a, "", name + "-1"),
+            "b": lambda name: _VAR_SAVEABLE(self.b, "", name + "-2")}
+
+    t = MyTrackable()
+    converter = saveable_object_util.SaveableCompatibilityConverter(t)
+
+    serialized_tensors = converter._serialize_to_tensors()
+    self.assertLen(serialized_tensors, 2)
+    self.assertEqual(15, self.evaluate(serialized_tensors["a-1"]))
+    self.assertEqual(20, self.evaluate(serialized_tensors["b-2"]))
+
+    with self.assertRaisesRegex(ValueError, "Could not restore object"):
+      converter._restore_from_tensors({"a": 1., "b": 2.})
+    with self.assertRaisesRegex(ValueError, "Could not restore object"):
+      converter._restore_from_tensors({"b-2": 2.})
+
+    converter._restore_from_tensors({"a-1": -123., "b-2": -456.})
+    self.assertEqual(-123, self.evaluate(t.a))
+    self.assertEqual(-456, self.evaluate(t.b))
+
+  def test_convert_variables(self):
+    # The method `_gather_saveables_for_checkpoint` allowed the users to pass
+    # Variables instead of Saveables.
+
+    class MyTrackable(base.Trackable):
+
+      def __init__(self):
+        self.a = variables.Variable(25.)
+        self.b = resource_variable_ops.UninitializedVariable(
+            dtype=dtypes.float32)
+
+      def _gather_saveables_for_checkpoint(self):
+        return {"a": self.a, "b": self.b}
+
+    t = MyTrackable()
+    converter = saveable_object_util.SaveableCompatibilityConverter(t)
+    serialized_tensors = converter._serialize_to_tensors()
+
+    self.assertLen(serialized_tensors, 2)
+    self.assertEqual(25, self.evaluate(serialized_tensors["a"]))
+    self.assertIsNone(serialized_tensors["b"])
+
+    with self.assertRaisesRegex(ValueError, "Could not restore object"):
+      converter._restore_from_tensors({"a": 5.})
+
+    converter._restore_from_tensors({"a": 5., "b": 6.})
+    self.assertEqual(5, self.evaluate(t.a))
+    self.assertEqual(6, self.evaluate(t.b))
+
+
+class State(resource.TrackableResource):
+
+  def __init__(self, initial_value):
+    super().__init__()
+    self._initial_value = initial_value
+    self._initialize()
+
+  def _create_resource(self):
+    return gen_resource_variable_ops.var_handle_op(
+        shape=[],
+        dtype=dtypes.float32,
+        shared_name=context.anonymous_name(),
+        name="StateVar",
+        container="")
+
+  def _initialize(self):
+    gen_resource_variable_ops.assign_variable_op(self.resource_handle,
+                                                 self._initial_value)
+
+  def _destroy_resource(self):
+    gen_resource_variable_ops.destroy_resource_op(self.resource_handle,
+                                                  ignore_lookup_error=True)
+
+  def read(self):
+    return gen_resource_variable_ops.read_variable_op(self.resource_handle,
+                                                      dtypes.float32)
+
+  def assign(self, value):
+    gen_resource_variable_ops.assign_variable_op(self.resource_handle, value)
+
+
+class _StateSaveable(saveable_object.SaveableObject):
+
+  def __init__(self, obj, name):
+    spec = saveable_object.SaveSpec(obj.read(), "", name)
+    self.obj = obj
+    super(_StateSaveable, self).__init__(obj, [spec], name)
+
+  def restore(self, restored_tensors, restored_shapes):
+    del restored_shapes  # Unused.
+    self.obj.assign(restored_tensors[0])
+
+
+class SaveableState(State):
+
+  def _gather_saveables_for_checkpoint(self):
+    return {
+        "value": lambda name: _StateSaveable(self, name)
+    }
+
+
+class TrackableState(State):
+
+  def _serialize_to_tensors(self):
+    return {
+        "value": self.read()
+    }
+
+  def _restore_from_tensors(self, restored_tensors):
+    self.assign(restored_tensors["value"])
+
+
+class SaveableCompatibilityEndToEndTest(test.TestCase):
+
+  def test_checkpoint_comparison(self):
+    saveable_state = SaveableState(5.)
+    trackable_state = TrackableState(10.)
+
+    # First test that SaveableState and TrackableState are equivalent by
+    # saving a checkpoint with both objects and swapping values.
+
+    self.assertEqual(5, self.evaluate(saveable_state.read()))
+    self.assertEqual(10, self.evaluate(trackable_state.read()))
+
+    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
+    checkpoint.Checkpoint(a=saveable_state, b=trackable_state).write(ckpt_path)
+
+    status = checkpoint.Checkpoint(b=saveable_state,
+                                   a=trackable_state).read(ckpt_path)
+    status.assert_consumed()
+
+    self.assertEqual(10, self.evaluate(saveable_state.read()))
+    self.assertEqual(5, self.evaluate(trackable_state.read()))
+
+    # Test that the converted SaveableState is compatible with the checkpoint
+    # saved above.
+    to_convert = SaveableState(0.0)
+
+    converted_saveable_state = (
+        saveable_object_util.SaveableCompatibilityConverter(to_convert))
+
+    checkpoint.Checkpoint(a=converted_saveable_state).read(
+        ckpt_path).assert_existing_objects_matched()
+    self.assertEqual(5, self.evaluate(to_convert.read()))
+
+    checkpoint.Checkpoint(b=converted_saveable_state).read(
+        ckpt_path).assert_existing_objects_matched()
+    self.assertEqual(10, self.evaluate(to_convert.read()))
+
+
+if __name__ == "__main__":
+  test.main()