Add checkpoint saver registration.

This is a currently-internal API that allows users to register functions for reading and writing checkpoints directly from file.

PiperOrigin-RevId: 405506664
Change-Id: I60c7b6d47f3b8d8b0fa8ff8b7c383d8112ba2301
diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto
index 85bf826..bd9303c 100644
--- a/tensorflow/core/protobuf/saved_object_graph.proto
+++ b/tensorflow/core/protobuf/saved_object_graph.proto
@@ -70,12 +70,26 @@
     CapturedTensor captured_tensor = 12;
   }
 
+  // Stores the functions used to save and restore this object. At most one of
+  // `saveable_objects` or `registered_saver` is defined for each SavedObject.
+  // See the comment below for the difference between SaveableObject and
+  // registered savers.
   map<string, SaveableObject> saveable_objects = 11;
 
   // The fields below are filled when the user serializes a registered Trackable
-  // class. Registered classes may save additional metadata and supersede the
-  // default loading process where nodes are recreated from the proto.
+  // class or an object with a registered saver function.
   //
+  // Registered classes may save additional metadata and supersede the
+  // default loading process where nodes are recreated from the proto.
+  // If the registered class cannot be found, then the object will load as one
+  // one of the default trackable objects: Autotrackable (a class similar to
+  // tf.Module), tf.function, or tf.Variable.
+  //
+  // Unlike SaveableObjects, which store the functions for saving and restoring
+  // from tensors, registered savers allow Trackables to write checkpoint shards
+  // directly (e.g. for performance or coordination reasons).
+  // *All registered savers must be available when loading the SavedModel.*
+
   // The name of the registered class of the form "{package}.{class_name}".
   // This field is used to search for the registered class at loading time.
   string registered_name = 13;
@@ -83,6 +97,10 @@
   // the registered classes's _deserialize_from_proto method when this object is
   // loaded from the SavedModel.
   google.protobuf.Any serialized_user_proto = 14;
+
+  // String name of the registered saver. At most one of `saveable_objects` or
+  // `registered_saver` is defined for each SavedObject.
+  string registered_saver = 16;
 }
 
 // A SavedUserObject is an object (in the object-oriented language of the
@@ -226,6 +244,7 @@
 
 message SaveableObject {
   // Node ids of concrete functions for saving and loading from a checkpoint.
+  // These functions save and restore directly from tensors.
   int32 save_function = 2;
   int32 restore_function = 3;
 }
diff --git a/tensorflow/core/protobuf/trackable_object_graph.proto b/tensorflow/core/protobuf/trackable_object_graph.proto
index 4be996b..4ccf19f 100644
--- a/tensorflow/core/protobuf/trackable_object_graph.proto
+++ b/tensorflow/core/protobuf/trackable_object_graph.proto
@@ -54,7 +54,19 @@
     repeated SerializedTensor attributes = 2;
     // Slot variables owned by this object.
     repeated SlotVariableReference slot_variables = 3;
+
+    // The registered saver used to save this object. If this saver is not
+    // present when loading the checkpoint, then loading will fail.
+    RegisteredSaver registered_saver = 4;
   }
 
   repeated TrackableObject nodes = 1;
 }
