Add checkpoint saver registration.
This is a currently-internal API that allows users to register functions for reading and writing checkpoints directly from file.
PiperOrigin-RevId: 405506664
Change-Id: I60c7b6d47f3b8d8b0fa8ff8b7c383d8112ba2301
diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto
index 85bf826..bd9303c 100644
--- a/tensorflow/core/protobuf/saved_object_graph.proto
+++ b/tensorflow/core/protobuf/saved_object_graph.proto
@@ -70,12 +70,26 @@
CapturedTensor captured_tensor = 12;
}
+ // Stores the functions used to save and restore this object. At most one of
+ // `saveable_objects` or `registered_saver` is defined for each SavedObject.
+ // See the comment below for the difference between SaveableObject and
+ // registered savers.
map<string, SaveableObject> saveable_objects = 11;
// The fields below are filled when the user serializes a registered Trackable
- // class. Registered classes may save additional metadata and supersede the
- // default loading process where nodes are recreated from the proto.
+ // class or an object with a registered saver function.
//
+ // Registered classes may save additional metadata and supersede the
+ // default loading process where nodes are recreated from the proto.
+ // If the registered class cannot be found, then the object will load as one
+ // one of the default trackable objects: Autotrackable (a class similar to
+ // tf.Module), tf.function, or tf.Variable.
+ //
+ // Unlike SaveableObjects, which store the functions for saving and restoring
+ // from tensors, registered savers allow Trackables to write checkpoint shards
+ // directly (e.g. for performance or coordination reasons).
+ // *All registered savers must be available when loading the SavedModel.*
+
// The name of the registered class of the form "{package}.{class_name}".
// This field is used to search for the registered class at loading time.
string registered_name = 13;
@@ -83,6 +97,10 @@
// the registered classes's _deserialize_from_proto method when this object is
// loaded from the SavedModel.
google.protobuf.Any serialized_user_proto = 14;
+
+ // String name of the registered saver. At most one of `saveable_objects` or
+ // `registered_saver` is defined for each SavedObject.
+ string registered_saver = 16;
}
// A SavedUserObject is an object (in the object-oriented language of the
@@ -226,6 +244,7 @@
message SaveableObject {
// Node ids of concrete functions for saving and loading from a checkpoint.
+ // These functions save and restore directly from tensors.
int32 save_function = 2;
int32 restore_function = 3;
}
diff --git a/tensorflow/core/protobuf/trackable_object_graph.proto b/tensorflow/core/protobuf/trackable_object_graph.proto
index 4be996b..4ccf19f 100644
--- a/tensorflow/core/protobuf/trackable_object_graph.proto
+++ b/tensorflow/core/protobuf/trackable_object_graph.proto
@@ -54,7 +54,19 @@
repeated SerializedTensor attributes = 2;
// Slot variables owned by this object.
repeated SlotVariableReference slot_variables = 3;
+
+ // The registered saver used to save this object. If this saver is not
+ // present when loading the checkpoint, then loading will fail.
+ RegisteredSaver registered_saver = 4;
}
repeated TrackableObject nodes = 1;
}
+
+message RegisteredSaver {
+ // The name of the registered saver/restore function.
+ string name = 1;
+
+ // Unique auto-generated name of the object.
+ string object_name = 2;
+}
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index dfcde1b..9b8e404 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -551,9 +551,12 @@
srcs = ["registration_saving_test.py"],
deps = [
":registration",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
"//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/training/tracking",
"@absl_py//absl/testing:parameterized",
diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py
index 3a20b27..2b0755e 100644
--- a/tensorflow/python/saved_model/load.py
+++ b/tensorflow/python/saved_model/load.py
@@ -550,6 +550,13 @@
for object_id, obj in dict(checkpoint.object_by_proto_id).items():
position = base.CheckpointPosition(checkpoint=checkpoint,
proto_id=object_id)
+ registered_saver = position.get_registered_saver_name()
+ if registered_saver:
+ raise NotImplementedError(
+ "Loading a SavedModel that uses registered checkpoint saver is "
+ f"not supported in graph mode. The loaded object {obj} uses the "
+ f"saver registered with the name {registered_saver}.")
+
restore_ops = position.restore_ops()
if restore_ops:
if resource_variable_ops.is_resource_variable(obj):
diff --git a/tensorflow/python/saved_model/registration.py b/tensorflow/python/saved_model/registration.py
index cd5d51a..2c61f74 100644
--- a/tensorflow/python/saved_model/registration.py
+++ b/tensorflow/python/saved_model/registration.py
@@ -16,31 +16,124 @@
revived_types registration will be migrated to this infrastructure.
"""
+import collections
+import re
from tensorflow.python.util import tf_inspect
-_CLASS_REGISTRY = {} # string registered name -> (predicate, class)
-_REGISTERED_NAMES = []
+
+# Only allow valid file/directory characters
+_VALID_REGISTERED_NAME = re.compile(r"^[a-zA-Z0-9._-]+$")
-def get_registered_name(obj):
- for name in reversed(_REGISTERED_NAMES):
- predicate, cls = _CLASS_REGISTRY[name]
- if not predicate and type(obj) == cls: # pylint: disable=unidiomatic-typecheck
- return name
- if predicate and predicate(obj):
- return name
- return None
+class _PredicateRegistry(object):
+ """Registry with predicate-based lookup.
+
+ See the documentation for `register_checkpoint_saver` and
+ `register_serializable` for reasons why predicates are required over a
+ class-based registry.
+
+ Since this class is used for global registries, each object must be registered
+ to unique names (an error is raised if there are naming conflicts). The lookup
+ searches the predicates in reverse order, so that later-registered predicates
+ are executed first.
+ """
+ __slots__ = ("_registry_name", "_registered_map", "_registered_predicates",
+ "_registered_names")
+
+ def __init__(self, name):
+ self._registry_name = name
+ # Maps registered name -> object
+ self._registered_map = {}
+ # Maps registered name -> predicate
+ self._registered_predicates = {}
+ # Stores names in the order of registration
+ self._registered_names = []
+
+ @property
+ def name(self):
+ return self._registry_name
+
+ def register(self, package, name, predicate, candidate):
+ """Registers a candidate object under the package, name and predicate."""
+ if not isinstance(package, str) or not isinstance(name, str):
+ raise TypeError(
+ f"The package and name registered to a {self.name} must be strings, "
+ f"got: package={type(package)}, name={type(name)}")
+ if not callable(predicate):
+ raise TypeError(
+ f"The predicate registered to a {self.name} must be callable, "
+ f"got: {type(predicate)}")
+ registered_name = package + "." + name
+ if not _VALID_REGISTERED_NAME.match(registered_name):
+ raise ValueError(
+ f"Invalid registered {self.name}. Please check that the package and "
+ f"name follow the regex '{_VALID_REGISTERED_NAME.pattern}': "
+ f"(package='{package}', name='{name}')")
+ if registered_name in self._registered_map:
+ raise ValueError(
+ f"The name '{registered_name}' has already been registered to a "
+ f"{self.name}. Found: {self._registered_map[registered_name]}")
+
+ self._registered_map[registered_name] = candidate
+ self._registered_predicates[registered_name] = predicate
+ self._registered_names.append(registered_name)
+
+ def lookup(self, obj):
+ """Looks up the registered object using the predicate.
+
+ Args:
+ obj: Object to pass to each of the registered predicates to look up the
+ registered object.
+ Returns:
+ The object registered with the first passing predicate.
+ Raises:
+ LookupError if the object does not match any of the predicate functions.
+ """
+ return self._registered_map[self.get_registered_name(obj)]
+
+ def name_lookup(self, registered_name):
+ """Looks up the registered object using the registered name."""
+ try:
+ return self._registered_map[registered_name]
+ except KeyError:
+ raise LookupError(f"The {self.name} registry does not have name "
+ f"'{registered_name}' registered.")
+
+ def get_registered_name(self, obj):
+ for registered_name in reversed(self._registered_names):
+ predicate = self._registered_predicates[registered_name]
+ if predicate(obj):
+ return registered_name
+ raise LookupError(f"Could not find matching {self.name} for {type(obj)}.")
+
+ def get_predicate(self, registered_name):
+ try:
+ return self._registered_predicates[registered_name]
+ except KeyError:
+ raise LookupError(f"The {self.name} registry does not have name "
+ f"'{registered_name}' registered.")
+
+
+_class_registry = _PredicateRegistry("serializable class")
+_saver_registry = _PredicateRegistry("checkpoint saver")
+
+
+def get_registered_class_name(obj):
+ try:
+ return _class_registry.get_registered_name(obj)
+ except LookupError:
+ return None
def get_registered_class(registered_name):
try:
- return _CLASS_REGISTRY[registered_name][1]
- except KeyError:
+ return _class_registry.name_lookup(registered_name)
+ except LookupError:
return None
-def register_serializable(package="Custom", name=None, predicate=None):
+def register_serializable(package="Custom", name=None, predicate=None): # pylint: disable=unused-argument
"""Decorator for registering a serializable class.
THIS METHOD IS STILL EXPERIMENTAL AND MAY CHANGE AT ANY TIME.
@@ -75,30 +168,194 @@
Returns:
A decorator that registers the decorated class with the passed names and
predicate.
-
- Raises:
- ValueError if predicate is not callable.
"""
- if predicate is not None and not callable(predicate):
- raise ValueError("The `predicate` passed to registered_serializable "
- "must be callable.")
-
def decorator(arg):
"""Registers a class with the serialization framework."""
+ nonlocal predicate
if not tf_inspect.isclass(arg):
- raise ValueError(
- "Registered serializable must be a class: {}".format(arg))
+ raise TypeError("Registered serializable must be a class: {}".format(arg))
class_name = name if name is not None else arg.__name__
- registered_name = package + "." + class_name
-
- if registered_name in _CLASS_REGISTRY:
- raise ValueError("{} has already been registered to {}".format(
- registered_name, _CLASS_REGISTRY[registered_name]))
-
- _CLASS_REGISTRY[registered_name] = (predicate, arg)
- _REGISTERED_NAMES.append(registered_name)
-
+ if predicate is None:
+ predicate = lambda x: isinstance(x, arg)
+ _class_registry.register(package, class_name, predicate, arg)
return arg
return decorator
+
+
+RegisteredSaver = collections.namedtuple(
+ "RegisteredSaver", ["name", "predicate", "save_fn", "restore_fn"])
+_REGISTERED_SAVERS = {}
+_REGISTERED_SAVER_NAMES = [] # Stores names in the order of registration
+
+
+def register_checkpoint_saver(package="Custom",
+ name=None,
+ predicate=None,
+ save_fn=None,
+ restore_fn=None):
+ """Registers functions which checkpoints & restores objects with custom steps.
+
+ If you have a class that requires complicated coordination between multiple
+ objects when checkpointing, then you will need to register a custom saver
+ and restore function. An example of this is a custom Variable class that
+ splits the variable across different objects and devices, and needs to write
+ checkpoints that are compatible with different configurations of devices.
+
+ The registered save and restore functions are used in checkpoints and
+ SavedModel.
+
+ Please make sure you are familiar with the concepts in the [Checkpointing
+ guide](https://www.tensorflow.org/guide/checkpoint), and ops used to save the
+ V2 checkpoint format:
+
+ * io_ops.SaveV2
+ * io_ops.MergeV2Checkpoints
+ * io_ops.RestoreV2
+
+ **Predicate**
+
+ The predicate is a filter that will run on every `Trackable` object connected
+ to the root object. This function determines whether a `Trackable` should use
+ the registered functions.
+
+ Example: `lambda x: isinstance(x, CustomClass)`
+
+ **Custom save function**
+
+ This is how checkpoint saving works normally:
+ 1. Gather all of the Trackables with saveable values.
+ 2. For each Trackable, gather all of the saveable tensors.
+ 3. Save checkpoint shards (grouping tensors by device) with SaveV2
+ 4. Merge the shards with MergeCheckpointV2. This combines all of the shard's
+ metadata, and renames them to follow the standard shard pattern.
+
+ When a saver is registered, Trackables that pass the registered `predicate`
+ are automatically marked as having saveable values. Next, the custom save
+ function replaces steps 2 and 3 of the saving process. Finally, the shards
+ returned by the custom save function are merged with the other shards.
+
+ The save function takes in a dictionary of `Trackables` and a `file_prefix`
+ string. The function should save checkpoint shards using the SaveV2 op, and
+ list of the shard prefixes. SaveV2 is currently required to work a correctly,
+ because the code merges all of the returned shards, and the `restore_fn` will
+ only be given the prefix of the merged checkpoint. If you need to be able to
+ save and restore from unmerged shards, please file a feature request.
+
+ Specification and example of the save function:
+
+ ```
+ def save_fn(trackables, file_prefix):
+ # trackables: A dictionary mapping unique string identifiers to trackables
+ # file_prefix: A unique file prefix generated using the registered name.
+ ...
+ # Gather the tensors to save.
+ ...
+ io_ops.SaveV2(file_prefix, tensor_names, shapes_and_slices, tensors)
+ return file_prefix # Returns a tensor or a list of string tensors
+ ```
+
+ **Custom restore function**
+
+ Normal checkpoint restore behavior:
+ 1. Gather all of the Trackables that have saveable values.
+ 2. For each Trackable, get the names of the desired tensors to extract from
+ the checkpoint.
+ 3. Use RestoreV2 to read the saved values, and pass the restored tensors to
+ the corresponding Trackables.
+
+ The custom restore function replaces steps 2 and 3.
+
+ The restore function also takes a dictionary of `Trackables` and a
+ `merged_prefix` string. The `merged_prefix` is different from the
+ `file_prefix`, since it contains the renamed shard paths. To read from the
+ merged checkpoint, you must use `RestoreV2(merged_prefix, ...)`.
+
+ Specification:
+
+ ```
+ def restore_fn(trackables, merged_prefix):
+ # trackables: A dictionary mapping unique string identifiers to Trackables
+ # merged_prefix: File prefix of the merged shard names.
+
+ restored_tensors = io_ops.restore_v2(
+ merged_prefix, tensor_names, shapes_and_slices, dtypes)
+ ...
+ # Restore the checkpoint values for the given Trackables.
+ ```
+
+ Args:
+ package: Optional, the package that this class belongs to.
+ name: (Required) The name of this saver, which is saved to the checkpoint.
+ When a checkpoint is restored, the name and package are used to find the
+ the matching restore function. The name and package are also used to
+ generate a unique file prefix that is passed to the save_fn.
+ predicate: (Required) A function that returns a boolean indicating whether a
+ `Trackable` object should be checkpointed with this function. Predicates
+ are executed in the reverse order that they are added (later registrations
+ are checked first).
+ save_fn: (Required) A function that takes a dictionary of trackables and a
+ file prefix as the arguments, writes the checkpoint shards for the given
+ Trackables, and returns the list of shard prefixes.
+ restore_fn: (Required) A function that takes a dictionary of trackables and
+ a file prefix as the arguments and restores the trackable values.
+
+ Raises:
+ ValueError: if the package and name are already registered.
+ """
+ if not callable(save_fn):
+ raise TypeError(f"The save_fn must be callable, got: {type(save_fn)}")
+ if not callable(restore_fn):
+ raise TypeError(f"The restore_fn must be callable, got: {type(restore_fn)}")
+
+ _saver_registry.register(package, name, predicate, (save_fn, restore_fn))
+
+
+def get_registered_saver_name(trackable):
+ """Returns the name of the registered saver to use with Trackable."""
+ try:
+ return _saver_registry.get_registered_name(trackable)
+ except LookupError:
+ return None
+
+
+def get_save_function(registered_name):
+ """Returns save function registered to name."""
+ return _saver_registry.name_lookup(registered_name)[0]
+
+
+def get_restore_function(registered_name):
+ """Returns restore function registered to name."""
+ return _saver_registry.name_lookup(registered_name)[1]
+
+
+def validate_restore_function(trackable, registered_name):
+ """Validates whether the trackable can be restored with the saver.
+
+ When using a checkpoint saved with a registered saver, that same saver must
+ also be also registered when loading. The name of that saver is saved to the
+ checkpoint and set in the `registered_name` arg.
+
+ Args:
+ trackable: A `Trackable` object.
+ registered_name: String name of the expected registered saver. This argument
+ should be set using the name saved in a checkpoint.
+
+ Raises:
+ ValueError if the saver could not be found, or if the predicate associated
+ with the saver does not pass.
+ """
+ try:
+ _saver_registry.name_lookup(registered_name)
+ except LookupError:
+ raise ValueError(
+ f"Error when restoring object {trackable} from checkpoint. This "
+ "object was saved using a registered saver named "
+ f"'{registered_name}', but this saver cannot be found in the "
+ "current context.")
+ if not _saver_registry.get_predicate(registered_name)(trackable):
+ raise ValueError(
+ f"Object {trackable} was saved with the registered saver named "
+ f"'{registered_name}'. However, this saver cannot be used to restore the "
+ "object because the predicate does not pass.")
diff --git a/tensorflow/python/saved_model/registration_saving_test.py b/tensorflow/python/saved_model/registration_saving_test.py
index e245219..8d56488 100644
--- a/tensorflow/python/saved_model/registration_saving_test.py
+++ b/tensorflow/python/saved_model/registration_saving_test.py
@@ -14,17 +14,94 @@
# ==============================================================================
"""Tests saving with registered Trackable classes and checkpoint functions."""
+import os
import tempfile
+
from absl.testing import parameterized
from google.protobuf import wrappers_pb2
+from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import registration
from tensorflow.python.saved_model import save
from tensorflow.python.training.tracking import tracking
+from tensorflow.python.training.tracking import util
+
+
+@registration.register_serializable()
+class Part(resource_variable_ops.ResourceVariable):
+
+ def __init__(self, value):
+ self._init_from_args(value)
+
+ @classmethod
+ def _deserialize_from_proto(cls, **kwargs):
+ return cls([0, 0])
+
+
+@registration.register_serializable()
+class Stack(tracking.AutoTrackable):
+
+ def __init__(self, parts=None):
+ self.parts = parts
+
+ @def_function.function(input_signature=[])
+ def value(self):
+ return array_ops.stack(self.parts)
+
+
+def get_tensor_slices(trackables):
+ tensor_names = []
+ shapes_and_slices = []
+ tensors = []
+ restored_trackables = []
+ for obj_prefix, obj in trackables.items():
+ if isinstance(obj, Part):
+ continue # only save stacks
+ tensor_names.append(obj_prefix + "/value")
+ shapes_and_slices.append("")
+ x = obj.value()
+ with ops.device("/device:CPU:0"):
+ tensors.append(array_ops.identity(x))
+ restored_trackables.append(obj)
+
+ return tensor_names, shapes_and_slices, tensors, restored_trackables
+
+
+def save_stacks_and_parts(trackables, file_prefix):
+ """Save stack and part objects to a checkpoint shard."""
+ tensor_names, shapes_and_slices, tensors, _ = get_tensor_slices(trackables)
+ io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices, tensors)
+ return file_prefix
+
+
+def restore_stacks_and_parts(trackables, merged_prefix):
+ tensor_names, shapes_and_slices, tensors, restored_trackables = (
+ get_tensor_slices(trackables))
+ dtypes = [t.dtype for t in tensors]
+ restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names,
+ shapes_and_slices, dtypes)
+ for trackable, restored_tensor in zip(restored_trackables, restored_tensors):
+ expected_shape = trackable.value().get_shape()
+ restored_tensor = array_ops.reshape(restored_tensor, expected_shape)
+ parts = array_ops.unstack(restored_tensor)
+ for part, restored_part in zip(trackable.parts, parts):
+ part.assign(restored_part)
+
+
+registration.register_checkpoint_saver(
+ name="stacks",
+ predicate=lambda x: isinstance(x, (Stack, Part)),
+ save_fn=save_stacks_and_parts,
+ restore_fn=restore_stacks_and_parts)
def cycle(obj, cycles, signatures=None, options=None):
@@ -45,11 +122,10 @@
@parameterized.named_parameters(
dict(testcase_name="ReloadOnce", cycles=1),
dict(testcase_name="ReloadTwice", cycles=2),
- dict(testcase_name="ReloadThrice", cycles=3)
-)
+ dict(testcase_name="ReloadThrice", cycles=3))
class SavedModelTest(test.TestCase, parameterized.TestCase):
- def test_save_and_load(self, cycles):
+ def test_registered_serializable(self, cycles):
@registration.register_serializable(name=f"SaveAndLoad{cycles}")
class Module(tracking.AutoTrackable):
@@ -123,5 +199,53 @@
self.assertIsInstance(loaded, Module)
self.assertEqual(5, loaded.v.numpy())
+ def test_registered_saver(self, cycles):
+ p1 = Part([1, 4])
+ p2 = Part([2, 5])
+ p3 = Part([3, 6])
+ s = Stack([p1, p2, p3])
+ loaded = cycle(s, cycles)
+ self.assertAllEqual(s.value(), loaded.value())
+
+
+class SingleCycleTest(test.TestCase):
+
+ @test_util.deprecated_graph_mode_only()
+ def test_registered_saver_fails_in_saved_model_graph_mode(self):
+ with context.eager_mode():
+ p1 = Part([1, 4])
+ p2 = Part([2, 5])
+ p3 = Part([3, 6])
+ s = Stack([p1, p2, p3])
+ save_dir = os.path.join(self.get_temp_dir(), "save_dir")
+ save.save(s, save_dir)
+
+ with self.assertRaisesRegex(
+ NotImplementedError,
+ "registered checkpoint saver is not supported in graph mode"):
+ load.load(save_dir)
+
+ def test_registered_saver_checkpoint(self):
+ p1 = Part([1, 4])
+ p2 = Part([2, 5])
+ p3 = Part([3, 6])
+ s = Stack([p1, p2, p3])
+ s2 = Stack([p3, p1, p2])
+
+ expected_value_s = s.value()
+ expected_value_s2 = s2.value()
+
+ ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
+ util.Checkpoint(s=s, s2=s2).write(ckpt_path)
+
+ del s, s2, p1, p2, p3
+
+ restore_s = Stack([Part([0, 0]) for _ in range(3)])
+ util.Checkpoint(s=restore_s).read(ckpt_path).expect_partial()
+ self.assertAllEqual(expected_value_s, restore_s.value())
+ util.Checkpoint(s2=restore_s).read(ckpt_path).expect_partial()
+ self.assertAllEqual(expected_value_s2, restore_s.value())
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/saved_model/registration_test.py b/tensorflow/python/saved_model/registration_test.py
index d9c794e..dfa5050 100644
--- a/tensorflow/python/saved_model/registration_test.py
+++ b/tensorflow/python/saved_model/registration_test.py
@@ -1,4 +1,4 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -54,7 +54,7 @@
])
def test_registration(self, expected_cls, expected_name):
obj = expected_cls()
- self.assertEqual(registration.get_registered_name(obj), expected_name)
+ self.assertEqual(registration.get_registered_class_name(obj), expected_name)
self.assertIs(
registration.get_registered_class(expected_name), expected_cls)
@@ -67,7 +67,7 @@
pass
no_register = NotRegistered
- self.assertIsNone(registration.get_registered_name(no_register))
+ self.assertIsNone(registration.get_registered_class_name(no_register))
def test_duplicate_registration(self):
@@ -76,12 +76,13 @@
pass
dup = Duplicate()
- self.assertEqual(registration.get_registered_name(dup), "Custom.Duplicate")
+ self.assertEqual(
+ registration.get_registered_class_name(dup), "Custom.Duplicate")
# Registrations with different names are ok.
registration.register_serializable(package="duplicate")(Duplicate)
# Registrations are checked in reverse order.
self.assertEqual(
- registration.get_registered_name(dup), "duplicate.Duplicate")
+ registration.get_registered_class_name(dup), "duplicate.Duplicate")
# Both names should resolve to the same class.
self.assertIs(
registration.get_registered_class("Custom.Duplicate"), Duplicate)
@@ -96,12 +97,12 @@
def test_register_non_class_fails(self):
obj = RegisteredClass()
- with self.assertRaisesRegex(ValueError, "must be a class"):
+ with self.assertRaisesRegex(TypeError, "must be a class"):
registration.register_serializable()(obj)
def test_register_bad_predicate_fails(self):
- with self.assertRaisesRegex(ValueError, "must be callable"):
- registration.register_serializable(predicate=0)
+ with self.assertRaisesRegex(TypeError, "must be callable"):
+ registration.register_serializable(predicate=0)(RegisteredClass)
def test_predicate(self):
@@ -118,8 +119,9 @@
a = Predicate(True)
b = Predicate(False)
self.assertEqual(
- registration.get_registered_name(a), "Custom.RegisterThisOnlyTrue")
- self.assertIsNone(registration.get_registered_name(b))
+ registration.get_registered_class_name(a),
+ "Custom.RegisterThisOnlyTrue")
+ self.assertIsNone(registration.get_registered_class_name(b))
registration.register_serializable(
name="RegisterAllPredicate",
@@ -127,9 +129,90 @@
Predicate)
self.assertEqual(
- registration.get_registered_name(a), "Custom.RegisterAllPredicate")
+ registration.get_registered_class_name(a),
+ "Custom.RegisterAllPredicate")
self.assertEqual(
- registration.get_registered_name(b), "Custom.RegisterAllPredicate")
+ registration.get_registered_class_name(b),
+ "Custom.RegisterAllPredicate")
+
+
+class CheckpointSaverRegistrationTest(test.TestCase):
+
+ def test_invalid_registration(self):
+ with self.assertRaisesRegex(TypeError, "must be string"):
+ registration.register_checkpoint_saver(
+ package=None,
+ name="test",
+ predicate=lambda: None,
+ save_fn=lambda: None,
+ restore_fn=lambda: None)
+ with self.assertRaisesRegex(TypeError, "must be string"):
+ registration.register_checkpoint_saver(
+ name=None,
+ predicate=lambda: None,
+ save_fn=lambda: None,
+ restore_fn=lambda: None)
+ with self.assertRaisesRegex(ValueError,
+ "Invalid registered checkpoint saver."):
+ registration.register_checkpoint_saver(
+ package="package",
+ name="t/est",
+ predicate=lambda: None,
+ save_fn=lambda: None,
+ restore_fn=lambda: None)
+ with self.assertRaisesRegex(ValueError,
+ "Invalid registered checkpoint saver."):
+ registration.register_checkpoint_saver(
+ package="package",
+ name="t/est",
+ predicate=lambda: None,
+ save_fn=lambda: None,
+ restore_fn=lambda: None)
+ with self.assertRaisesRegex(
+ TypeError,
+ "The predicate registered to a checkpoint saver must be callable"
+ ):
+ registration.register_checkpoint_saver(
+ name="test",
+ predicate=None,
+ save_fn=lambda: None,
+ restore_fn=lambda: None)
+ with self.assertRaisesRegex(TypeError, "The save_fn must be callable"):
+ registration.register_checkpoint_saver(
+ name="test",
+ predicate=lambda: None,
+ save_fn=None,
+ restore_fn=lambda: None)
+ with self.assertRaisesRegex(TypeError, "The restore_fn must be callable"):
+ registration.register_checkpoint_saver(
+ name="test",
+ predicate=lambda: None,
+ save_fn=lambda: None,
+ restore_fn=None)
+
+ def test_registration(self):
+ registration.register_checkpoint_saver(
+ package="Testing",
+ name="test_predicate",
+ predicate=lambda x: hasattr(x, "check_attr"),
+ save_fn=lambda: "save",
+ restore_fn=lambda: "restore")
+ x = base.Trackable()
+ self.assertIsNone(registration.get_registered_saver_name(x))
+
+ x.check_attr = 1
+ saver_name = registration.get_registered_saver_name(x)
+ self.assertEqual(saver_name, "Testing.test_predicate")
+
+ self.assertEqual(registration.get_save_function(saver_name)(), "save")
+ self.assertEqual(registration.get_restore_function(saver_name)(), "restore")
+
+ registration.validate_restore_function(x, "Testing.test_predicate")
+ with self.assertRaisesRegex(ValueError, "saver cannot be found"):
+ registration.validate_restore_function(x, "Invalid.name")
+ x2 = base.Trackable()
+ with self.assertRaisesRegex(ValueError, "saver cannot be used"):
+ registration.validate_restore_function(x2, "Testing.test_predicate")
if __name__ == "__main__":
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 42920b6..598ea3e 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -256,8 +256,12 @@
of the object loaded from the SavedModel. These functions are recorded in
the `saveable_objects` map in the `SavedObject` proto.
"""
- checkpoint_factory_map = graph_view.get_checkpoint_factories_and_keys(
- self.object_names)
+ checkpoint_factory_map, registered_savers = (
+ graph_view.get_checkpoint_factories_and_keys(self.object_names))
+ self._obj_to_registered_saver = object_identity.ObjectIdentityDictionary()
+ for saver_name, trackables in registered_savers.items():
+ for trackable in trackables.values():
+ self._obj_to_registered_saver[trackable] = saver_name
self._saveable_objects_map = (
_gen_save_and_restore_functions(checkpoint_factory_map))
@@ -373,14 +377,18 @@
child_proto.node_id = self.node_ids[ref_function]
child_proto.local_name = local_name
- if node not in self._saveable_objects_map:
- continue
+ if node in self._saveable_objects_map:
+ assert node not in self._obj_to_registered_saver, (
+ "Objects can't have both SaveableObjects and a registered saver")
- for local_name, (save_fn, restore_fn) in (
- self._saveable_objects_map[node].items()):
- saveable_object_proto = object_proto.saveable_objects[local_name]
- saveable_object_proto.save_function = self.node_ids[save_fn]
- saveable_object_proto.restore_function = self.node_ids[restore_fn]
+ for local_name, (save_fn, restore_fn) in (
+ self._saveable_objects_map[node].items()):
+ saveable_object_proto = object_proto.saveable_objects[local_name]
+ saveable_object_proto.save_function = self.node_ids[save_fn]
+ saveable_object_proto.restore_function = self.node_ids[restore_fn]
+
+ elif node in self._obj_to_registered_saver:
+ object_proto.registered_saver = self._obj_to_registered_saver[node]
def map_resources(self):
"""Makes new resource handle ops corresponding to existing resource tensors.
@@ -968,11 +976,15 @@
# gathering from the eager context so Optimizers save the right set of
# variables, but want any operations associated with the save/restore to be in
# the exported graph (thus the `to_graph` argument).
- saver = functional_saver.MultiDeviceSaver(
- saveable_view.checkpoint_view.frozen_saveable_objects(
+ call_with_mapped_captures = functools.partial(
+ _call_function_with_mapped_captures, resource_map=resource_map)
+ named_saveable_objects, registered_savers = (
+ saveable_view.checkpoint_view.frozen_saveables_and_savers(
object_map=object_map, to_graph=exported_graph,
- call_with_mapped_captures=functools.partial(
- _call_function_with_mapped_captures, resource_map=resource_map)))
+ call_with_mapped_captures=call_with_mapped_captures))
+ saver = functional_saver.MultiDeviceSaver(named_saveable_objects,
+ registered_savers,
+ call_with_mapped_captures)
with exported_graph.as_default():
signatures = _generate_signatures(signature_functions, resource_map)
@@ -1151,7 +1163,7 @@
# pylint:enable=protected-access
proto.user_object.CopyFrom(registered_type_proto)
- registered_name = registration.get_registered_name(obj)
+ registered_name = registration.get_registered_class_name(obj)
if registered_name:
proto.registered_name = registered_name
serialized_user_proto = obj._serialize_to_proto() # pylint: disable=protected-access
diff --git a/tensorflow/python/training/saving/BUILD b/tensorflow/python/training/saving/BUILD
index 646d204..180d6c0 100644
--- a/tensorflow/python/training/saving/BUILD
+++ b/tensorflow/python/training/saving/BUILD
@@ -29,6 +29,7 @@
":saveable_object",
":saveable_object_util",
"//tensorflow/python/eager:def_function",
+ "//tensorflow/python/saved_model:registration",
],
)
diff --git a/tensorflow/python/training/saving/functional_saver.py b/tensorflow/python/training/saving/functional_saver.py
index bf7c7d1..1c6e803 100644
--- a/tensorflow/python/training/saving/functional_saver.py
+++ b/tensorflow/python/training/saving/functional_saver.py
@@ -21,10 +21,12 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import string_ops
+from tensorflow.python.saved_model import registration
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import saveable_hook
from tensorflow.python.training.saving import saveable_object
@@ -126,6 +128,44 @@
return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
+def registered_saver_filename(filename_tensor, saver_name):
+ return string_ops.string_join(
+ [filename_tensor, constant_op.constant(f"-{saver_name}")])
+
+
+def _get_mapped_registered_save_fn(fn, trackables, call_with_mapped_captures):
+ """Converts the function to a python or tf.function with a single file arg."""
+ if call_with_mapped_captures is None:
+ def mapped_fn(file_prefix):
+ return fn(trackables=trackables, file_prefix=file_prefix)
+ return mapped_fn
+ else:
+ tf_fn = def_function.function(fn, autograph=False)
+ concrete = tf_fn.get_concrete_function(
+ trackables=trackables,
+ file_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
+ def mapped_fn(file_prefix):
+ return call_with_mapped_captures(concrete, [file_prefix])
+ return mapped_fn
+
+
+def _get_mapped_registered_restore_fn(fn, trackables,
+ call_with_mapped_captures):
+ """Converts the function to a python or tf.function with a single file arg."""
+ if call_with_mapped_captures is None:
+ def mapped_fn(merged_prefix):
+ return fn(trackables=trackables, merged_prefix=merged_prefix)
+ return mapped_fn
+ else:
+ tf_fn = def_function.function(fn, autograph=False)
+ concrete = tf_fn.get_concrete_function(
+ trackables=trackables,
+ merged_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
+ def mapped_fn(merged_prefix):
+ return call_with_mapped_captures(concrete, [merged_prefix])
+ return mapped_fn
+
+
class MultiDeviceSaver(object):
"""Saves checkpoints directly from multiple devices.
@@ -134,7 +174,10 @@
checkpointing are built on top of it.
"""
- def __init__(self, saveable_objects):
+ def __init__(self,
+ saveable_objects,
+ registered_savers=None,
+ call_with_mapped_captures=None):
"""Specify a list of `SaveableObject`s to save and restore.
Args:
@@ -142,6 +185,11 @@
Objects extending `SaveableObject` will be saved and restored, and
objects extending `SaveableHook` will be called into at save and
restore time.
+ registered_savers: A dictionary mapping `registration.RegisteredSaver`
+ namedtuples to a dictionary of named Trackables. The keys of the
+ Trackable dictionary are string names that uniquely identify the
+ Trackable in the checkpoint.
+ call_with_mapped_captures: TODO
"""
self._before_save_callbacks = []
self._after_restore_callbacks = []
@@ -168,6 +216,17 @@
device: _SingleDeviceSaver(saveables)
for device, saveables in saveables_by_device.items()}
+ self._registered_savers = {}
+ if registered_savers:
+ for registered_name, trackables in registered_savers.items():
+ save_fn = _get_mapped_registered_save_fn(
+ registration.get_save_function(registered_name),
+ trackables, call_with_mapped_captures)
+ restore_fn = _get_mapped_registered_restore_fn(
+ registration.get_restore_function(registered_name),
+ trackables, call_with_mapped_captures)
+ self._registered_savers[registered_name] = (save_fn, restore_fn)
+
def to_proto(self):
"""Serializes to a SaverDef referencing the current graph."""
filename_tensor = array_ops.placeholder(
@@ -247,11 +306,29 @@
constant_op.constant("_temp/part"))
tmp_checkpoint_prefix = string_ops.string_join(
[file_prefix, sharded_suffix])
+ registered_paths = {
+ saver_name: registered_saver_filename(file_prefix, saver_name)
+ for saver_name in self._registered_savers
+ }
def save_fn():
+ saved_prefixes = []
+ # Save with the registered savers.
+ for saver_name, (save_fn, _) in self._registered_savers.items():
+ maybe_saved_prefixes = save_fn(registered_paths[saver_name])
+ if maybe_saved_prefixes is not None:
+ flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes)
+ if not all(
+ tensor_util.is_tf_type(x) and x.dtype == dtypes.string
+ for x in flattened_saved_prefixes):
+ raise ValueError(
+ "Registered saver can only return `None` or "
+ f"string type tensors. Got {maybe_saved_prefixes}.")
+ saved_prefixes.extend(flattened_saved_prefixes)
+
+ # (Default saver) Save with single device savers.
num_shards = len(self._single_device_savers)
sharded_saves = []
- sharded_prefixes = []
num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
last_device = None
for shard, (device, saver) in enumerate(
@@ -260,7 +337,7 @@
with ops.device(saveable_object_util.set_cpu0(device)):
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
num_shards_tensor)
- sharded_prefixes.append(shard_prefix)
+ saved_prefixes.append(shard_prefix)
with ops.device(device):
# _SingleDeviceSaver will use the CPU device when necessary, but
# initial read operations should be placed on the SaveableObject's
@@ -278,7 +355,7 @@
# merged, attempts to delete the temporary directory,
# "<user-fed prefix>_temp".
return gen_io_ops.merge_v2_checkpoints(
- sharded_prefixes, file_prefix, delete_old_dirs=True)
+ saved_prefixes, file_prefix, delete_old_dirs=True)
# Since this will causes a function re-trace on each save, limit this to the
# cases where it is needed: eager and when there are multiple tasks/single
@@ -315,7 +392,8 @@
for device, saver in sorted(self._single_device_savers.items()):
with ops.device(device):
restore_ops.update(saver.restore(file_prefix, options))
-
+ for _, (_, restore_fn) in self._registered_savers.items():
+ restore_fn(file_prefix)
return restore_ops
# Since this will causes a function re-trace on each restore, limit this to
diff --git a/tensorflow/python/training/tracking/BUILD b/tensorflow/python/training/tracking/BUILD
index bb5fd05..0510dc9 100644
--- a/tensorflow/python/training/tracking/BUILD
+++ b/tensorflow/python/training/tracking/BUILD
@@ -28,6 +28,7 @@
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
+ "//tensorflow/python/saved_model:registration",
"//tensorflow/python/training/saving:saveable_object",
"@six_archive//:six",
],
diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py
index b7edf5d..cbbe9e8 100644
--- a/tensorflow/python/training/tracking/base.py
+++ b/tensorflow/python/training/tracking/base.py
@@ -26,6 +26,7 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import registration
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
@@ -441,6 +442,9 @@
A list of operations when graph building, or an empty list when executing
eagerly.
"""
+ if self._has_registered_saver():
+ raise ValueError("Unable to run individual checkpoint restore for objects"
+ " with registered savers.")
(restore_ops, tensor_saveables,
python_saveables) = self.gather_ops_or_named_saveables()
restore_ops.extend(
@@ -477,6 +481,16 @@
return self._checkpoint.shape_map[serialized_tensor.checkpoint_key]
return None
+ def _has_registered_saver(self):
+ return bool(self.object_proto.registered_saver.name)
+
+ def get_registered_saver_name(self):
+ if self._has_registered_saver():
+ saver_name = self.object_proto.registered_saver.name
+ registration.validate_restore_function(self.trackable, saver_name)
+ return saver_name
+ return None
+
def _queue_slot_variable_for_restoration(self, optimizer_object, variable,
slot_variable_id, slot_name):
"""Adds a slot variable onto the restoration queue.
@@ -999,16 +1013,27 @@
restore_ops = []
tensor_saveables = {}
python_saveables = []
+ registered_savers = collections.defaultdict(dict)
while visit_queue:
current_position = visit_queue.popleft()
- new_restore_ops, new_tensor_saveables, new_python_saveables = (
- current_position.trackable # pylint: disable=protected-access
- ._single_restoration_from_checkpoint_position(
- checkpoint_position=current_position,
- visit_queue=visit_queue))
- restore_ops.extend(new_restore_ops)
- tensor_saveables.update(new_tensor_saveables)
- python_saveables.extend(new_python_saveables)
+ trackable = current_position.trackable
+
+ # Restore using the ops defined in a Saveable or registered function.
+ registered_saver = current_position.get_registered_saver_name()
+ if registered_saver:
+ object_name = (
+ current_position.object_proto.registered_saver.object_name)
+ registered_savers[registered_saver][object_name] = trackable
+ trackable._self_update_uid = current_position.checkpoint.restore_uid # pylint: disable=protected-access
+ else:
+ new_restore_ops, new_tensor_saveables, new_python_saveables = (
+ trackable._single_restoration_from_checkpoint_position( # pylint: disable=protected-access
+ current_position))
+ restore_ops.extend(new_restore_ops)
+ tensor_saveables.update(new_tensor_saveables)
+ python_saveables.extend(new_python_saveables)
+
+ _queue_children_for_restoration(current_position, visit_queue)
# Restore slot variables first.
#
@@ -1032,11 +1057,11 @@
restore_ops.extend(
current_position.checkpoint.restore_saveables(tensor_saveables,
- python_saveables))
+ python_saveables,
+ registered_savers))
return restore_ops
- def _single_restoration_from_checkpoint_position(self, checkpoint_position,
- visit_queue):
+ def _single_restoration_from_checkpoint_position(self, checkpoint_position):
"""Restore this object, and either queue its dependencies or defer them."""
self._maybe_initialize_trackable()
checkpoint = checkpoint_position.checkpoint
@@ -1051,22 +1076,6 @@
restore_ops = ()
tensor_saveables = {}
python_saveables = ()
- for child in checkpoint_position.object_proto.children:
- child_position = CheckpointPosition(
- checkpoint=checkpoint, proto_id=child.node_id)
- local_object = self._lookup_dependency(child.local_name)
- if local_object is None:
- # We don't yet have a dependency registered with this name. Save it
- # in case we do.
- self._deferred_dependencies.setdefault(child.local_name,
- []).append(child_position)
- else:
- if child_position.bind_object(trackable=local_object):
- # This object's correspondence is new, so dependencies need to be
- # visited. Delay doing it so that we get a breadth-first dependency
- # resolution order (shallowest paths first). The caller is responsible
- # for emptying visit_queue.
- visit_queue.append(child_position)
return restore_ops, tensor_saveables, python_saveables
def _gather_saveables_for_checkpoint(self):
@@ -1320,3 +1329,26 @@
returned must also be in the `_checkpoint_dependencies` dict.
"""
return {}
+
+
+def _queue_children_for_restoration(checkpoint_position, visit_queue):
+ """Queues the restoration of trackable's children or defers them."""
+ # pylint: disable=protected-access
+ trackable = checkpoint_position.trackable
+ checkpoint = checkpoint_position.checkpoint
+ for child in checkpoint_position.object_proto.children:
+ child_position = CheckpointPosition(
+ checkpoint=checkpoint, proto_id=child.node_id)
+ local_object = trackable._lookup_dependency(child.local_name)
+ if local_object is None:
+ # We don't yet have a dependency registered with this name. Save it
+ # in case we do.
+ trackable._deferred_dependencies.setdefault(child.local_name,
+ []).append(child_position)
+ else:
+ if child_position.bind_object(trackable=local_object):
+ # This object's correspondence is new, so dependencies need to be
+ # visited. Delay doing it so that we get a breadth-first dependency
+ # resolution order (shallowest paths first). The caller is responsible
+ # for emptying visit_queue.
+ visit_queue.append(child_position)
diff --git a/tensorflow/python/training/tracking/graph_view.py b/tensorflow/python/training/tracking/graph_view.py
index f51ac86..a94d41d 100644
--- a/tensorflow/python/training/tracking/graph_view.py
+++ b/tensorflow/python/training/tracking/graph_view.py
@@ -21,6 +21,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.saved_model import registration
from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.training.saving import saveable_object as saveable_object_lib
from tensorflow.python.training.saving import saveable_object_util
@@ -141,27 +142,60 @@
return slot_variables
-def get_checkpoint_factories_and_keys(object_names):
+def _get_mapped_trackable(trackable, object_map):
+ """Returns the mapped trackable if possible, otherwise returns trackable."""
+ if object_map is None:
+ return trackable
+ else:
+ return object_map.get(trackable, trackable)
+
+
+def get_checkpoint_factories_and_keys(object_names, object_map=None):
"""Gets a map of saveable factories and corresponding checkpoint keys.
Args:
object_names: a dictionary that maps `Trackable` objects to auto-generated
string names.
+ object_map: a dictionary mapping `Trackable` to copied `Trackable` objects.
+ The copied objects are generated from `Trackable._map_resources()` which
+ copies the object into another graph. Generally only resource objects
+ (e.g. Variables, Tables) will be in this map.
Returns:
- A dictionary mapping Trackables -> a list of _CheckpointFactoryData.
+ A tuple of (
+ Dictionary mapping trackable -> list of _CheckpointFactoryData,
+ Dictionary mapping registered saver name -> {object name -> trackable})
"""
checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
+ registered_savers = collections.defaultdict(dict)
for trackable, object_name in object_names.items():
- checkpoint_factory_map[trackable] = []
- for name, saveable_factory in (
- trackable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
- checkpoint_key = "%s/%s/%s" % (
- object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
- checkpoint_factory_map[trackable].append(_CheckpointFactoryData(
- factory=saveable_factory,
- name=name,
- checkpoint_key=checkpoint_key))
- return checkpoint_factory_map
+ # object_to_save is only used to retrieve the saving functionality. For keys
+ # and other data, use the original `trackable`.
+ object_to_save = _get_mapped_trackable(trackable, object_map)
+
+ saver_name = registration.get_registered_saver_name(object_to_save)
+ if saver_name:
+ registered_savers[saver_name][object_name] = trackable
+ else:
+ checkpoint_factory_map[trackable] = []
+ for name, saveable_factory in (
+ object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
+ checkpoint_key = "%s/%s/%s" % (
+ object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
+ checkpoint_factory_map[trackable].append(_CheckpointFactoryData(
+ factory=saveable_factory,
+ name=name,
+ checkpoint_key=checkpoint_key))
+ return checkpoint_factory_map, registered_savers
+
+
+def _add_attributes_to_object_graph_for_registered_savers(
+ registered_savers, object_graph_proto, node_ids):
+ """Fills the object graph proto with data about the registered savers."""
+ for saver_name, trackables in registered_savers.items():
+ for object_name, trackable in trackables.items():
+ object_proto = object_graph_proto.nodes[node_ids[trackable]]
+ object_proto.registered_saver.name = saver_name
+ object_proto.registered_saver.object_name = object_name
@tf_export("__internal__.tracking.ObjectGraphView", v1=[])
@@ -265,6 +299,25 @@
def _add_attributes_to_object_graph(
self, trackable_objects, object_graph_proto, node_ids, object_names,
object_map, call_with_mapped_captures):
+ """Create saveables/savers and corresponding protos in the object graph."""
+ # The loop below creates TrackableObject protos in the TrackableObjectGraph,
+ # which are filled in the `_add_attributes_to_object_graph_for_*` methods.
+ for checkpoint_id, (trackable, unused_object_proto) in enumerate(
+ zip(trackable_objects, object_graph_proto.nodes)):
+ assert node_ids[trackable] == checkpoint_id
+ checkpoint_factory_map, registered_savers = (
+ get_checkpoint_factories_and_keys(object_names, object_map))
+ _add_attributes_to_object_graph_for_registered_savers(
+ registered_savers, object_graph_proto, node_ids)
+ named_saveable_objects, feed_additions = (
+ self._add_attributes_to_object_graph_for_saveable_objects(
+ checkpoint_factory_map, object_graph_proto, node_ids, object_map,
+ call_with_mapped_captures))
+ return named_saveable_objects, feed_additions, registered_savers
+
+ def _add_attributes_to_object_graph_for_saveable_objects(
+ self, checkpoint_factory_map, object_graph_proto, node_ids, object_map,
+ call_with_mapped_captures):
"""Create SaveableObjects and corresponding SerializedTensor protos."""
named_saveable_objects = []
if self._saveables_cache is None:
@@ -276,27 +329,15 @@
# functions computing volatile Python state to be saved with the
# checkpoint.
feed_additions = {}
- if object_map is None:
- mapped_object_names = object_names
- else:
- mapped_object_names = object_identity.ObjectIdentityDictionary()
- for trackable, name in object_names.items():
- mapped_object_names[object_map.get(trackable, trackable)] = name
- checkpoint_factory_map = get_checkpoint_factories_and_keys(
- mapped_object_names)
- for checkpoint_id, (trackable, object_proto) in enumerate(
- zip(trackable_objects, object_graph_proto.nodes)):
- assert node_ids[trackable] == checkpoint_id
- if object_map is None:
- object_to_save = trackable
- else:
- object_to_save = object_map.get(trackable, trackable)
+ for trackable, factory_data_list in checkpoint_factory_map.items():
+ object_proto = object_graph_proto.nodes[node_ids[trackable]]
if self._saveables_cache is not None:
+ object_to_save = _get_mapped_trackable(trackable, object_map)
cached_attributes = self._saveables_cache.setdefault(object_to_save, {})
else:
cached_attributes = None
- for factory_data in checkpoint_factory_map[object_to_save]:
+ for factory_data in factory_data_list:
attribute = object_proto.attributes.add()
attribute.name = name = factory_data.name
attribute.checkpoint_key = key = factory_data.checkpoint_key
@@ -410,7 +451,7 @@
trackable_objects=trackable_objects,
node_ids=node_ids,
slot_variables=slot_variables)
- named_saveable_objects, feed_additions = (
+ named_saveable_objects, feed_additions, registered_savers = (
self._add_attributes_to_object_graph(
trackable_objects=trackable_objects,
object_graph_proto=object_graph_proto,
@@ -418,7 +459,8 @@
object_names=object_names,
object_map=object_map,
call_with_mapped_captures=call_with_mapped_captures))
- return named_saveable_objects, object_graph_proto, feed_additions
+ return (named_saveable_objects, object_graph_proto, feed_additions,
+ registered_savers)
def serialize_object_graph(self):
"""Determine checkpoint keys for variables and build a serialized graph.
@@ -441,6 +483,12 @@
Raises:
ValueError: If there are invalid characters in an optimizer's slot names.
"""
+ named_saveable_objects, object_graph_proto, feed_additions, _ = (
+ self.serialize_object_graph_with_registered_savers())
+ return named_saveable_objects, object_graph_proto, feed_additions
+
+ def serialize_object_graph_with_registered_savers(self):
+ """Determine checkpoint keys for variables and build a serialized graph."""
trackable_objects, node_paths = self._breadth_first_traversal()
return self._serialize_gathered_objects(
trackable_objects, node_paths)
@@ -448,17 +496,23 @@
def frozen_saveable_objects(self, object_map=None, to_graph=None,
call_with_mapped_captures=None):
"""Creates SaveableObjects with the current object graph frozen."""
+ return self.frozen_saveables_and_savers(object_map, to_graph,
+ call_with_mapped_captures)[0]
+
+ def frozen_saveables_and_savers(self, object_map=None, to_graph=None,
+ call_with_mapped_captures=None):
+ """Generates SaveableObjects and registered savers in the frozen graph."""
trackable_objects, node_paths = self._breadth_first_traversal()
if to_graph:
target_context = to_graph.as_default
else:
target_context = ops.NullContextmanager
with target_context():
- named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects(
- trackable_objects,
- node_paths,
- object_map,
- call_with_mapped_captures)
+ named_saveable_objects, graph_proto, _, registered_savers = (
+ self._serialize_gathered_objects(trackable_objects,
+ node_paths,
+ object_map,
+ call_with_mapped_captures))
with ops.device("/cpu:0"):
object_graph_tensor = constant_op.constant(
graph_proto.SerializeToString(), dtype=dtypes.string)
@@ -466,7 +520,7 @@
base.NoRestoreSaveable(
tensor=object_graph_tensor,
name=base.OBJECT_GRAPH_PROTO_KEY))
- return named_saveable_objects
+ return named_saveable_objects, registered_savers
def objects_ids_and_slot_variables_and_paths(self):
"""Traverse the object graph and list all accessible objects.
diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py
index 74e6007..fbf745c 100644
--- a/tensorflow/python/training/tracking/util.py
+++ b/tensorflow/python/training/tracking/util.py
@@ -302,13 +302,17 @@
if self.new_restore_ops_callback:
self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable
- def restore_saveables(self, tensor_saveables, python_saveables):
+ def restore_saveables(self,
+ tensor_saveables,
+ python_saveables,
+ registered_savers=None):
"""Run or build restore operations for SaveableObjects.
Args:
tensor_saveables: `SaveableObject`s which correspond to Tensors.
python_saveables: `PythonStateSaveable`s which correspond to Python
values.
+ registered_savers: a dict mapping saver names-> object name -> Trackable.
Returns:
When graph building, a list of restore operations, either cached or newly
@@ -322,7 +326,7 @@
[self.reader.get_tensor(name) for name in spec_names])
# If we have new SaveableObjects, extract and cache restore ops.
- if tensor_saveables:
+ if tensor_saveables or registered_savers:
validated_saveables = saveable_object_util.validate_and_slice_inputs(
tensor_saveables)
validated_names = set(saveable.name for saveable in validated_saveables)
@@ -331,7 +335,8 @@
"Saveable keys changed when validating. Got back "
f"{tensor_saveables.keys()}, was expecting {validated_names}")
new_restore_ops = functional_saver.MultiDeviceSaver(
- validated_saveables).restore(self.save_path_tensor, self.options)
+ validated_saveables,
+ registered_savers).restore(self.save_path_tensor, self.options)
if not context.executing_eagerly():
for name, restore_op in sorted(new_restore_ops.items()):
restore_ops.append(restore_op)
@@ -1142,8 +1147,8 @@
def _gather_saveables(self, object_graph_tensor=None):
"""Wraps _serialize_object_graph to include the object graph proto."""
- (named_saveable_objects, graph_proto,
- feed_additions) = self._graph_view.serialize_object_graph()
+ named_saveable_objects, graph_proto, feed_additions, registered_savers = (
+ self._graph_view.serialize_object_graph_with_registered_savers())
if object_graph_tensor is None:
with ops.device("/cpu:0"):
object_graph_tensor = constant_op.constant(
@@ -1155,7 +1160,8 @@
named_saveable_objects.append(
base.NoRestoreSaveable(
tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
- return named_saveable_objects, graph_proto, feed_additions
+ return (named_saveable_objects, graph_proto, feed_additions,
+ registered_savers)
def _save_cached_when_graph_building(self,
file_prefix,
@@ -1175,8 +1181,8 @@
current object graph and any Python state to be saved in the
checkpoint. When executing eagerly only the first argument is meaningful.
"""
- (named_saveable_objects, graph_proto,
- feed_additions) = self._gather_saveables(
+ (named_saveable_objects, graph_proto, feed_additions,
+ registered_savers) = self._gather_saveables(
object_graph_tensor=object_graph_tensor)
if (self._last_save_object_graph != graph_proto
# When executing eagerly, we need to re-create SaveableObjects each time
@@ -1184,7 +1190,8 @@
# constructors. That means the Saver needs to be copied with a new
# var_list.
or context.executing_eagerly() or ops.inside_function()):
- saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
+ saver = functional_saver.MultiDeviceSaver(named_saveable_objects,
+ registered_savers)
save_op = saver.save(file_prefix, options=options)
with ops.device("/cpu:0"):
with ops.control_dependencies([save_op]):
@@ -1420,9 +1427,10 @@
A saver which saves object-based checkpoints for the object graph frozen at
the time `frozen_saver` was called.
"""
- named_saveable_objects = graph_view_lib.ObjectGraphView(
- root_trackable).frozen_saveable_objects()
- return functional_saver.MultiDeviceSaver(named_saveable_objects)
+ named_saveable_objects, registered_savers = graph_view_lib.ObjectGraphView(
+ root_trackable).frozen_saveables_and_savers()
+ return functional_saver.MultiDeviceSaver(named_saveable_objects,
+ registered_savers)
def saver_with_op_caching(obj, attached_dependencies=None):
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-object-graph-view.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-object-graph-view.pbtxt
index 9121914..cf2f65f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-object-graph-view.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-object-graph-view.pbtxt
@@ -23,6 +23,10 @@
argspec: "args=[\'self\', \'object_map\', \'to_graph\', \'call_with_mapped_captures\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
+ name: "frozen_saveables_and_savers"
+ argspec: "args=[\'self\', \'object_map\', \'to_graph\', \'call_with_mapped_captures\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "list_children"
argspec: "args=[\'self\', \'obj\'], varargs=None, keywords=None, defaults=None"
}
@@ -42,4 +46,8 @@
name: "serialize_object_graph"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
+ member_method {
+ name: "serialize_object_graph_with_registered_savers"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
}