| # Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Serialization Registration for SavedModel. |
| |
| revived_types registration will be migrated to this infrastructure. |
| """ |
| import collections |
| import re |
| |
| from tensorflow.python.util import tf_inspect |
| |
| |
| # Only allow valid file/directory characters |
| _VALID_REGISTERED_NAME = re.compile(r"^[a-zA-Z0-9._-]+$") |
| |
| |
| class _PredicateRegistry(object): |
| """Registry with predicate-based lookup. |
| |
| See the documentation for `register_checkpoint_saver` and |
| `register_serializable` for reasons why predicates are required over a |
| class-based registry. |
| |
| Since this class is used for global registries, each object must be registered |
| to unique names (an error is raised if there are naming conflicts). The lookup |
| searches the predicates in reverse order, so that later-registered predicates |
| are executed first. |
| """ |
| __slots__ = ("_registry_name", "_registered_map", "_registered_predicates", |
| "_registered_names") |
| |
| def __init__(self, name): |
| self._registry_name = name |
| # Maps registered name -> object |
| self._registered_map = {} |
| # Maps registered name -> predicate |
| self._registered_predicates = {} |
| # Stores names in the order of registration |
| self._registered_names = [] |
| |
| @property |
| def name(self): |
| return self._registry_name |
| |
| def register(self, package, name, predicate, candidate): |
| """Registers a candidate object under the package, name and predicate.""" |
| if not isinstance(package, str) or not isinstance(name, str): |
| raise TypeError( |
| f"The package and name registered to a {self.name} must be strings, " |
| f"got: package={type(package)}, name={type(name)}") |
| if not callable(predicate): |
| raise TypeError( |
| f"The predicate registered to a {self.name} must be callable, " |
| f"got: {type(predicate)}") |
| registered_name = package + "." + name |
| if not _VALID_REGISTERED_NAME.match(registered_name): |
| raise ValueError( |
| f"Invalid registered {self.name}. Please check that the package and " |
| f"name follow the regex '{_VALID_REGISTERED_NAME.pattern}': " |
| f"(package='{package}', name='{name}')") |
| if registered_name in self._registered_map: |
| raise ValueError( |
| f"The name '{registered_name}' has already been registered to a " |
| f"{self.name}. Found: {self._registered_map[registered_name]}") |
| |
| self._registered_map[registered_name] = candidate |
| self._registered_predicates[registered_name] = predicate |
| self._registered_names.append(registered_name) |
| |
| def lookup(self, obj): |
| """Looks up the registered object using the predicate. |
| |
| Args: |
| obj: Object to pass to each of the registered predicates to look up the |
| registered object. |
| Returns: |
| The object registered with the first passing predicate. |
| Raises: |
| LookupError if the object does not match any of the predicate functions. |
| """ |
| return self._registered_map[self.get_registered_name(obj)] |
| |
| def name_lookup(self, registered_name): |
| """Looks up the registered object using the registered name.""" |
| try: |
| return self._registered_map[registered_name] |
| except KeyError: |
| raise LookupError(f"The {self.name} registry does not have name " |
| f"'{registered_name}' registered.") |
| |
| def get_registered_name(self, obj): |
| for registered_name in reversed(self._registered_names): |
| predicate = self._registered_predicates[registered_name] |
| if predicate(obj): |
| return registered_name |
| raise LookupError(f"Could not find matching {self.name} for {type(obj)}.") |
| |
| def get_predicate(self, registered_name): |
| try: |
| return self._registered_predicates[registered_name] |
| except KeyError: |
| raise LookupError(f"The {self.name} registry does not have name " |
| f"'{registered_name}' registered.") |
| |
| |
| _class_registry = _PredicateRegistry("serializable class") |
| _saver_registry = _PredicateRegistry("checkpoint saver") |
| |
| |
| def get_registered_class_name(obj): |
| try: |
| return _class_registry.get_registered_name(obj) |
| except LookupError: |
| return None |
| |
| |
| def get_registered_class(registered_name): |
| try: |
| return _class_registry.name_lookup(registered_name) |
| except LookupError: |
| return None |
| |
| |
| def register_serializable(package="Custom", name=None, predicate=None): # pylint: disable=unused-argument |
| """Decorator for registering a serializable class. |
| |
| THIS METHOD IS STILL EXPERIMENTAL AND MAY CHANGE AT ANY TIME. |
| |
| Registered classes will be saved with a name generated by combining the |
| `package` and `name` arguments. When loading a SavedModel, modules saved with |
| this registered name will be created using the `_deserialize_from_proto` |
| method. |
| |
| By default, only direct instances of the registered class will be saved/ |
| restored with the `serialize_from_proto`/`deserialize_from_proto` methods. To |
| extend the registration to subclasses, use the `predicate argument`: |
| |
| ```python |
| class A(tf.Module): |
| pass |
| |
| register_serializable( |
| package="Example", predicate=lambda obj: isinstance(obj, A))(A) |
| ``` |
| |
| Args: |
| package: The package that this class belongs to. |
| name: The name to serialize this class under in this package. If None, the |
| class's name will be used. |
| predicate: An optional function that takes a single Trackable argument, and |
| determines whether that object should be serialized with this `package` |
| and `name`. The default predicate checks whether the object's type exactly |
| matches the registered class. Predicates are executed in the reverse order |
| that they are added (later registrations are checked first). |
| |
| Returns: |
| A decorator that registers the decorated class with the passed names and |
| predicate. |
| """ |
| def decorator(arg): |
| """Registers a class with the serialization framework.""" |
| nonlocal predicate |
| if not tf_inspect.isclass(arg): |
| raise TypeError("Registered serializable must be a class: {}".format(arg)) |
| |
| class_name = name if name is not None else arg.__name__ |
| if predicate is None: |
| predicate = lambda x: isinstance(x, arg) |
| _class_registry.register(package, class_name, predicate, arg) |
| return arg |
| |
| return decorator |
| |
| |
| RegisteredSaver = collections.namedtuple( |
| "RegisteredSaver", ["name", "predicate", "save_fn", "restore_fn"]) |
| _REGISTERED_SAVERS = {} |
| _REGISTERED_SAVER_NAMES = [] # Stores names in the order of registration |
| |
| |
| def register_checkpoint_saver(package="Custom", |
| name=None, |
| predicate=None, |
| save_fn=None, |
| restore_fn=None): |
| """Registers functions which checkpoints & restores objects with custom steps. |
| |
| If you have a class that requires complicated coordination between multiple |
| objects when checkpointing, then you will need to register a custom saver |
| and restore function. An example of this is a custom Variable class that |
| splits the variable across different objects and devices, and needs to write |
| checkpoints that are compatible with different configurations of devices. |
| |
| The registered save and restore functions are used in checkpoints and |
| SavedModel. |
| |
| Please make sure you are familiar with the concepts in the [Checkpointing |
| guide](https://www.tensorflow.org/guide/checkpoint), and ops used to save the |
| V2 checkpoint format: |
| |
| * io_ops.SaveV2 |
| * io_ops.MergeV2Checkpoints |
| * io_ops.RestoreV2 |
| |
| **Predicate** |
| |
| The predicate is a filter that will run on every `Trackable` object connected |
| to the root object. This function determines whether a `Trackable` should use |
| the registered functions. |
| |
| Example: `lambda x: isinstance(x, CustomClass)` |
| |
| **Custom save function** |
| |
| This is how checkpoint saving works normally: |
| 1. Gather all of the Trackables with saveable values. |
| 2. For each Trackable, gather all of the saveable tensors. |
| 3. Save checkpoint shards (grouping tensors by device) with SaveV2 |
| 4. Merge the shards with MergeCheckpointV2. This combines all of the shard's |
| metadata, and renames them to follow the standard shard pattern. |
| |
| When a saver is registered, Trackables that pass the registered `predicate` |
| are automatically marked as having saveable values. Next, the custom save |
| function replaces steps 2 and 3 of the saving process. Finally, the shards |
| returned by the custom save function are merged with the other shards. |
| |
| The save function takes in a dictionary of `Trackables` and a `file_prefix` |
| string. The function should save checkpoint shards using the SaveV2 op, and |
| list of the shard prefixes. SaveV2 is currently required to work a correctly, |
| because the code merges all of the returned shards, and the `restore_fn` will |
| only be given the prefix of the merged checkpoint. If you need to be able to |
| save and restore from unmerged shards, please file a feature request. |
| |
| Specification and example of the save function: |
| |
| ``` |
| def save_fn(trackables, file_prefix): |
| # trackables: A dictionary mapping unique string identifiers to trackables |
| # file_prefix: A unique file prefix generated using the registered name. |
| ... |
| # Gather the tensors to save. |
| ... |
| io_ops.SaveV2(file_prefix, tensor_names, shapes_and_slices, tensors) |
| return file_prefix # Returns a tensor or a list of string tensors |
| ``` |
| |
| **Custom restore function** |
| |
| Normal checkpoint restore behavior: |
| 1. Gather all of the Trackables that have saveable values. |
| 2. For each Trackable, get the names of the desired tensors to extract from |
| the checkpoint. |
| 3. Use RestoreV2 to read the saved values, and pass the restored tensors to |
| the corresponding Trackables. |
| |
| The custom restore function replaces steps 2 and 3. |
| |
| The restore function also takes a dictionary of `Trackables` and a |
| `merged_prefix` string. The `merged_prefix` is different from the |
| `file_prefix`, since it contains the renamed shard paths. To read from the |
| merged checkpoint, you must use `RestoreV2(merged_prefix, ...)`. |
| |
| Specification: |
| |
| ``` |
| def restore_fn(trackables, merged_prefix): |
| # trackables: A dictionary mapping unique string identifiers to Trackables |
| # merged_prefix: File prefix of the merged shard names. |
| |
| restored_tensors = io_ops.restore_v2( |
| merged_prefix, tensor_names, shapes_and_slices, dtypes) |
| ... |
| # Restore the checkpoint values for the given Trackables. |
| ``` |
| |
| Args: |
| package: Optional, the package that this class belongs to. |
| name: (Required) The name of this saver, which is saved to the checkpoint. |
| When a checkpoint is restored, the name and package are used to find the |
| the matching restore function. The name and package are also used to |
| generate a unique file prefix that is passed to the save_fn. |
| predicate: (Required) A function that returns a boolean indicating whether a |
| `Trackable` object should be checkpointed with this function. Predicates |
| are executed in the reverse order that they are added (later registrations |
| are checked first). |
| save_fn: (Required) A function that takes a dictionary of trackables and a |
| file prefix as the arguments, writes the checkpoint shards for the given |
| Trackables, and returns the list of shard prefixes. |
| restore_fn: (Required) A function that takes a dictionary of trackables and |
| a file prefix as the arguments and restores the trackable values. |
| |
| Raises: |
| ValueError: if the package and name are already registered. |
| """ |
| if not callable(save_fn): |
| raise TypeError(f"The save_fn must be callable, got: {type(save_fn)}") |
| if not callable(restore_fn): |
| raise TypeError(f"The restore_fn must be callable, got: {type(restore_fn)}") |
| |
| _saver_registry.register(package, name, predicate, (save_fn, restore_fn)) |
| |
| |
| def get_registered_saver_name(trackable): |
| """Returns the name of the registered saver to use with Trackable.""" |
| try: |
| return _saver_registry.get_registered_name(trackable) |
| except LookupError: |
| return None |
| |
| |
| def get_save_function(registered_name): |
| """Returns save function registered to name.""" |
| return _saver_registry.name_lookup(registered_name)[0] |
| |
| |
| def get_restore_function(registered_name): |
| """Returns restore function registered to name.""" |
| return _saver_registry.name_lookup(registered_name)[1] |
| |
| |
| def validate_restore_function(trackable, registered_name): |
| """Validates whether the trackable can be restored with the saver. |
| |
| When using a checkpoint saved with a registered saver, that same saver must |
| also be also registered when loading. The name of that saver is saved to the |
| checkpoint and set in the `registered_name` arg. |
| |
| Args: |
| trackable: A `Trackable` object. |
| registered_name: String name of the expected registered saver. This argument |
| should be set using the name saved in a checkpoint. |
| |
| Raises: |
| ValueError if the saver could not be found, or if the predicate associated |
| with the saver does not pass. |
| """ |
| try: |
| _saver_registry.name_lookup(registered_name) |
| except LookupError: |
| raise ValueError( |
| f"Error when restoring object {trackable} from checkpoint. This " |
| "object was saved using a registered saver named " |
| f"'{registered_name}', but this saver cannot be found in the " |
| "current context.") |
| if not _saver_registry.get_predicate(registered_name)(trackable): |
| raise ValueError( |
| f"Object {trackable} was saved with the registered saver named " |
| f"'{registered_name}'. However, this saver cannot be used to restore the " |
| "object because the predicate does not pass.") |