+
+message RegisteredSaver {
+  // The name of the registered saver/restore function.
+  string name = 1;
+
+  // Unique auto-generated name of the object.
+  string object_name = 2;
+}
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index dfcde1b..9b8e404 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -551,9 +551,12 @@
     srcs = ["registration_saving_test.py"],
     deps = [
         ":registration",
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python:io_ops",
         "//tensorflow/python:resource_variable_ops",
+        "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:test",
         "//tensorflow/python/training/tracking",
         "@absl_py//absl/testing:parameterized",
diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py
index 3a20b27..2b0755e 100644
--- a/tensorflow/python/saved_model/load.py
+++ b/tensorflow/python/saved_model/load.py
@@ -550,6 +550,13 @@
       for object_id, obj in dict(checkpoint.object_by_proto_id).items():
         position = base.CheckpointPosition(checkpoint=checkpoint,
                                            proto_id=object_id)
+        registered_saver = position.get_registered_saver_name()
+        if registered_saver:
+          raise NotImplementedError(
+              "Loading a SavedModel that uses registered checkpoint saver is "
+              f"not supported in graph mode. The loaded object {obj} uses the "
+              f"saver registered with the name {registered_saver}.")
+
         restore_ops = position.restore_ops()
         if restore_ops:
           if resource_variable_ops.is_resource_variable(obj):
diff --git a/tensorflow/python/saved_model/registration.py b/tensorflow/python/saved_model/registration.py
index cd5d51a..2c61f74 100644
--- a/tensorflow/python/saved_model/registration.py
+++ b/tensorflow/python/saved_model/registration.py
@@ -16,31 +16,124 @@
 
 revived_types registration will be migrated to this infrastructure.
 """
+import collections
+import re
 
 from tensorflow.python.util import tf_inspect
 
-_CLASS_REGISTRY = {}  # string registered name -> (predicate, class)
-_REGISTERED_NAMES = []
+
+# Only allow valid file/directory characters
+_VALID_REGISTERED_NAME = re.compile(r"^[a-zA-Z0-9._-]+$")
 
 
-def get_registered_name(obj):
-  for name in reversed(_REGISTERED_NAMES):
-    predicate, cls = _CLASS_REGISTRY[name]
-    if not predicate and type(obj) == cls:  # pylint: disable=unidiomatic-typecheck
-      return name
-    if predicate and predicate(obj):
-      return name
-  return None
+class _PredicateRegistry(object):
+  """Registry with predicate-based lookup.
+
+  See the documentation for `register_checkpoint_saver` and
+  `register_serializable` for reasons why predicates are required over a
+  class-based registry.
+
+  Since this class is used for global registries, each object must be registered
+  to unique names (an error is raised if there are naming conflicts). The lookup
+  searches the predicates in reverse order, so that later-registered predicates
+  are executed first.
+  """
+  __slots__ = ("_registry_name", "_registered_map", "_registered_predicates",
+               "_registered_names")
+
+  def __init__(self, name):
+    self._registry_name = name
+    # Maps registered name -> object
+    self._registered_map = {}
+    # Maps registered name -> predicate
+    self._registered_predicates = {}
+    # Stores names in the order of registration
+    self._registered_names = []
+
+  @property
+  def name(self):
+    return self._registry_name
+
+  def register(self, package, name, predicate, candidate):
+    """Registers a candidate object under the package, name and predicate."""
+    if not isinstance(package, str) or not isinstance(name, str):
+      raise TypeError(
+          f"The package and name registered to a {self.name} must be strings, "
+          f"got: package={type(package)}, name={type(name)}")
+    if not callable(predicate):
+      raise TypeError(
+          f"The predicate registered to a {self.name} must be callable, "
+          f"got: {type(predicate)}")
+    registered_name = package + "." + name
+    if not _VALID_REGISTERED_NAME.match(registered_name):
+      raise ValueError(
+          f"Invalid registered {self.name}. Please check that the package and "
+          f"name follow the regex '{_VALID_REGISTERED_NAME.pattern}': "
+          f"(package='{package}', name='{name}')")
+    if registered_name in self._registered_map:
+      raise ValueError(
+          f"The name '{registered_name}' has already been registered to a "
+          f"{self.name}. Found: {self._registered_map[registered_name]}")
+
+    self._registered_map[registered_name] = candidate
+    self._registered_predicates[registered_name] = predicate
+    self._registered_names.append(registered_name)
+
+  def lookup(self, obj):
+    """Looks up the registered object using the predicate.
+
+    Args:
+      obj: Object to pass to each of the registered predicates to look up the
+        registered object.
+    Returns:
+      The object registered with the first passing predicate.
+    Raises:
+      LookupError if the object does not match any of the predicate functions.
+    """
+    return self._registered_map[self.get_registered_name(obj)]
+
+  def name_lookup(self, registered_name):
+    """Looks up the registered object using the registered name."""
+    try:
+      return self._registered_map[registered_name]
+    except KeyError:
+      raise LookupError(f"The {self.name} registry does not have name "
+                        f"'{registered_name}' registered.")
+
+  def get_registered_name(self, obj):
+    for registered_name in reversed(self._registered_names):
+      predicate = self._registered_predicates[registered_name]
+      if predicate(obj):
+        return registered_name
+    raise LookupError(f"Could not find matching {self.name} for {type(obj)}.")
+
+  def get_predicate(self, registered_name):
+    try:
+      return self._registered_predicates[registered_name]
+    except KeyError:
+      raise LookupError(f"The {self.name} registry does not have name "
+                        f"'{registered_name}' registered.")
+
+
+_class_registry = _PredicateRegistry("serializable class")
+_saver_registry = _PredicateRegistry("checkpoint saver")
+
+
+def get_registered_class_name(obj):
+  try:
+    return _class_registry.get_registered_name(obj)
+  except LookupError:
+    return None
 
 
 def get_registered_class(registered_name):
   try:
-    return _CLASS_REGISTRY[registered_name][1]
-  except KeyError:
+    return _class_registry.name_lookup(registered_name)
+  except LookupError:
     return None
 
 
-def register_serializable(package="Custom", name=None, predicate=None):
+def register_serializable(package="Custom", name=None, predicate=None):  # pylint: disable=unused-argument
   """Decorator for registering a serializable class.
 
   THIS METHOD IS STILL EXPERIMENTAL AND MAY CHANGE AT ANY TIME.
@@ -75,30 +168,194 @@
   Returns:
     A decorator that registers the decorated class with the passed names and
     predicate.
-
-  Raises:
-    ValueError if predicate is not callable.
   """
-  if predicate is not None and not callable(predicate):
-    raise ValueError("The `predicate` passed to registered_serializable "
-                     "must be callable.")
-
   def decorator(arg):
     """Registers a class with the serialization framework."""
+    nonlocal predicate
     if not tf_inspect.isclass(arg):
-      raise ValueError(
-          "Registered serializable must be a class: {}".format(arg))
+      raise TypeError("Registered serializable must be a class: {}".format(arg))
 
     class_name = name if name is not None else arg.__name__
-    registered_name = package + "." + class_name
-
-    if registered_name in _CLASS_REGISTRY:
-      raise ValueError("{} has already been registered to {}".format(
-          registered_name, _CLASS_REGISTRY[registered_name]))
-
-    _CLASS_REGISTRY[registered_name] = (predicate, arg)
-    _REGISTERED_NAMES.append(registered_name)
-
+    if predicate is None:
+      predicate = lambda x: isinstance(x, arg)
+    _class_registry.register(package, class_name, predicate, arg)
     return arg
 
   return decorator
+
+
+RegisteredSaver = collections.namedtuple(
+    "RegisteredSaver", ["name", "predicate", "save_fn", "restore_fn"])
+_REGISTERED_SAVERS = {}
+_REGISTERED_SAVER_NAMES = []  # Stores names in the order of registration
+
+
+def register_checkpoint_saver(package="Custom",
+                              name=None,
+                              predicate=None,
+                              save_fn=None,
+                              restore_fn=None):
+  """Registers functions which checkpoints & restores objects with custom steps.
+
+  If you have a class that requires complicated coordination between multiple
+  objects when checkpointing, then you will need to register a custom saver
+  and restore function. An example of this is a custom Variable class that
+  splits the variable across different objects and devices, and needs to write
+  checkpoints that are compatible with different configurations of devices.
+
+  The registered save and restore functions are used in checkpoints and
+  SavedModel.
+
+  Please make sure you are familiar with the concepts in the [Checkpointing
+  guide](https://www.tensorflow.org/guide/checkpoint), and ops used to save the
+  V2 checkpoint format:
+
+  * io_ops.SaveV2
+  * io_ops.MergeV2Checkpoints
+  * io_ops.RestoreV2
+
+  **Predicate**
+
+  The predicate is a filter that will run on every `Trackable` object connected
+  to the root object. This function determines whether a `Trackable` should use
+  the registered functions.
+
+  Example: `lambda x: isinstance(x, CustomClass)`
+
+  **Custom save function**
+
+  This is how checkpoint saving works normally:
+  1. Gather all of the Trackables with saveable values.
+  2. For each Trackable, gather all of the saveable tensors.
+  3. Save checkpoint shards (grouping tensors by device) with SaveV2
+  4. Merge the shards with MergeCheckpointV2. This combines all of the shard's
+     metadata, and renames them to follow the standard shard pattern.
+
+  When a saver is registered, Trackables that pass the registered `predicate`
+  are automatically marked as having saveable values. Next, the custom save
+  function replaces steps 2 and 3 of the saving process. Finally, the shards
+  returned by the custom save function are merged with the other shards.
+
+  The save function takes in a dictionary of `Trackables` and a `file_prefix`
+  string. The function should save checkpoint shards using the SaveV2 op, and
+  list of the shard prefixes. SaveV2 is currently required to work a correctly,
+  because the code merges all of the returned shards, and the `restore_fn` will
+  only be given the prefix of the merged checkpoint. If you need to be able to
+  save and restore from unmerged shards, please file a feature request.
+
+  Specification and example of the save function:
+
+  ```
+  def save_fn(trackables, file_prefix):
+    # trackables: A dictionary mapping unique string identifiers to trackables
+    # file_prefix: A unique file prefix generated using the registered name.
+    ...
+    # Gather the tensors to save.
+    ...
+    io_ops.SaveV2(file_prefix, tensor_names, shapes_and_slices, tensors)
+    return file_prefix  # Returns a tensor or a list of string tensors
+  ```
+
+  **Custom restore function**
+
+  Normal checkpoint restore behavior:
+  1. Gather all of the Trackables that have saveable values.
+  2. For each Trackable, get the names of the desired tensors to extract from
+     the checkpoint.
+  3. Use RestoreV2 to read the saved values, and pass the restored tensors to
+     the corresponding Trackables.
+
+  The custom restore function replaces steps 2 and 3.
+
+  The restore function also takes a dictionary of `Trackables` and a
+  `merged_prefix` string. The `merged_prefix` is different from the
+  `file_prefix`, since it contains the renamed shard paths. To read from the
+  merged checkpoint, you must use `RestoreV2(merged_prefix, ...)`.
+
+  Specification:
+
+  ```
+  def restore_fn(trackables, merged_prefix):
+    # trackables: A dictionary mapping unique string identifiers to Trackables
+    # merged_prefix: File prefix of the merged shard names.
+
+    restored_tensors = io_ops.restore_v2(
+        merged_prefix, tensor_names, shapes_and_slices, dtypes)
+    ...
+    # Restore the checkpoint values for the given Trackables.
+  ```
+
+  Args:
+    package: Optional, the package that this class belongs to.
+    name: (Required) The name of this saver, which is saved to the checkpoint.
+      When a checkpoint is restored, the name and package are used to find the
+      the matching restore function. The name and package are also used to
+      generate a unique file prefix that is passed to the save_fn.
+    predicate: (Required) A function that returns a boolean indicating whether a
+      `Trackable` object should be checkpointed with this function. Predicates
+      are executed in the reverse order that they are added (later registrations
+      are checked first).
+    save_fn: (Required) A function that takes a dictionary of trackables and a
+      file prefix as the arguments, writes the checkpoint shards for the given
+      Trackables, and returns the list of shard prefixes.
+    restore_fn: (Required) A function that takes a dictionary of trackables and
+      a file prefix as the arguments and restores the trackable values.
+
+  Raises:
+    ValueError: if the package and name are already registered.
+  """
+  if not callable(save_fn):
+    raise TypeError(f"The save_fn must be callable, got: {type(save_fn)}")
+  if not callable(restore_fn):
+    raise TypeError(f"The restore_fn must be callable, got: {type(restore_fn)}")
+
+  _saver_registry.register(package, name, predicate, (save_fn, restore_fn))
+
+
+def get_registered_saver_name(trackable):
+  """Returns the name of the registered saver to use with Trackable."""
+  try:
+    return _saver_registry.get_registered_name(trackable)
+  except LookupError:
+    return None
+
+
+def get_save_function(registered_name):
+  """Returns save function registered to name."""
+  return _saver_registry.name_lookup(registered_name)[0]
+
+
+def get_restore_function(registered_name):
+  """Returns restore function registered to name."""
+  return _saver_registry.name_lookup(registered_name)[1]
+
+
+def validate_restore_function(trackable, registered_name):
+  """Validates whether the trackable can be restored with the saver.
+
+  When using a checkpoint saved with a registered saver, that same saver must
+  also be also registered when loading. The name of that saver is saved to the
+  checkpoint and set in the `registered_name` arg.
+
+  Args:
+    trackable: A `Trackable` object.
+    registered_name: String name of the expected registered saver. This argument
+      should be set using the name saved in a checkpoint.
+
+  Raises:
+    ValueError if the saver could not be found, or if the predicate associated
+      with the saver does not pass.
+  """
+  try:
+    _saver_registry.name_lookup(registered_name)
+  except LookupError:
+    raise ValueError(
+        f"Error when restoring object {trackable} from checkpoint. This "
+        "object was saved using a registered saver named "
+        f"'{registered_name}', but this saver cannot be found in the "
+        "current context.")
+  if not _saver_registry.get_predicate(registered_name)(trackable):
+    raise ValueError(
+        f"Object {trackable} was saved with the registered saver named "
+        f"'{registered_name}'. However, this saver cannot be used to restore the "
+        "object because the predicate does not pass.")
diff --git a/tensorflow/python/saved_model/registration_saving_test.py b/tensorflow/python/saved_model/registration_saving_test.py
index e245219..8d56488 100644
--- a/tensorflow/python/saved_model/registration_saving_test.py
+++ b/tensorflow/python/saved_model/registration_saving_test.py
@@ -14,17 +14,94 @@
 # ==============================================================================
 """Tests saving with registered Trackable classes and checkpoint functions."""
 
+import os
 import tempfile
+
 from absl.testing import parameterized
 
 from google.protobuf import wrappers_pb2
+from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.saved_model import load
 from tensorflow.python.saved_model import registration
 from tensorflow.python.saved_model import save
 from tensorflow.python.training.tracking import tracking
+from tensorflow.python.training.tracking import util
+
+
+@registration.register_serializable()
+class Part(resource_variable_ops.ResourceVariable):
+
+  def __init__(self, value):
+    self._init_from_args(value)
+
+  @classmethod
+  def _deserialize_from_proto(cls, **kwargs):
+    return cls([0, 0])
+
+
+@registration.register_serializable()
+class Stack(tracking.AutoTrackable):
+
+  def __init__(self, parts=None):
+    self.parts = parts
+
+  @def_function.function(input_signature=[])
+  def value(self):
+    return array_ops.stack(self.parts)
+
+
+def get_tensor_slices(trackables):
+  tensor_names = []
+  shapes_and_slices = []
+  tensors = []
+  restored_trackables = []
+  for obj_prefix, obj in trackables.items():
+    if isinstance(obj, Part):
+      continue  # only save stacks
+    tensor_names.append(obj_prefix + "/value")
+    shapes_and_slices.append("")
+    x = obj.value()
+    with ops.device("/device:CPU:0"):
+      tensors.append(array_ops.identity(x))
+    restored_trackables.append(obj)
+
+  return tensor_names, shapes_and_slices, tensors, restored_trackables
+
+
+def save_stacks_and_parts(trackables, file_prefix):
+  """Save stack and part objects to a checkpoint shard."""
+  tensor_names, shapes_and_slices, tensors, _ = get_tensor_slices(trackables)
+  io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices, tensors)
+  return file_prefix
+
+
+def restore_stacks_and_parts(trackables, merged_prefix):
+  tensor_names, shapes_and_slices, tensors, restored_trackables = (
+      get_tensor_slices(trackables))
+  dtypes = [t.dtype for t in tensors]
+  restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names,
+                                       shapes_and_slices, dtypes)
+  for trackable, restored_tensor in zip(restored_trackables, restored_tensors):
+    expected_shape = trackable.value().get_shape()
+    restored_tensor = array_ops.reshape(restored_tensor, expected_shape)
+    parts = array_ops.unstack(restored_tensor)
+    for part, restored_part in zip(trackable.parts, parts):
+      part.assign(restored_part)
+
+
+registration.register_checkpoint_saver(
+    name="stacks",
+    predicate=lambda x: isinstance(x, (Stack, Part)),
+    save_fn=save_stacks_and_parts,
+    restore_fn=restore_stacks_and_parts)
 
 
 def cycle(obj, cycles, signatures=None, options=None):
@@ -45,11 +122,10 @@
 @parameterized.named_parameters(
     dict(testcase_name="ReloadOnce", cycles=1),
     dict(testcase_name="ReloadTwice", cycles=2),
-    dict(testcase_name="ReloadThrice", cycles=3)
-)
+    dict(testcase_name="ReloadThrice", cycles=3))
 class SavedModelTest(test.TestCase, parameterized.TestCase):
 
-  def test_save_and_load(self, cycles):
+  def test_registered_serializable(self, cycles):
 
     @registration.register_serializable(name=f"SaveAndLoad{cycles}")
     class Module(tracking.AutoTrackable):
@@ -123,5 +199,53 @@
     self.assertIsInstance(loaded, Module)
     self.assertEqual(5, loaded.v.numpy())
 
+  def test_registered_saver(self, cycles):
+    p1 = Part([1, 4])
+    p2 = Part([2, 5])
+    p3 = Part([3, 6])
+    s = Stack([p1, p2, p3])
+    loaded = cycle(s, cycles)
+    self.assertAllEqual(s.value(), loaded.value())
+
+
+class SingleCycleTest(test.TestCase):
+
+  @test_util.deprecated_graph_mode_only()
+  def test_registered_saver_fails_in_saved_model_graph_mode(self):
+    with context.eager_mode():
+      p1 = Part([1, 4])
+      p2 = Part([2, 5])
+      p3 = Part([3, 6])
+      s = Stack([p1, p2, p3])
+      save_dir = os.path.join(self.get_temp_dir(), "save_dir")
+      save.save(s, save_dir)
+
+    with self.assertRaisesRegex(
+        NotImplementedError,
+        "registered checkpoint saver is not supported in graph mode"):
+      load.load(save_dir)
+
+  def test_registered_saver_checkpoint(self):
+    p1 = Part([1, 4])
+    p2 = Part([2, 5])
+    p3 = Part([3, 6])
+    s = Stack([p1, p2, p3])
+    s2 = Stack([p3, p1, p2])
+
+    expected_value_s = s.value()
+    expected_value_s2 = s2.value()
+
+    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
+    util.Checkpoint(s=s, s2=s2).write(ckpt_path)
+
+    del s, s2, p1, p2, p3
+
+    restore_s = Stack([Part([0, 0]) for _ in range(3)])
+    util.Checkpoint(s=restore_s).read(ckpt_path).expect_partial()
+    self.assertAllEqual(expected_value_s, restore_s.value())
+    util.Checkpoint(s2=restore_s).read(ckpt_path).expect_partial()
+    self.assertAllEqual(expected_value_s2, restore_s.value())
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/saved_model/registration_test.py b/tensorflow/python/saved_model/registration_test.py
index d9c794e..dfa5050 100644
--- a/tensorflow/python/saved_model/registration_test.py
+++ b/tensorflow/python/saved_model/registration_test.py
@@ -1,4 +1,4 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -54,7 +54,7 @@
   ])
   def test_registration(self, expected_cls, expected_name):
     obj = expected_cls()
-    self.assertEqual(registration.get_registered_name(obj), expected_name)
+    self.assertEqual(registration.get_registered_class_name(obj), expected_name)
     self.assertIs(
         registration.get_registered_class(expected_name), expected_cls)
 
@@ -67,7 +67,7 @@
       pass
 
     no_register = NotRegistered
-    self.assertIsNone(registration.get_registered_name(no_register))
+    self.assertIsNone(registration.get_registered_class_name(no_register))
 
   def test_duplicate_registration(self):
 
@@ -76,12 +76,13 @@
       pass
 
     dup = Duplicate()
-    self.assertEqual(registration.get_registered_name(dup), "Custom.Duplicate")
+    self.assertEqual(
+        registration.get_registered_class_name(dup), "Custom.Duplicate")
     # Registrations with different names are ok.
     registration.register_serializable(package="duplicate")(Duplicate)
     # Registrations are checked in reverse order.
     self.assertEqual(
-        registration.get_registered_name(dup), "duplicate.Duplicate")
+        registration.get_registered_class_name(dup), "duplicate.Duplicate")
     # Both names should resolve to the same class.
     self.assertIs(
         registration.get_registered_class("Custom.Duplicate"), Duplicate)
@@ -96,12 +97,12 @@
 
   def test_register_non_class_fails(self):
     obj = RegisteredClass()
-    with self.assertRaisesRegex(ValueError, "must be a class"):
+    with self.assertRaisesRegex(TypeError, "must be a class"):
       registration.register_serializable()(obj)
 
   def test_register_bad_predicate_fails(self):
-    with self.assertRaisesRegex(ValueError, "must be callable"):
-      registration.register_serializable(predicate=0)
+    with self.assertRaisesRegex(TypeError, "must be callable"):
+      registration.register_serializable(predicate=0)(RegisteredClass)
 
   def test_predicate(self):
 
@@ -118,8 +119,9 @@
     a = Predicate(True)
     b = Predicate(False)
     self.assertEqual(
-        registration.get_registered_name(a), "Custom.RegisterThisOnlyTrue")
-    self.assertIsNone(registration.get_registered_name(b))
+        registration.get_registered_class_name(a),
+        "Custom.RegisterThisOnlyTrue")
+    self.assertIsNone(registration.get_registered_class_name(b))
 
     registration.register_serializable(
         name="RegisterAllPredicate",
@@ -127,9 +129,90 @@
             Predicate)
 
     self.assertEqual(
-        registration.get_registered_name(a), "Custom.RegisterAllPredicate")
+        registration.get_registered_class_name(a),
+        "Custom.RegisterAllPredicate")
     self.assertEqual(
-        registration.get_registered_name(b), "Custom.RegisterAllPredicate")
+        registration.get_registered_class_name(b),
+        "Custom.RegisterAllPredicate")
+
+
+class CheckpointSaverRegistrationTest(test.TestCase):
+
+  def test_invalid_registration(self):
+    with self.assertRaisesRegex(TypeError, "must be string"):
+      registration.register_checkpoint_saver(
+          package=None,
+          name="test",
+          predicate=lambda: None,
+          save_fn=lambda: None,
+          restore_fn=lambda: None)
+    with self.assertRaisesRegex(TypeError, "must be string"):
+      registration.register_checkpoint_saver(
+          name=None,
+          predicate=lambda: None,
+          save_fn=lambda: None,
+          restore_fn=lambda: None)
+    with self.assertRaisesRegex(ValueError,
+                                "Invalid registered checkpoint saver."):
+      registration.register_checkpoint_saver(
+          package="package",
+          name="t/est",
+          predicate=lambda: None,
+          save_fn=lambda: None,
+          restore_fn=lambda: None)
+    with self.assertRaisesRegex(ValueError,
+                                "Invalid registered checkpoint saver."):
+      registration.register_checkpoint_saver(
+          package="package",
+          name="t/est",
+          predicate=lambda: None,
+          save_fn=lambda: None,
+          restore_fn=lambda: None)
+    with self.assertRaisesRegex(
+        TypeError,
+        "The predicate registered to a checkpoint saver must be callable"
+    ):
+      registration.register_checkpoint_saver(
+          name="test",
+          predicate=None,
+          save_fn=lambda: None,
+          restore_fn=lambda: None)
+    with self.assertRaisesRegex(TypeError, "The save_fn must be callable"):
+      registration.register_checkpoint_saver(
+          name="test",
+          predicate=lambda: None,
+          save_fn=None,
+          restore_fn=lambda: None)
+    with self.assertRaisesRegex(TypeError, "The restore_fn must be callable"):
+      registration.register_checkpoint_saver(
+          name="test",
+          predicate=lambda: None,
+          save_fn=lambda: None,
+          restore_fn=None)
+
+  def test_registration(self):
+    registration.register_checkpoint_saver(
+        package="Testing",
+        name="test_predicate",
+        predicate=lambda x: hasattr(x, "check_attr"),
+        save_fn=lambda: "save",
+        restore_fn=lambda: "restore")
+    x = base.Trackable()
+    self.assertIsNone(registration.get_registered_saver_name(x))
+
+    x.check_attr = 1
+    saver_name = registration.get_registered_saver_name(x)
+    self.assertEqual(saver_name, "Testing.test_predicate")
+
+    self.assertEqual(registration.get_save_function(saver_name)(), "save")
+    self.assertEqual(registration.get_restore_function(saver_name)(), "restore")
+
+    registration.validate_restore_function(x, "Testing.test_predicate")
+    with self.assertRaisesRegex(ValueError, "saver cannot be found"):
+      registration.validate_restore_function(x, "Invalid.name")
+    x2 = base.Trackable()
+    with self.assertRaisesRegex(ValueError, "saver cannot be used"):
+      registration.validate_restore_function(x2, "Testing.test_predicate")
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 42920b6..598ea3e 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -256,8 +256,12 @@
     of the object loaded from the SavedModel. These functions are recorded in
     the `saveable_objects` map in the `SavedObject` proto.
     """
-    checkpoint_factory_map = graph_view.get_checkpoint_factories_and_keys(
-        self.object_names)
+    checkpoint_factory_map, registered_savers = (
+        graph_view.get_checkpoint_factories_and_keys(self.object_names))
+    self._obj_to_registered_saver = object_identity.ObjectIdentityDictionary()
+    for saver_name, trackables in registered_savers.items():
+      for trackable in trackables.values():
+        self._obj_to_registered_saver[trackable] = saver_name
     self._saveable_objects_map = (
         _gen_save_and_restore_functions(checkpoint_factory_map))
 
@@ -373,14 +377,18 @@
         child_proto.node_id = self.node_ids[ref_function]
         child_proto.local_name = local_name
 
-      if node not in self._saveable_objects_map:
-        continue
+      if node in self._saveable_objects_map:
+        assert node not in self._obj_to_registered_saver, (
+            "Objects can't have both SaveableObjects and a registered saver")
 
-      for local_name, (save_fn, restore_fn) in (
-          self._saveable_objects_map[node].items()):
-        saveable_object_proto = object_proto.saveable_objects[local_name]
-        saveable_object_proto.save_function = self.node_ids[save_fn]
-        saveable_object_proto.restore_function = self.node_ids[restore_fn]
+        for local_name, (save_fn, restore_fn) in (
+            self._saveable_objects_map[node].items()):
+          saveable_object_proto = object_proto.saveable_objects[local_name]
+          saveable_object_proto.save_function = self.node_ids[save_fn]
+          saveable_object_proto.restore_function = self.node_ids[restore_fn]
+
+      elif node in self._obj_to_registered_saver:
+        object_proto.registered_saver = self._obj_to_registered_saver[node]
 
   def map_resources(self):
     """Makes new resource handle ops corresponding to existing resource tensors.
@@ -968,11 +976,15 @@
   # gathering from the eager context so Optimizers save the right set of
   # variables, but want any operations associated with the save/restore to be in
   # the exported graph (thus the `to_graph` argument).
-  saver = functional_saver.MultiDeviceSaver(
-      saveable_view.checkpoint_view.frozen_saveable_objects(
+  call_with_mapped_captures = functools.partial(
+      _call_function_with_mapped_captures, resource_map=resource_map)
+  named_saveable_objects, registered_savers = (
+      saveable_view.checkpoint_view.frozen_saveables_and_savers(
           object_map=object_map, to_graph=exported_graph,
-          call_with_mapped_captures=functools.partial(
-              _call_function_with_mapped_captures, resource_map=resource_map)))
+          call_with_mapped_captures=call_with_mapped_captures))
+  saver = functional_saver.MultiDeviceSaver(named_saveable_objects,
+                                            registered_savers,
+                                            call_with_mapped_captures)
 
   with exported_graph.as_default():
     signatures = _generate_signatures(signature_functions, resource_map)
@@ -1151,7 +1163,7 @@
       # pylint:enable=protected-access
     proto.user_object.CopyFrom(registered_type_proto)
 
-  registered_name = registration.get_registered_name(obj)
+  registered_name = registration.get_registered_class_name(obj)
   if registered_name:
     proto.registered_name = registered_name
     serialized_user_proto = obj._serialize_to_proto()  # pylint: disable=protected-access
diff --git a/tensorflow/python/training/saving/BUILD b/tensorflow/python/training/saving/BUILD
index 646d204..180d6c0 100644
--- a/tensorflow/python/training/saving/BUILD
+++ b/tensorflow/python/training/saving/BUILD
@@ -29,6 +29,7 @@
         ":saveable_object",
         ":saveable_object_util",
         "//tensorflow/python/eager:def_function",
+        "//tensorflow/python/saved_model:registration",
     ],
 )
 
diff --git a/tensorflow/python/training/saving/functional_saver.py b/tensorflow/python/training/saving/functional_saver.py
index bf7c7d1..1c6e803 100644
--- a/tensorflow/python/training/saving/functional_saver.py
+++ b/tensorflow/python/training/saving/functional_saver.py
@@ -21,10 +21,12 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_io_ops
 from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import string_ops
+from tensorflow.python.saved_model import registration
 from tensorflow.python.training.saving import checkpoint_options
 from tensorflow.python.training.saving import saveable_hook
 from tensorflow.python.training.saving import saveable_object
@@ -126,6 +128,44 @@
   return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
 
 
+def registered_saver_filename(filename_tensor, saver_name):
+  return string_ops.string_join(
+      [filename_tensor, constant_op.constant(f"-{saver_name}")])
+
+
+def _get_mapped_registered_save_fn(fn, trackables, call_with_mapped_captures):
+  """Converts the function to a python or tf.function with a single file arg."""
+  if call_with_mapped_captures is None:
+    def mapped_fn(file_prefix):
+      return fn(trackables=trackables, file_prefix=file_prefix)
+    return mapped_fn
+  else:
+    tf_fn = def_function.function(fn, autograph=False)
+    concrete = tf_fn.get_concrete_function(
+        trackables=trackables,
+        file_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
+    def mapped_fn(file_prefix):
+      return call_with_mapped_captures(concrete, [file_prefix])
+    return mapped_fn
+
+
+def _get_mapped_registered_restore_fn(fn, trackables,
+                                      call_with_mapped_captures):
+  """Converts the function to a python or tf.function with a single file arg."""
+  if call_with_mapped_captures is None:
+    def mapped_fn(merged_prefix):
+      return fn(trackables=trackables, merged_prefix=merged_prefix)
+    return mapped_fn
+  else:
+    tf_fn = def_function.function(fn, autograph=False)
+    concrete = tf_fn.get_concrete_function(
+        trackables=trackables,
+        merged_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
+    def mapped_fn(merged_prefix):
+      return call_with_mapped_captures(concrete, [merged_prefix])
+    return mapped_fn
+
+
 class MultiDeviceSaver(object):
   """Saves checkpoints directly from multiple devices.
 
@@ -134,7 +174,10 @@
   checkpointing are built on top of it.
   """
 
-  def __init__(self, saveable_objects):
+  def __init__(self,
+               saveable_objects,
+               registered_savers=None,
+               call_with_mapped_captures=None):
     """Specify a list of `SaveableObject`s to save and restore.
 
     Args:
@@ -142,6 +185,11 @@
         Objects extending `SaveableObject` will be saved and restored, and
         objects extending `SaveableHook` will be called into at save and
         restore time.
+      registered_savers: A dictionary mapping `registration.RegisteredSaver`
+        namedtuples to a dictionary of named Trackables. The keys of the
+        Trackable dictionary are string names that uniquely identify the
+        Trackable in the checkpoint.
+      call_with_mapped_captures: TODO
     """
     self._before_save_callbacks = []
     self._after_restore_callbacks = []
@@ -168,6 +216,17 @@
         device: _SingleDeviceSaver(saveables)
         for device, saveables in saveables_by_device.items()}
 
+    self._registered_savers = {}
+    if registered_savers:
+      for registered_name, trackables in registered_savers.items():
+        save_fn = _get_mapped_registered_save_fn(
+            registration.get_save_function(registered_name),
+            trackables, call_with_mapped_captures)
+        restore_fn = _get_mapped_registered_restore_fn(
+            registration.get_restore_function(registered_name),
+            trackables, call_with_mapped_captures)
+        self._registered_savers[registered_name] = (save_fn, restore_fn)
+
   def to_proto(self):
     """Serializes to a SaverDef referencing the current graph."""
     filename_tensor = array_ops.placeholder(
@@ -247,11 +306,29 @@
           constant_op.constant("_temp/part"))
       tmp_checkpoint_prefix = string_ops.string_join(
           [file_prefix, sharded_suffix])
+      registered_paths = {
+          saver_name: registered_saver_filename(file_prefix, saver_name)
+          for saver_name in self._registered_savers
+      }
 
     def save_fn():
+      saved_prefixes = []
+      # Save with the registered savers.
+      for saver_name, (save_fn, _) in self._registered_savers.items():
+        maybe_saved_prefixes = save_fn(registered_paths[saver_name])
+        if maybe_saved_prefixes is not None:
+          flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes)
+          if not all(
+              tensor_util.is_tf_type(x) and x.dtype == dtypes.string
+              for x in flattened_saved_prefixes):
+            raise ValueError(
+                "Registered saver can only return `None` or "
+                f"string type tensors. Got {maybe_saved_prefixes}.")
+          saved_prefixes.extend(flattened_saved_prefixes)
+
+      # (Default saver) Save with single device savers.
       num_shards = len(self._single_device_savers)
       sharded_saves = []
-      sharded_prefixes = []
       num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
       last_device = None
       for shard, (device, saver) in enumerate(
@@ -260,7 +337,7 @@
         with ops.device(saveable_object_util.set_cpu0(device)):
           shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                           num_shards_tensor)
-        sharded_prefixes.append(shard_prefix)
+        saved_prefixes.append(shard_prefix)
         with ops.device(device):
           # _SingleDeviceSaver will use the CPU device when necessary, but
           # initial read operations should be placed on the SaveableObject's
@@ -278,7 +355,7 @@
           # merged, attempts to delete the temporary directory,
           # "<user-fed prefix>_temp".
           return gen_io_ops.merge_v2_checkpoints(
-              sharded_prefixes, file_prefix, delete_old_dirs=True)
+              saved_prefixes, file_prefix, delete_old_dirs=True)
 
     # Since this will causes a function re-trace on each save, limit this to the
     # cases where it is needed: eager and when there are multiple tasks/single
@@ -315,7 +392,8 @@
       for device, saver in sorted(self._single_device_savers.items()):
         with ops.device(device):
           restore_ops.update(saver.restore(file_prefix, options))
-
+      for _, (_, restore_fn) in self._registered_savers.items():
+        restore_fn(file_prefix)
       return restore_ops
 
     # Since this will causes a function re-trace on each restore, limit this to
diff --git a/tensorflow/python/training/tracking/BUILD b/tensorflow/python/training/tracking/BUILD
index bb5fd05..0510dc9 100644
--- a/tensorflow/python/training/tracking/BUILD
+++ b/tensorflow/python/training/tracking/BUILD
@@ -28,6 +28,7 @@
         "//tensorflow/python:platform",
         "//tensorflow/python:util",
         "//tensorflow/python/eager:context",
+        "//tensorflow/python/saved_model:registration",
         "//tensorflow/python/training/saving:saveable_object",
         "@six_archive//:six",
     ],
diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py
index b7edf5d..cbbe9e8 100644
--- a/tensorflow/python/training/tracking/base.py
+++ b/tensorflow/python/training/tracking/base.py
@@ -26,6 +26,7 @@
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_io_ops as io_ops
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import registration
 from tensorflow.python.training.saving import saveable_object
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_decorator
@@ -441,6 +442,9 @@
       A list of operations when graph building, or an empty list when executing
       eagerly.
     """
+    if self._has_registered_saver():
+      raise ValueError("Unable to run individual checkpoint restore for objects"
+                       " with registered savers.")
     (restore_ops, tensor_saveables,
      python_saveables) = self.gather_ops_or_named_saveables()
     restore_ops.extend(
@@ -477,6 +481,16 @@
         return self._checkpoint.shape_map[serialized_tensor.checkpoint_key]
     return None
 
+  def _has_registered_saver(self):
+    return bool(self.object_proto.registered_saver.name)
+
+  def get_registered_saver_name(self):
+    if self._has_registered_saver():
+      saver_name = self.object_proto.registered_saver.name
+      registration.validate_restore_function(self.trackable, saver_name)
+      return saver_name
+    return None
+
   def _queue_slot_variable_for_restoration(self, optimizer_object, variable,
                                            slot_variable_id, slot_name):
     """Adds a slot variable onto the restoration queue.
@@ -999,16 +1013,27 @@
     restore_ops = []
     tensor_saveables = {}
     python_saveables = []
+    registered_savers = collections.defaultdict(dict)
     while visit_queue:
       current_position = visit_queue.popleft()
-      new_restore_ops, new_tensor_saveables, new_python_saveables = (
-          current_position.trackable  # pylint: disable=protected-access
-          ._single_restoration_from_checkpoint_position(
-              checkpoint_position=current_position,
-              visit_queue=visit_queue))
-      restore_ops.extend(new_restore_ops)
-      tensor_saveables.update(new_tensor_saveables)
-      python_saveables.extend(new_python_saveables)
+      trackable = current_position.trackable
+
+      # Restore using the ops defined in a Saveable or registered function.
+      registered_saver = current_position.get_registered_saver_name()
+      if registered_saver:
+        object_name = (
+            current_position.object_proto.registered_saver.object_name)
+        registered_savers[registered_saver][object_name] = trackable
+        trackable._self_update_uid = current_position.checkpoint.restore_uid  # pylint: disable=protected-access
+      else:
+        new_restore_ops, new_tensor_saveables, new_python_saveables = (
+            trackable._single_restoration_from_checkpoint_position(  # pylint: disable=protected-access
+                current_position))
+        restore_ops.extend(new_restore_ops)
+        tensor_saveables.update(new_tensor_saveables)
+        python_saveables.extend(new_python_saveables)
+
+      _queue_children_for_restoration(current_position, visit_queue)
 
     # Restore slot variables first.
     #
@@ -1032,11 +1057,11 @@
 
     restore_ops.extend(
         current_position.checkpoint.restore_saveables(tensor_saveables,
-                                                      python_saveables))
+                                                      python_saveables,
+                                                      registered_savers))
     return restore_ops
 
-  def _single_restoration_from_checkpoint_position(self, checkpoint_position,
-                                                   visit_queue):
+  def _single_restoration_from_checkpoint_position(self, checkpoint_position):
     """Restore this object, and either queue its dependencies or defer them."""
     self._maybe_initialize_trackable()
     checkpoint = checkpoint_position.checkpoint
@@ -1051,22 +1076,6 @@
       restore_ops = ()
       tensor_saveables = {}
       python_saveables = ()
-    for child in checkpoint_position.object_proto.children:
-      child_position = CheckpointPosition(
-          checkpoint=checkpoint, proto_id=child.node_id)
-      local_object = self._lookup_dependency(child.local_name)
-      if local_object is None:
-        # We don't yet have a dependency registered with this name. Save it
-        # in case we do.
-        self._deferred_dependencies.setdefault(child.local_name,
-                                               []).append(child_position)
-      else:
-        if child_position.bind_object(trackable=local_object):
-          # This object's correspondence is new, so dependencies need to be
-          # visited. Delay doing it so that we get a breadth-first dependency
-          # resolution order (shallowest paths first). The caller is responsible
-          # for emptying visit_queue.
-          visit_queue.append(child_position)
     return restore_ops, tensor_saveables, python_saveables
 
   def _gather_saveables_for_checkpoint(self):
@@ -1320,3 +1329,26 @@
       returned must also be in the `_checkpoint_dependencies` dict.
     """
     return {}
+
+
+def _queue_children_for_restoration(checkpoint_position, visit_queue):
+  """Queues the restoration of trackable's children or defers them."""
+  # pylint: disable=protected-access
+  trackable = checkpoint_position.trackable
+  checkpoint = checkpoint_position.checkpoint
+  for child in checkpoint_position.object_proto.children:
+    child_position = CheckpointPosition(
+        checkpoint=checkpoint, proto_id=child.node_id)
+    local_object = trackable._lookup_dependency(child.local_name)
+    if local_object is None:
+      # We don't yet have a dependency registered with this name. Save it
+      # in case we do.
+      trackable._deferred_dependencies.setdefault(child.local_name,
+                                                  []).append(child_position)
+    else:
+      if child_position.bind_object(trackable=local_object):
+        # This object's correspondence is new, so dependencies need to be
+        # visited. Delay doing it so that we get a breadth-first dependency
+        # resolution order (shallowest paths first). The caller is responsible
+        # for emptying visit_queue.
+        visit_queue.append(child_position)
diff --git a/tensorflow/python/training/tracking/graph_view.py b/tensorflow/python/training/tracking/graph_view.py
index f51ac86..a94d41d 100644
--- a/tensorflow/python/training/tracking/graph_view.py
+++ b/tensorflow/python/training/tracking/graph_view.py
@@ -21,6 +21,7 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.saved_model import registration
 from tensorflow.python.training import optimizer as optimizer_v1
 from tensorflow.python.training.saving import saveable_object as saveable_object_lib
 from tensorflow.python.training.saving import saveable_object_util
@@ -141,27 +142,60 @@
   return slot_variables
 
 
-def get_checkpoint_factories_and_keys(object_names):
+def _get_mapped_trackable(trackable, object_map):
+  """Returns the mapped trackable if possible, otherwise returns trackable."""
+  if object_map is None:
+    return trackable
+  else:
+    return object_map.get(trackable, trackable)
+
+
+def get_checkpoint_factories_and_keys(object_names, object_map=None):
   """Gets a map of saveable factories and corresponding checkpoint keys.
 
   Args:
     object_names: a dictionary that maps `Trackable` objects to auto-generated
       string names.
+    object_map: a dictionary mapping `Trackable` to copied `Trackable` objects.
+      The copied objects are generated from `Trackable._map_resources()` which
+      copies the object into another graph. Generally only resource objects
+      (e.g. Variables, Tables) will be in this map.
   Returns:
-    A dictionary mapping Trackables -> a list of _CheckpointFactoryData.
+    A tuple of (
+      Dictionary mapping trackable -> list of _CheckpointFactoryData,
+      Dictionary mapping registered saver name -> {object name -> trackable})
   """
   checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
+  registered_savers = collections.defaultdict(dict)
   for trackable, object_name in object_names.items():
-    checkpoint_factory_map[trackable] = []
-    for name, saveable_factory in (
-        trackable._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
-      checkpoint_key = "%s/%s/%s" % (
-          object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
-      checkpoint_factory_map[trackable].append(_CheckpointFactoryData(
-          factory=saveable_factory,
-          name=name,
-          checkpoint_key=checkpoint_key))
-  return checkpoint_factory_map
+    # object_to_save is only used to retrieve the saving functionality. For keys
+    # and other data, use the original `trackable`.
+    object_to_save = _get_mapped_trackable(trackable, object_map)
+
+    saver_name = registration.get_registered_saver_name(object_to_save)
+    if saver_name:
+      registered_savers[saver_name][object_name] = trackable
+    else:
+      checkpoint_factory_map[trackable] = []
+      for name, saveable_factory in (
+          object_to_save._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
+        checkpoint_key = "%s/%s/%s" % (
+            object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
+        checkpoint_factory_map[trackable].append(_CheckpointFactoryData(
+            factory=saveable_factory,
+            name=name,
+            checkpoint_key=checkpoint_key))
+  return checkpoint_factory_map, registered_savers
+
+
+def _add_attributes_to_object_graph_for_registered_savers(
+    registered_savers, object_graph_proto, node_ids):
+  """Fills the object graph proto with data about the registered savers."""
+  for saver_name, trackables in registered_savers.items():
+    for object_name, trackable in trackables.items():
+      object_proto = object_graph_proto.nodes[node_ids[trackable]]
+      object_proto.registered_saver.name = saver_name
+      object_proto.registered_saver.object_name = object_name
 
 
 @tf_export("__internal__.tracking.ObjectGraphView", v1=[])
@@ -265,6 +299,25 @@
   def _add_attributes_to_object_graph(
       self, trackable_objects, object_graph_proto, node_ids, object_names,
       object_map, call_with_mapped_captures):
+    """Create saveables/savers and corresponding protos in the object graph."""
+    # The loop below creates TrackableObject protos in the TrackableObjectGraph,
+    # which are filled in the `_add_attributes_to_object_graph_for_*` methods.
+    for checkpoint_id, (trackable, unused_object_proto) in enumerate(
+        zip(trackable_objects, object_graph_proto.nodes)):
+      assert node_ids[trackable] == checkpoint_id
+    checkpoint_factory_map, registered_savers = (
+        get_checkpoint_factories_and_keys(object_names, object_map))
+    _add_attributes_to_object_graph_for_registered_savers(
+        registered_savers, object_graph_proto, node_ids)
+    named_saveable_objects, feed_additions = (
+        self._add_attributes_to_object_graph_for_saveable_objects(
+            checkpoint_factory_map, object_graph_proto, node_ids, object_map,
+            call_with_mapped_captures))
+    return named_saveable_objects, feed_additions, registered_savers
+
+  def _add_attributes_to_object_graph_for_saveable_objects(
+      self, checkpoint_factory_map, object_graph_proto, node_ids, object_map,
+      call_with_mapped_captures):
     """Create SaveableObjects and corresponding SerializedTensor protos."""
     named_saveable_objects = []
     if self._saveables_cache is None:
@@ -276,27 +329,15 @@
       # functions computing volatile Python state to be saved with the
       # checkpoint.
       feed_additions = {}
-    if object_map is None:
-      mapped_object_names = object_names
-    else:
-      mapped_object_names = object_identity.ObjectIdentityDictionary()
-      for trackable, name in object_names.items():
-        mapped_object_names[object_map.get(trackable, trackable)] = name
-    checkpoint_factory_map = get_checkpoint_factories_and_keys(
-        mapped_object_names)
-    for checkpoint_id, (trackable, object_proto) in enumerate(
-        zip(trackable_objects, object_graph_proto.nodes)):
-      assert node_ids[trackable] == checkpoint_id
-      if object_map is None:
-        object_to_save = trackable
-      else:
-        object_to_save = object_map.get(trackable, trackable)
+    for trackable, factory_data_list in checkpoint_factory_map.items():
+      object_proto = object_graph_proto.nodes[node_ids[trackable]]
       if self._saveables_cache is not None:
+        object_to_save = _get_mapped_trackable(trackable, object_map)
         cached_attributes = self._saveables_cache.setdefault(object_to_save, {})
       else:
         cached_attributes = None
 
-      for factory_data in checkpoint_factory_map[object_to_save]:
+      for factory_data in factory_data_list:
         attribute = object_proto.attributes.add()
         attribute.name = name = factory_data.name
         attribute.checkpoint_key = key = factory_data.checkpoint_key
@@ -410,7 +451,7 @@
         trackable_objects=trackable_objects,
         node_ids=node_ids,
         slot_variables=slot_variables)
-    named_saveable_objects, feed_additions = (
+    named_saveable_objects, feed_additions, registered_savers = (
         self._add_attributes_to_object_graph(
             trackable_objects=trackable_objects,
             object_graph_proto=object_graph_proto,
@@ -418,7 +459,8 @@
             object_names=object_names,
             object_map=object_map,
             call_with_mapped_captures=call_with_mapped_captures))
-    return named_saveable_objects, object_graph_proto, feed_additions
+    return (named_saveable_objects, object_graph_proto, feed_additions,
+            registered_savers)
 
   def serialize_object_graph(self):
     """Determine checkpoint keys for variables and build a serialized graph.
@@ -441,6 +483,12 @@
     Raises:
       ValueError: If there are invalid characters in an optimizer's slot names.
     """
+    named_saveable_objects, object_graph_proto, feed_additions, _ = (
+        self.serialize_object_graph_with_registered_savers())
+    return named_saveable_objects, object_graph_proto, feed_additions
+
+  def serialize_object_graph_with_registered_savers(self):
+    """Determine checkpoint keys for variables and build a serialized graph."""
     trackable_objects, node_paths = self._breadth_first_traversal()
     return self._serialize_gathered_objects(
         trackable_objects, node_paths)
@@ -448,17 +496,23 @@
   def frozen_saveable_objects(self, object_map=None, to_graph=None,
                               call_with_mapped_captures=None):
     """Creates SaveableObjects with the current object graph frozen."""
+    return self.frozen_saveables_and_savers(object_map, to_graph,
+                                            call_with_mapped_captures)[0]
+
+  def frozen_saveables_and_savers(self, object_map=None, to_graph=None,
+                                  call_with_mapped_captures=None):
+    """Generates SaveableObjects and registered savers in the frozen graph."""
     trackable_objects, node_paths = self._breadth_first_traversal()
     if to_graph:
       target_context = to_graph.as_default
     else:
       target_context = ops.NullContextmanager
     with target_context():
-      named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects(
-          trackable_objects,
-          node_paths,
-          object_map,
-          call_with_mapped_captures)
+      named_saveable_objects, graph_proto, _, registered_savers = (
+          self._serialize_gathered_objects(trackable_objects,
+                                           node_paths,
+                                           object_map,
+                                           call_with_mapped_captures))
       with ops.device("/cpu:0"):
         object_graph_tensor = constant_op.constant(
             graph_proto.SerializeToString(), dtype=dtypes.string)
@@ -466,7 +520,7 @@
           base.NoRestoreSaveable(
               tensor=object_graph_tensor,
               name=base.OBJECT_GRAPH_PROTO_KEY))
-    return named_saveable_objects
+    return named_saveable_objects, registered_savers
 
   def objects_ids_and_slot_variables_and_paths(self):
     """Traverse the object graph and list all accessible objects.
diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py
index 74e6007..fbf745c 100644
--- a/tensorflow/python/training/tracking/util.py
+++ b/tensorflow/python/training/tracking/util.py
@@ -302,13 +302,17 @@
     if self.new_restore_ops_callback:
       self.new_restore_ops_callback(new_ops)  # pylint: disable=not-callable
 
-  def restore_saveables(self, tensor_saveables, python_saveables):
+  def restore_saveables(self,
+                        tensor_saveables,
+                        python_saveables,
+                        registered_savers=None):
     """Run or build restore operations for SaveableObjects.
 
     Args:
       tensor_saveables: `SaveableObject`s which correspond to Tensors.
       python_saveables: `PythonStateSaveable`s which correspond to Python
         values.
+      registered_savers: a dict mapping saver names-> object name -> Trackable.
 
     Returns:
       When graph building, a list of restore operations, either cached or newly
@@ -322,7 +326,7 @@
           [self.reader.get_tensor(name) for name in spec_names])
 
     # If we have new SaveableObjects, extract and cache restore ops.
-    if tensor_saveables:
+    if tensor_saveables or registered_savers:
       validated_saveables = saveable_object_util.validate_and_slice_inputs(
           tensor_saveables)
       validated_names = set(saveable.name for saveable in validated_saveables)
@@ -331,7 +335,8 @@
             "Saveable keys changed when validating. Got back "
             f"{tensor_saveables.keys()}, was expecting {validated_names}")
       new_restore_ops = functional_saver.MultiDeviceSaver(
-          validated_saveables).restore(self.save_path_tensor, self.options)
+          validated_saveables,
+          registered_savers).restore(self.save_path_tensor, self.options)
       if not context.executing_eagerly():
         for name, restore_op in sorted(new_restore_ops.items()):
           restore_ops.append(restore_op)
@@ -1142,8 +1147,8 @@
 
   def _gather_saveables(self, object_graph_tensor=None):
     """Wraps _serialize_object_graph to include the object graph proto."""
-    (named_saveable_objects, graph_proto,
-     feed_additions) = self._graph_view.serialize_object_graph()
+    named_saveable_objects, graph_proto, feed_additions, registered_savers = (
+        self._graph_view.serialize_object_graph_with_registered_savers())
     if object_graph_tensor is None:
       with ops.device("/cpu:0"):
         object_graph_tensor = constant_op.constant(
@@ -1155,7 +1160,8 @@
     named_saveable_objects.append(
         base.NoRestoreSaveable(
             tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
-    return named_saveable_objects, graph_proto, feed_additions
+    return (named_saveable_objects, graph_proto, feed_additions,
+            registered_savers)
 
   def _save_cached_when_graph_building(self,
                                        file_prefix,
@@ -1175,8 +1181,8 @@
       current object graph and any Python state to be saved in the
       checkpoint. When executing eagerly only the first argument is meaningful.
     """
-    (named_saveable_objects, graph_proto,
-     feed_additions) = self._gather_saveables(
+    (named_saveable_objects, graph_proto, feed_additions,
+     registered_savers) = self._gather_saveables(
          object_graph_tensor=object_graph_tensor)
     if (self._last_save_object_graph != graph_proto
         # When executing eagerly, we need to re-create SaveableObjects each time
@@ -1184,7 +1190,8 @@
         # constructors. That means the Saver needs to be copied with a new
         # var_list.
         or context.executing_eagerly() or ops.inside_function()):
-      saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
+      saver = functional_saver.MultiDeviceSaver(named_saveable_objects,
+                                                registered_savers)
       save_op = saver.save(file_prefix, options=options)
       with ops.device("/cpu:0"):
         with ops.control_dependencies([save_op]):
@@ -1420,9 +1427,10 @@
     A saver which saves object-based checkpoints for the object graph frozen at
     the time `frozen_saver` was called.
   """
-  named_saveable_objects = graph_view_lib.ObjectGraphView(
-      root_trackable).frozen_saveable_objects()
-  return functional_saver.MultiDeviceSaver(named_saveable_objects)
+  named_saveable_objects, registered_savers = graph_view_lib.ObjectGraphView(
+      root_trackable).frozen_saveables_and_savers()
+  return functional_saver.MultiDeviceSaver(named_saveable_objects,
+                                           registered_savers)
 
 
 def saver_with_op_caching(obj, attached_dependencies=None):
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-object-graph-view.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-object-graph-view.pbtxt
index 9121914..cf2f65f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-object-graph-view.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-object-graph-view.pbtxt
@@ -23,6 +23,10 @@
     argspec: "args=[\'self\', \'object_map\', \'to_graph\', \'call_with_mapped_captures\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
   }
   member_method {
+    name: "frozen_saveables_and_savers"
+    argspec: "args=[\'self\', \'object_map\', \'to_graph\', \'call_with_mapped_captures\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
     name: "list_children"
     argspec: "args=[\'self\', \'obj\'], varargs=None, keywords=None, defaults=None"
   }
@@ -42,4 +46,8 @@
     name: "serialize_object_graph"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
+  member_method {
+    name: "serialize_object_graph_with_registered_savers"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
 }