blob: a94d41da88c7fac9ff58fce30569b2d72b33ac96 [file] [log] [blame]
"""Manages a graph of Trackable objects."""
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import collections
import copy
import weakref
from tensorflow.core.protobuf import trackable_object_graph_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import registration
from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.training.saving import saveable_object as saveable_object_lib
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names.
# Keyword for identifying that the next bit of a checkpoint variable name is a
# slot name. Checkpoint names for slot variables look like:
#
# <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name>
#
# Where <path to variable> is a full path from the checkpoint root to the
# variable being slotted for.
_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"
# Keyword for separating the path to an object from the name of an
# attribute in checkpoint names. Used like:
# <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute>
_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
# Factory and related info used to build a SaveableObject that saves a Trackable
# to checkpoint.
_CheckpointFactoryData = collections.namedtuple(
"_CheckpointFactoryData", ["factory", "name", "checkpoint_key"])
def _escape_local_name(name):
# We need to support slashes in local names for compatibility, since this
# naming scheme is being patched in to things like Layer.add_variable where
# slashes were previously accepted. We also want to use slashes to indicate
# edges traversed to reach the variable, so we escape forward slashes in
# names.
return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR)
.replace(r"/", _ESCAPE_CHAR + "S"))
def _object_prefix_from_path(node_paths):
return "/".join(
(_escape_local_name(trackable.name)
for trackable in node_paths))
def _slot_variable_naming_for_optimizer(optimizer_path):
"""Make a function for naming slot variables in an optimizer."""
# Name slot variables:
#
# <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name>
#
# where <variable name> is exactly the checkpoint name used for the original
# variable, including the path from the checkpoint root and the local name in
# the object which owns it. Note that we only save slot variables if the
# variable it's slotting for is also being saved.
optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path)
def _name_slot_variable(variable_path, slot_name):
"""With an optimizer specified, name a slot variable."""
return (variable_path
+ optimizer_identifier
+ _escape_local_name(slot_name))
return _name_slot_variable
def _serialize_slot_variables(trackable_objects, node_ids, object_names):
"""Gather and name slot variables."""
non_slot_objects = list(trackable_objects)
slot_variables = object_identity.ObjectIdentityDictionary()
for trackable in non_slot_objects:
if (isinstance(trackable, optimizer_v1.Optimizer)
# TODO(b/110718070): Fix Keras imports.
# Note: dir() is used rather than hasattr() here to avoid triggering
# custom __getattr__ code, see b/152031870 for context.
or "_create_or_restore_slot_variable" in dir(trackable)):
naming_scheme = _slot_variable_naming_for_optimizer(
optimizer_path=object_names[trackable])
slot_names = trackable.get_slot_names()
for slot_name in slot_names:
for original_variable_node_id, original_variable in enumerate(
non_slot_objects):
try:
slot_variable = trackable.get_slot(
original_variable, slot_name)
except (AttributeError, KeyError):
slot_variable = None
if slot_variable is None:
continue
slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access
if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access
# TODO(allenl): Gather dependencies of slot variables.
raise NotImplementedError(
"Currently only variables with no dependencies can be saved as "
"slot variables. File a feature request if this limitation "
"bothers you.")
if slot_variable in node_ids:
raise NotImplementedError(
"A slot variable was re-used as a dependency of a Trackable "
f"object: {slot_variable}. This is not currently allowed. "
"File a feature request if this limitation bothers you.")
checkpoint_name = naming_scheme(
variable_path=object_names[original_variable],
slot_name=slot_name)
object_names[slot_variable] = checkpoint_name
slot_variable_node_id = len(trackable_objects)
node_ids[slot_variable] = slot_variable_node_id
trackable_objects.append(slot_variable)
slot_variable_proto = (
trackable_object_graph_pb2.TrackableObjectGraph
.TrackableObject.SlotVariableReference(
slot_name=slot_name,
original_variable_node_id=original_variable_node_id,
slot_variable_node_id=slot_variable_node_id))
slot_variables.setdefault(trackable, []).append(
slot_variable_proto)
return slot_variables
def _get_mapped_trackable(trackable, object_map):
"""Returns the mapped trackable if possible, otherwise returns trackable."""
if object_map is None:
return trackable
else:
return object_map.get(trackable, trackable)
def get_checkpoint_factories_and_keys(object_names, object_map=None):
"""Gets a map of saveable factories and corresponding checkpoint keys.
Args:
object_names: a dictionary that maps `Trackable` objects to auto-generated
string names.
object_map: a dictionary mapping `Trackable` to copied `Trackable` objects.
The copied objects are generated from `Trackable._map_resources()` which
copies the object into another graph. Generally only resource objects
(e.g. Variables, Tables) will be in this map.
Returns:
A tuple of (
Dictionary mapping trackable -> list of _CheckpointFactoryData,
Dictionary mapping registered saver name -> {object name -> trackable})
"""
checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
registered_savers = collections.defaultdict(dict)
for trackable, object_name in object_names.items():
# object_to_save is only used to retrieve the saving functionality. For keys
# and other data, use the original `trackable`.
object_to_save = _get_mapped_trackable(trackable, object_map)
saver_name = registration.get_registered_saver_name(object_to_save)
if saver_name:
registered_savers[saver_name][object_name] = trackable
else:
checkpoint_factory_map[trackable] = []
for name, saveable_factory in (
object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
checkpoint_key = "%s/%s/%s" % (
object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
checkpoint_factory_map[trackable].append(_CheckpointFactoryData(
factory=saveable_factory,
name=name,
checkpoint_key=checkpoint_key))
return checkpoint_factory_map, registered_savers
def _add_attributes_to_object_graph_for_registered_savers(
registered_savers, object_graph_proto, node_ids):
"""Fills the object graph proto with data about the registered savers."""
for saver_name, trackables in registered_savers.items():
for object_name, trackable in trackables.items():
object_proto = object_graph_proto.nodes[node_ids[trackable]]
object_proto.registered_saver.name = saver_name
object_proto.registered_saver.object_name = object_name
@tf_export("__internal__.tracking.ObjectGraphView", v1=[])
class ObjectGraphView(object):
"""Gathers and serializes an object graph."""
def __init__(self, root, saveables_cache=None, attached_dependencies=None):
"""Configure the graph view.
Args:
root: A `Trackable` object whose variables (including the variables
of dependencies, recursively) should be saved. May be a weak reference.
saveables_cache: A dictionary mapping `Trackable` objects ->
attribute names -> SaveableObjects, used to avoid re-creating
SaveableObjects when graph building.
attached_dependencies: Dependencies to attach to the root object. Used
when saving a Checkpoint with a defined root object.
"""
self._root_ref = root
self._saveables_cache = saveables_cache
self._attached_dependencies = attached_dependencies
def __deepcopy__(self, memo):
if isinstance(self._root_ref, weakref.ref):
# By default, weak references are not copied, which leads to surprising
# deepcopy behavior. To fix, we first we copy the object itself, then we
# make a weak reference to the copy.
strong_root = self._root_ref()
if strong_root is not None:
strong_copy = copy.deepcopy(strong_root, memo)
memo[id(self._root_ref)] = weakref.ref(strong_copy)
# super() does not have a __deepcopy__, so we need to re-implement it
copied = super().__new__(type(self))
memo[id(self)] = copied
for key, value in vars(self).items():
setattr(copied, key, copy.deepcopy(value, memo))
return copied
def list_children(self, obj):
# pylint: disable=protected-access
obj._maybe_initialize_trackable()
dependencies = obj._checkpoint_dependencies
# pylint: enable=protected-access
if obj is self.root and self._attached_dependencies:
dependencies = dependencies.copy()
dependencies.extend(self._attached_dependencies)
return dependencies
@property
def saveables_cache(self):
"""Maps Trackable objects -> attribute names -> list(SaveableObjects).
Used to avoid re-creating SaveableObjects when graph building. None when
executing eagerly.
Returns:
The cache (an object-identity dictionary), or None if caching is disabled.
"""
return self._saveables_cache
@property
def attached_dependencies(self):
"""Returns list of dependencies that should be saved in the checkpoint.
These dependencies are not tracked by root, but are in the checkpoint.
This is defined when the user creates a Checkpoint with both root and kwargs
set.
Returns:
A list of TrackableReferences.
"""
return self._attached_dependencies
@property
def root(self):
if isinstance(self._root_ref, weakref.ref):
derefed = self._root_ref()
assert derefed is not None
return derefed
else:
return self._root_ref
def _breadth_first_traversal(self):
"""Find shortest paths to all dependencies of self.root."""
bfs_sorted = []
to_visit = collections.deque([self.root])
node_paths = object_identity.ObjectIdentityDictionary()
node_paths[self.root] = ()
while to_visit:
current_trackable = to_visit.popleft()
bfs_sorted.append(current_trackable)
for name, dependency in self.list_children(current_trackable):
if dependency not in node_paths:
node_paths[dependency] = (
node_paths[current_trackable] + (
base.TrackableReference(name, dependency),))
to_visit.append(dependency)
return bfs_sorted, node_paths
def _add_attributes_to_object_graph(
self, trackable_objects, object_graph_proto, node_ids, object_names,
object_map, call_with_mapped_captures):
"""Create saveables/savers and corresponding protos in the object graph."""
# The loop below creates TrackableObject protos in the TrackableObjectGraph,
# which are filled in the `_add_attributes_to_object_graph_for_*` methods.
for checkpoint_id, (trackable, unused_object_proto) in enumerate(
zip(trackable_objects, object_graph_proto.nodes)):
assert node_ids[trackable] == checkpoint_id
checkpoint_factory_map, registered_savers = (
get_checkpoint_factories_and_keys(object_names, object_map))
_add_attributes_to_object_graph_for_registered_savers(
registered_savers, object_graph_proto, node_ids)
named_saveable_objects, feed_additions = (
self._add_attributes_to_object_graph_for_saveable_objects(
checkpoint_factory_map, object_graph_proto, node_ids, object_map,
call_with_mapped_captures))
return named_saveable_objects, feed_additions, registered_savers
def _add_attributes_to_object_graph_for_saveable_objects(
self, checkpoint_factory_map, object_graph_proto, node_ids, object_map,
call_with_mapped_captures):
"""Create SaveableObjects and corresponding SerializedTensor protos."""
named_saveable_objects = []
if self._saveables_cache is None:
# No SaveableObject caching. Either we're executing eagerly, or building a
# static save which is specialized to the current Python state.
feed_additions = None
else:
# If we are caching SaveableObjects, we need to build up a feed_dict with
# functions computing volatile Python state to be saved with the
# checkpoint.
feed_additions = {}
for trackable, factory_data_list in checkpoint_factory_map.items():
object_proto = object_graph_proto.nodes[node_ids[trackable]]
if self._saveables_cache is not None:
object_to_save = _get_mapped_trackable(trackable, object_map)
cached_attributes = self._saveables_cache.setdefault(object_to_save, {})
else:
cached_attributes = None
for factory_data in factory_data_list:
attribute = object_proto.attributes.add()
attribute.name = name = factory_data.name
attribute.checkpoint_key = key = factory_data.checkpoint_key
saveable_factory = factory_data.factory
# See if we can skip saving this checkpoint key.
saveables = cached_attributes.get(name) if cached_attributes else None
if saveables is not None:
for saveable in saveables:
if key not in saveable.name:
# The checkpoint key for this SaveableObject is different. We
# need to re-create it.
saveables = None
del cached_attributes[name]
break
if saveables is None:
if callable(saveable_factory):
maybe_saveable = saveable_object_util.create_saveable_object(
saveable_factory, key, call_with_mapped_captures)
else:
maybe_saveable = saveable_factory
if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
saveables = (maybe_saveable,)
else:
# Figure out the name-based Saver's name for this variable. If it's
# already a SaveableObject we'd just get the checkpoint key back, so
# we leave full_name blank.
saver_dict = saveable_object_util.op_list_to_dict(
[maybe_saveable], convert_variable_to_tensor=False)
full_name, = saver_dict.keys()
saveables = tuple(saveable_object_util.saveable_objects_for_op(
op=maybe_saveable, name=key))
for saveable in saveables:
saveable.full_name = full_name
for saveable in saveables:
if key not in saveable.name:
raise AssertionError(
f"The object {trackable} produced a SaveableObject with name "
f"'{saveable.name}' for attribute '{name}'. Expected a name"
f" containing '{key}'.")
if cached_attributes is not None:
cached_attributes[name] = saveables
optional_restore = None
for saveable in saveables:
if optional_restore is None:
optional_restore = saveable.optional_restore
else:
optional_restore = optional_restore and saveable.optional_restore
if hasattr(saveable, "full_name"):
attribute.full_name = saveable.full_name
if isinstance(saveable, base.PythonStateSaveable):
if feed_additions is None:
assert self._saveables_cache is None
# If we're not caching saveables, then we're either executing
# eagerly or building a static save/restore (e.g. for a
# SavedModel). In either case, we should embed the current Python
# state in the graph rather than relying on a feed dict.
saveable = saveable.freeze()
else:
saveable_feed_dict = saveable.feed_dict_additions()
for new_feed_key in saveable_feed_dict.keys():
if new_feed_key in feed_additions:
raise AssertionError(
f"The object {trackable} tried to feed a value for the "
f"Tensor {new_feed_key} when saving, but another object "
"is already feeding a value.")
feed_additions.update(saveable_feed_dict)
named_saveable_objects.append(saveable)
if optional_restore is None:
optional_restore = False
attribute.optional_restore = optional_restore
return named_saveable_objects, feed_additions
def _fill_object_graph_proto(self, trackable_objects,
node_ids,
slot_variables,
object_graph_proto=None):
"""Name non-slot `Trackable`s and add them to `object_graph_proto`."""
if object_graph_proto is None:
object_graph_proto = (
trackable_object_graph_pb2.TrackableObjectGraph())
for checkpoint_id, trackable in enumerate(trackable_objects):
assert node_ids[trackable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
object_proto.slot_variables.extend(slot_variables.get(trackable, ()))
for child in self.list_children(trackable):
child_proto = object_proto.children.add()
child_proto.node_id = node_ids[child.ref]
child_proto.local_name = child.name
return object_graph_proto
def _serialize_gathered_objects(self, trackable_objects, node_paths,
object_map=None,
call_with_mapped_captures=None):
"""Create SaveableObjects and protos for gathered objects."""
object_names = object_identity.ObjectIdentityDictionary()
for obj, path in node_paths.items():
object_names[obj] = _object_prefix_from_path(path)
node_ids = object_identity.ObjectIdentityDictionary()
for node_id, node in enumerate(trackable_objects):
node_ids[node] = node_id
slot_variables = _serialize_slot_variables(
trackable_objects=trackable_objects,
node_ids=node_ids,
object_names=object_names)
object_graph_proto = self._fill_object_graph_proto(
trackable_objects=trackable_objects,
node_ids=node_ids,
slot_variables=slot_variables)
named_saveable_objects, feed_additions, registered_savers = (
self._add_attributes_to_object_graph(
trackable_objects=trackable_objects,
object_graph_proto=object_graph_proto,
node_ids=node_ids,
object_names=object_names,
object_map=object_map,
call_with_mapped_captures=call_with_mapped_captures))
return (named_saveable_objects, object_graph_proto, feed_additions,
registered_savers)
def serialize_object_graph(self):
"""Determine checkpoint keys for variables and build a serialized graph.
Non-slot variables are keyed based on a shortest path from the root saveable
to the object which owns the variable (i.e. the one which called
`Trackable._add_variable` to create it).
Slot variables are keyed based on a shortest path to the variable being
slotted for, a shortest path to their optimizer, and the slot name.
Returns:
A tuple of (named_variables, object_graph_proto, feed_additions):
named_variables: A dictionary mapping names to variable objects.
object_graph_proto: A TrackableObjectGraph protocol buffer
containing the serialized object graph and variable references.
feed_additions: A dictionary mapping from Tensors to values which should
be fed when saving.
Raises:
ValueError: If there are invalid characters in an optimizer's slot names.
"""
named_saveable_objects, object_graph_proto, feed_additions, _ = (
self.serialize_object_graph_with_registered_savers())
return named_saveable_objects, object_graph_proto, feed_additions
def serialize_object_graph_with_registered_savers(self):
"""Determine checkpoint keys for variables and build a serialized graph."""
trackable_objects, node_paths = self._breadth_first_traversal()
return self._serialize_gathered_objects(
trackable_objects, node_paths)
def frozen_saveable_objects(self, object_map=None, to_graph=None,
call_with_mapped_captures=None):
"""Creates SaveableObjects with the current object graph frozen."""
return self.frozen_saveables_and_savers(object_map, to_graph,
call_with_mapped_captures)[0]
def frozen_saveables_and_savers(self, object_map=None, to_graph=None,
call_with_mapped_captures=None):
"""Generates SaveableObjects and registered savers in the frozen graph."""
trackable_objects, node_paths = self._breadth_first_traversal()
if to_graph:
target_context = to_graph.as_default
else:
target_context = ops.NullContextmanager
with target_context():
named_saveable_objects, graph_proto, _, registered_savers = (
self._serialize_gathered_objects(trackable_objects,
node_paths,
object_map,
call_with_mapped_captures))
with ops.device("/cpu:0"):
object_graph_tensor = constant_op.constant(
graph_proto.SerializeToString(), dtype=dtypes.string)
named_saveable_objects.append(
base.NoRestoreSaveable(
tensor=object_graph_tensor,
name=base.OBJECT_GRAPH_PROTO_KEY))
return named_saveable_objects, registered_savers
def objects_ids_and_slot_variables_and_paths(self):
"""Traverse the object graph and list all accessible objects.
Looks for `Trackable` objects which are dependencies of
`root_trackable`. Includes slot variables only if the variable they are
slotting for and the optimizer are dependencies of `root_trackable`
(i.e. if they would be saved with a checkpoint).
Returns:
A tuple of (trackable objects, paths from root for each object,
object -> node id, slot variables, object_names)
"""
trackable_objects, node_paths = self._breadth_first_traversal()
object_names = object_identity.ObjectIdentityDictionary()
for obj, path in node_paths.items():
object_names[obj] = _object_prefix_from_path(path)
node_ids = object_identity.ObjectIdentityDictionary()
for node_id, node in enumerate(trackable_objects):
node_ids[node] = node_id
slot_variables = _serialize_slot_variables(
trackable_objects=trackable_objects,
node_ids=node_ids,
object_names=object_names)
return (trackable_objects, node_paths, node_ids, slot_variables,
object_names)
def objects_ids_and_slot_variables(self):
trackable_objects, _, node_ids, slot_variables, _ = (
self.objects_ids_and_slot_variables_and_paths())
return trackable_objects, node_ids, slot_variables
def list_objects(self):
"""Traverse the object graph and list all accessible objects."""
trackable_objects, _, _ = self.objects_ids_and_slot_variables()
return trackable_objects