blob: e23be2ef5056f23066188347e9c5f489a507c676 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Import a trackable object from a SavedModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
import sys
from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import function_deserialization
from tensorflow.python.saved_model import load_options
from tensorflow.python.saved_model import load_v1_in_v2
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
def _unused_handle():
"""Returns a placeholder as a handle that is not supposed to be accessed."""
error_message = ("Trying to access a placeholder that is not supposed to be "
"executed. This means you are executing a graph generated "
"from the cross-replica context in an in-replica context.")
assert_op = control_flow_ops.Assert(
array_ops.placeholder_with_default(False, shape=()),
[error_message])
with ops.control_dependencies([assert_op]):
return array_ops.placeholder(dtype=dtypes.resource)
class _WrapperFunction(function.ConcreteFunction):
"""A class wraps a concrete function to handle different distributed contexts.
The reason for wrapping a concrete function is because the _captured_inputs
fields used for in-replica context and cross-replica context are different.
When `load()` is called from within a tf.distribute.strategy scope, the
captured inputs are distributed variables. When using these distributed
variables during calling the function, we need different approaches when it is
in-replica and when it is not in-replica. When it is in replica, naturally we
should use the corresponding component of the distributed variable; when it is
not in-replica, calling the function should mean that it is constructing a
graph that is not actually going to be used. A typical use case is when
constructing a functional model. In this case, return a placeholder with a
control dependency to ensure that is never accessed.
"""
def __init__(self, concrete_function):
# Shallow copy the concrete_function
self.__dict__.update(vars(concrete_function))
def _call_flat(self, args, captured_inputs, cancellation_manager=None):
def get_handle(x):
return x.handle if distribute_utils.is_distributed_variable(x) else x
def get_unused_handle(x):
return _unused_handle() if distribute_utils.is_distributed_variable(x) \
else x
if (ds_context.get_replica_context() is not None or
values_util.is_saving_non_distributed()):
# If we're in the replica context or are saving a non-distributed version
# of the model, we resolve the captured variables to the corresponding
# resource handle. In both situation we call var.handle, but it has
# different behavior. In the replica context, var.handle resolves the
# replica local variable handle if the variable is replicated. When saving
# a non-distributed version of the model, var.handle resolves to the
# primary variable handle, since we only save one copy of a replicated
# variable.
captured_inputs = list(map(get_handle, captured_inputs))
else: # cross-replica context
captured_inputs = list(map(get_unused_handle, captured_inputs))
return super(_WrapperFunction, self)._call_flat(args, captured_inputs,
cancellation_manager)
class Loader(object):
"""Helper class to load an object-based SavedModel."""
def __init__(self, object_graph_proto, saved_model_proto, export_dir,
ckpt_options, filters):
meta_graph = saved_model_proto.meta_graphs[0]
self._asset_file_def = meta_graph.asset_file_def
self._operation_attributes = {
node.name: node.attr for node in meta_graph.graph_def.node}
self._proto = object_graph_proto
self._export_dir = export_dir
self._concrete_functions = (
function_deserialization.load_function_def_library(
meta_graph.graph_def.library))
self._checkpoint_options = ckpt_options
# Stores user-defined node_filters argument.
self._node_filters = filters
# Stores map of string paths to integers.
self._node_path_to_id = self._convert_node_paths_to_ints()
self._loaded_nodes = {}
if isinstance(filters, dict):
# If node_filters is a dict, then the values may contain already created
# trackable objects. In this case, create a dictionary mapping node IDs to
# the already created nodes. This dict will be updated in
# `_retrieve_all_filtered_nodes` with tracked dependencies.
for node_path, node in filters.items():
if isinstance(node, tuple):
self._loaded_nodes[self._node_path_to_id[node_path]] = node
else:
self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr)
# Get a list of all integer node ids to load, or None if all nodes should be
# loaded. This list includes ids of child nodes.
self._filtered_nodes = self._retrieve_all_filtered_nodes()
for name, concrete_function in self._concrete_functions.items():
# Wrap all the concrete function so that they are capable of dealing with
# both in replica and cross replica cases.
self._concrete_functions[name] = _WrapperFunction(concrete_function)
self._load_all()
self._restore_checkpoint()
for node in self._nodes:
if isinstance(node, tracking.CapturableResource):
init_op = node._initialize() # pylint: disable=protected-access
if not context.executing_eagerly():
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
def _convert_node_paths_to_ints(self):
"""Maps all string node paths in node_filters to the int node ids."""
if self._node_filters is None:
return None
path_to_int = {}
for node_id in self._node_filters:
int_node_id = None
if isinstance(node_id, str):
node_path = node_id.split(".")
if node_path[0] != "root":
raise ValueError(
"When passing string identifiers to node_filters, the first name"
" must be root.")
int_node_id = 0
for n, name in enumerate(node_path[1:]):
int_node_id = self._find_node_child(
int_node_id, name, ".".join(node_path[:n+2]))
path_to_int[node_id] = int_node_id
else:
raise TypeError("Elements in node_filters must be strings.")
return path_to_int
def _retrieve_all_filtered_nodes(self):
"""Traverses through the object graph to get the IDs of all nodes to load.
As a side-effect, if node_filters is a dictionary that contains already-
created objects, then the dependencies tracked by those objects will be
added to node_filters.
Returns:
List of all nodes to load, or None if all nodes should be loaded.
"""
if self._node_filters is None:
return None # All nodes should be loaded.
all_filtered_nodes = set()
nodes_to_visit = list(self._node_filters)
while nodes_to_visit:
node_path = nodes_to_visit.pop(0)
node_id = self._node_path_to_id[node_path]
if node_id in all_filtered_nodes:
continue
all_filtered_nodes.add(node_id)
node, setter = self._loaded_nodes.get(node_id, (None, None))
if node is not None:
if not isinstance(node, base.Trackable):
raise TypeError(
"Error when processing dictionary values passed to nodes_to_load."
"Object at {} is expected to be a checkpointable TensorFlow "
"object (e.g. tf.Variable, tf.Module or Keras layer)."
.format(node_path))
node._maybe_initialize_trackable() # pylint: disable=protected-access
for reference in self._proto.nodes[node_id].children:
child_object, _ = self._loaded_nodes.get(
reference.node_id, (None, None))
# See if node already tracks the child reference, in which case add the
# child to the loaded_nodes dict.
if child_object is None and node is not None:
child_object = node._lookup_dependency(reference.local_name) # pylint: disable=protected-access
if isinstance(child_object, data_structures.TrackableDataStructure):
# Make setattr a noop to avoid overwriting already existing data
# structures.
setter = lambda *args: None
self._loaded_nodes[reference.node_id] = (child_object, setter)
child_path = "{}.{}".format(node_path, reference.local_name)
self._node_path_to_id[child_path] = reference.node_id
nodes_to_visit.append(child_path)
if 0 in all_filtered_nodes:
return None
return all_filtered_nodes
def _find_node_child(self, node_id, child_name, path):
for reference in self._proto.nodes[node_id].children:
if reference.local_name == child_name:
return reference.node_id
raise ValueError("unable to find node {}".format(path))
def _load_all(self):
"""Loads all nodes and functions from the SavedModel and their edges."""
self._load_nodes()
self._load_edges()
# TODO(b/124045874): There are limitations with functions whose captures
# trigger other functions to be executed. For now it is only guaranteed to
# work if the captures of a function only trigger functions without
# captures.
self._setup_functions_structures()
self._setup_functions_captures()
self._create_saveable_object_factories()
def _create_saveable_object_factories(self):
for node_id, proto in self._iter_all_nodes():
node = self.get(node_id)
node._self_saveable_object_factories = {} # pylint: disable=protected-access
for name, saveable_object_proto in proto.saveable_objects.items():
node._self_saveable_object_factories[name] = ( # pylint: disable=protected-access
saveable_object_util.restored_saved_object_factory(
self.get(saveable_object_proto.save_function),
self.get(saveable_object_proto.restore_function)))
def _load_edges(self):
"""Adds edges from objects to other objects and functions."""
for node_id, object_proto in self._iter_all_nodes():
self._add_object_graph_edges(object_proto, node_id)
# If root object isn't loaded, then create edges from the root for
# checkpoint compatibility.
if self._filtered_nodes is not None and 0 not in self._filtered_nodes:
root = self.get(0)
for node_path in self._node_filters:
loaded_node = self._nodes[self._node_path_to_id[node_path]]
path = node_path.split(".")
current_node = root
for name in path[1:-1]:
if not hasattr(current_node, name):
setattr(current_node, name, self._recreate_base_user_object()[0])
current_node = getattr(current_node, name)
if not hasattr(current_node, path[-1]):
setattr(current_node, path[-1], loaded_node)
def _add_object_graph_edges(self, proto, node_id):
"""Adds edges from an object to its children."""
obj = self._nodes[node_id]
setter = self._node_setters[node_id]
for reference in proto.children:
setter(obj, reference.local_name, self._nodes[reference.node_id])
# Note: if an object has an attribute `__call__` add a class method
# that allows `obj()` syntax to work. This is done per-instance to
# allow `callable` to be used to find out if an object is callable.
if reference.local_name == "__call__" and not callable(obj):
setattr(type(obj), "__call__", _call_attribute)
def _setup_functions_structures(self):
"""Setup structure for inputs and outputs of restored functions."""
coder = nested_structure_coder.StructureCoder()
for name, proto in sorted(self._proto.concrete_functions.items()):
concrete_function = self._concrete_functions[name]
# By setting the structured_outputs directly, we can rely on this
# function_lib.ConcreteFunction object to perform the output repacking
# logic. The only limitation of that logic is that it only works
# with output that is convertible to Tensors and the conversion
# always happens. For example tf.TensorShape([2, 3]) will be
# converted to Tensor representing [2, 3].
original_outputs = coder.decode_proto(proto.output_signature)
# The original_outputs here had Tensors converted to TensorSpecs, so
# the restored function's structured_outputs field will not be
# exactly the same. Fortunately the repacking logic cares only about
# the structure; and the unpacking logic cares only about structure
# and types.
concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access
concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access
coder.decode_proto(proto.canonicalized_input_signature))
concrete_function._initialize_function_spec() # pylint: disable=protected-access
def _setup_functions_captures(self):
"""Setup captures and variables in restored functions."""
concrete_functions = sorted(self._proto.concrete_functions.items())
for name, proto in concrete_functions:
concrete_function = self._concrete_functions[name]
bound_inputs = [
self._get_tensor_from_node(node_id, name)
for node_id in proto.bound_inputs]
bound_variables = [
self._nodes[node_id]
for node_id in proto.bound_inputs
if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
]
# TODO(andresp): This is only injecting the captured inputs into the
# concrete function, note that we did not modify the FuncGraph
# itself.
concrete_function._captured_inputs = bound_inputs # pylint: disable=protected-access
concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access
if bound_inputs:
for bound_input, internal_capture in zip(
bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
if distribute_utils.is_distributed_variable(bound_input):
concrete_function.graph.capture_distributed_variable(
bound_input, internal_capture)
else:
concrete_function.graph.replace_capture(bound_input,
internal_capture)
if internal_capture.dtype == dtypes.resource:
if resource_variable_ops.is_resource_variable(bound_input):
try:
handle = bound_input.handle
except ValueError:
# For mirrored variables we'll copy handle data for components
# as they get captured.
pass
else:
custom_gradient.copy_handle_data(handle, internal_capture)
else:
custom_gradient.copy_handle_data(bound_input, internal_capture)
# Setting "captures" first means "capture" won't create a new
# placeholder for this input.
concrete_function.graph.capture(bound_input)
def _get_tensor_from_node(self, node_id, fn_name):
"""Resolves a node id into a tensor to be captured for a function."""
if self._node_filters is not None and self._nodes[node_id] is None:
raise ValueError(
"Error when processing nodes_to_load. Function \"{}\" requires "
"inputs/variables that are not loaded when nodes_to_load={}"
.format(fn_name, self._node_filters))
with ops.init_scope():
obj = self._nodes[node_id]
if distribute_utils.is_distributed_variable(obj):
return obj
elif resource_variable_ops.is_resource_variable(obj):
return obj.handle
elif isinstance(obj, tracking.Asset):
return obj.asset_path
elif tensor_util.is_tensor(obj):
return obj
elif isinstance(obj, tracking.CapturableResource):
# Note: this executes restored functions in the CapturableResource.
return obj.resource_handle
raise ValueError("Can't convert node %s to tensor" % (type(obj)))
def _initialize_loaded_nodes(self):
nodes = {}
node_setters = {}
for node_id, (node, setter) in self._loaded_nodes.items():
nodes[node_id] = node
node_setters[node_id] = setter
return nodes, node_setters
def _iter_all_nodes(self):
if self._filtered_nodes is None:
return enumerate(self._proto.nodes)
else:
return [(node_id, self._proto.nodes[node_id])
for node_id in self._filtered_nodes]
def _load_nodes(self):
"""Load all saved objects."""
# `nodes` maps from node ids to recreated objects
# `node_setters` maps from node ids to setter functions
# (same signature as setattr) for setting dependencies.
nodes, node_setters = self._initialize_loaded_nodes()
# Figure out which objects are slot variables. These objects are created
# with Optimizer.add_slot rather than _recreate_variable.
slot_variable_node_ids = set()
for _, proto in self._iter_all_nodes():
for slot_variable_proto in proto.slot_variables:
slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id)
# Re-create everything except slot variables.
for node_id, proto in self._iter_all_nodes():
if node_id in slot_variable_node_ids or nodes.get(node_id) is not None:
# Defer recreating slot variables so we can use the public Optimizer
# interface.
continue
node, setter = self._recreate(proto, node_id)
nodes[node_id] = node
node_setters[node_id] = setter
# Now that we have created the variables being optimized, we have enough
# information to re-create slot variables for them.
for node_id, proto in self._iter_all_nodes():
optimizer_object = nodes[node_id]
for slot_variable_proto in proto.slot_variables:
optimized_variable = nodes[
slot_variable_proto.original_variable_node_id]
slot_variable = optimizer_object.add_slot(
var=optimized_variable,
slot_name=slot_variable_proto.slot_name)
nodes[slot_variable_proto.slot_variable_node_id] = slot_variable
node_setters[slot_variable_proto.slot_variable_node_id] = setattr
# If root object is not loaded, add a dummy root object for checkpoint
# compatibility.
if 0 not in nodes:
nodes[0] = self._recreate_base_user_object()[0]
self._nodes = [nodes.get(node_id)
for node_id in range(len(self._proto.nodes))]
self._node_setters = node_setters
@property
def _expect_partial_checkpoint(self):
"""Whether to expect that some objects aren't loaded.
This should be set to True in subclasses of the Loader class which generate
a trackable object with an object graph that is different from the graph
in the SavedModel. Setting this property to True suppresses the warnings
that are printed out when there are unused parts of the checkpoint or
object.
Returns:
boolean
"""
return False
def _restore_checkpoint(self):
"""Load state from checkpoint into the deserialized objects."""
variables_path = saved_model_utils.get_variables_path(self._export_dir)
# TODO(andresp): Clean use of private methods of TrackableSaver.
# pylint: disable=protected-access
saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
with ops.device("CPU"):
saver._file_prefix_placeholder = constant_op.constant(variables_path)
if self._expect_partial_checkpoint:
load_status = saver.restore(variables_path,
self._checkpoint_options).expect_partial()
else:
load_status = saver.restore(variables_path, self._checkpoint_options)
load_status.assert_existing_objects_matched()
checkpoint = load_status._checkpoint
# When running in eager mode, the `restore` call above has already run and
# restored the state of trackables, call `position.restore_ops()` will
# return an empty list as there is nothing left to do. In graph mode, that
# will return the list of ops that must run to restore the object on that
# position. We have to wire them in the initializers of the objects so that
# they get initialized properly when using common practices (e.g. the ones
# used by ManagedSession) without further user action.
for object_id, obj in dict(checkpoint.object_by_proto_id).items():
position = base.CheckpointPosition(checkpoint=checkpoint,
proto_id=object_id)
restore_ops = position.restore_ops()
if restore_ops:
if resource_variable_ops.is_resource_variable(obj):
if len(restore_ops) == 1:
obj._initializer_op = restore_ops[0]
else:
obj._initializer_op = control_flow_ops.group(*restore_ops)
elif isinstance(obj, lookup_ops.LookupInterface):
# We don't need to check for eager execution here, since this code
# path should only be taken if we are restoring in graph mode.
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
else:
raise NotImplementedError(
("Missing functionality to restore state of object "
"%r from the checkpoint." % obj))
def adjust_debug_info_func_names(self, debug_info):
"""Rewrite func names in the debug info by using the concrete func names."""
output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
output_debug_info.files[:] = debug_info.files
for key in debug_info.traces:
node, func = key.split("@")
new_func = ""
if func in self._concrete_functions:
new_func = self._concrete_functions[func].function_def.signature.name
output_debug_info.traces[node + "@" + new_func].CopyFrom(
debug_info.traces[key])
return output_debug_info
def get(self, node_id):
if isinstance(node_id, str):
node_id = self._node_path_to_id[node_id]
return self._nodes[node_id]
def _recreate(self, proto, node_id):
"""Creates a Python object from a SavedObject protocol buffer."""
factory = {
"user_object": (
lambda: self._recreate_user_object(proto.user_object, node_id)),
"asset": lambda: self._recreate_asset(proto.asset),
"function": lambda: self._recreate_function(proto.function),
"bare_concrete_function": functools.partial(
self._recreate_bare_concrete_function,
proto.bare_concrete_function),
"variable": lambda: self._recreate_variable(proto.variable),
"constant": lambda: self._recreate_constant(proto.constant),
"resource": lambda: self._recreate_resource(proto.resource),
}
kind = proto.WhichOneof("kind")
if kind not in factory:
raise ValueError("Unknown SavedObject type: %r" % kind)
return factory[kind]()
def _recreate_user_object(self, proto, node_id):
"""Instantiates a SavedUserObject."""
looked_up = revived_types.deserialize(proto)
if looked_up is None:
return self._recreate_base_user_object(proto, node_id)
return looked_up
def _recreate_base_user_object(self, proto=None, node_id=None):
del proto, node_id
# Note: each user object has its own class. This allows making each one
# individually callable by adding a `__call__` method to the classes of
# the objects instances that have a `__call__` property.
class _UserObject(tracking.AutoTrackable):
pass
return _UserObject(), setattr
def _recreate_asset(self, proto):
filename = os.path.join(
saved_model_utils.get_assets_dir(self._export_dir),
self._asset_file_def[proto.asset_file_def_index].filename)
return tracking.Asset(filename), setattr
def _recreate_function(self, proto):
return function_deserialization.recreate_function(
proto, self._concrete_functions), setattr
def _recreate_bare_concrete_function(self, proto):
return function_deserialization.setup_bare_concrete_function(
proto, self._concrete_functions), setattr
def _recreate_variable(self, proto):
name = proto.name if proto.name else None
if name is not None:
dbg_name = name
else:
dbg_name = "<variable loaded from saved model>"
synchronization, aggregation, trainable = (
variables.validate_synchronization_aggregation_trainable(
proto.synchronization, proto.aggregation, proto.trainable,
name=dbg_name))
def uninitialized_variable_creator(next_creator, **kwargs):
"""A variable creator that creates uninitialized variables."""
del next_creator
return resource_variable_ops.UninitializedVariable(**kwargs)
# Create a variable_creator_scope that creates uninitialized variables with
# a lower priority such that a potential distributed variable_creator_scope
# can take precedence.
with ops.get_default_graph()._variable_creator_scope( # pylint: disable=protected-access
uninitialized_variable_creator,
priority=50):
return variables.Variable(
shape=proto.shape,
dtype=proto.dtype,
name=name,
trainable=trainable,
synchronization=synchronization,
aggregation=aggregation), setattr
def _recreate_constant(self, proto):
tensor_proto = self._operation_attributes[proto.operation]["value"].tensor
ndarray = tensor_util.MakeNdarray(tensor_proto)
if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string:
with ops.device("CPU"):
imported_constant = constant_op.constant(ndarray)
else:
imported_constant = constant_op.constant(ndarray)
return imported_constant, setattr
def _recreate_resource(self, proto):
return _RestoredResource(device=proto.device), setattr
# TODO(b/124205571,b/124092991): Solve destruction of resources.
class _RestoredResource(tracking.TrackableResource):
"""Restored SavedResource."""
def __init__(self, device=""):
super(_RestoredResource, self).__init__(device=device)
self._destroy_resource_fn = None
def _create_resource(self):
raise RuntimeError()
def _initialize(self):
raise RuntimeError()
@property
def _destroy_resource(self):
return self._destroy_resource_fn
@_destroy_resource.setter
def _destroy_resource(self, destroy_resource_fn):
self._resource_deleter = tracking.CapturableResourceDeleter(
destroy_resource_fn)
self._destroy_resource_fn = destroy_resource_fn
def _list_functions_for_serialization(self, unused_serialization_cache):
# Overwrite this method to avoid the implementation of
# base class to re-wrap the polymorphic functions into
# another layer of `tf.function`.
functions = {
"_create_resource": self._create_resource,
"_initialize": self._initialize,
}
if self._destroy_resource:
functions.update(_destroy_resource=self._destroy_resource)
return functions
def _call_attribute(instance, *args, **kwargs):
return instance.__call__(*args, **kwargs)
@tf_export("__internal__.saved_model.load_partial", v1=[])
def load_partial(export_dir, filters, tags=None, options=None):
"""Partially load a SavedModel (saved from V2).
Similar to `tf.saved_model.load`, but with an additional argument that
lets you specify which nodes to load.
`tf.saved_model.load_partial(export_dir, ["root"])` and
`tf.saved_model.load(export_dir)` are equivalent.
Note: This only works for SavedModels saved with TensorFlow V2 from
`tf.saved_model.save` or Keras. This will not load SavedModels save from
the Estimator API.
In Tensorflow V2, SavedModel stores the **object graph** of the saved object.
The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras
layers, etc.) and edges that are the name of the attributes connecting the
objects.
*Example 1*
```
model = tf.Module()
model.child_layer = tf.Module()
model.child_layer.v = tf.Variable(5.)
tf.saved_model.save(model, '/tmp/model')
loaded = tf.__internal__.saved_model.load_partial(
... '/tmp/model',
... ['root.child_layer', 'root.child_layer.v'])
loaded['root.child_layer'].v.numpy()
5.
loaded['root.child_layer'].v is loaded['root.child_layer.v']
True
*Example 2*
model = tf.Module()
model.child_layer = tf.Module()
model.child_layer.v = tf.Variable(5.)
>>>
tf.saved_model.save(model, '/tmp/model')
# Create a variable
new_variable = tf.Variable(0.)
loaded = tf.__internal__.saved_model.load_partial(
... '/tmp/model',
... {'root.child_layer': None, 'root.child_layer.v': new_variable})
loaded['root.child_layer'].v.numpy()
5.
new_variable.numpy()
5.
```
**Loading under different distribution strategies**
You can load different parts of the model under different distribution
strategies. Note that this is very experimental so use with care.
```
model = tf.Module()
model.layer_1 = tf.Module()
model.layer_1.v = tf.Variable(5.)
model.layer_2 = tf.Module()
model.layer_2.v = tf.Variable(7.)
tf.saved_model.save(model, '/tmp/model')
# Load with no strategy
loaded = tf.__internal__.saved_model.load_partial(
... '/tmp/model',
... ['root.layer_1'])
loaded['root.layer_1'].v
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
... loaded2 = tf.__internal__.saved_model.load_partial(
... '/tmp/model',
... ['root.layer_2'])
loaded2['root.layer_2'].v
MirroredVariable:{
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>
}
```
Args:
export_dir: The SavedModel directory to load from.
filters: A list or dictionary where each element or key is a string
path to nodes that should be loaded. Node paths consist of all the child
attribute names to reach that node in the form: `root.{attribute_name}`.
The loader will load all of the specified nodes and their recursive
descendants. When this option is defined, the loader will return a
dictionary mapping the node paths to the loaded objects.
tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
if the SavedModel contains a single MetaGraph, as for those exported from
`tf.saved_model.save`.
options: `tf.saved_model.LoadOptions` object that specifies options for
loading.
Returns:
A dictionary mapping node paths from the filter to loaded objects.
"""
return load_internal(export_dir, tags, options, filters=filters)
@tf_export("saved_model.load", v1=["saved_model.load_v2"])
def load(export_dir, tags=None, options=None):
"""Load a SavedModel from `export_dir`.
Signatures associated with the SavedModel are available as functions:
```python
imported = tf.saved_model.load(path)
f = imported.signatures["serving_default"]
print(f(x=tf.constant([[1.]])))
```
Objects exported with `tf.saved_model.save` additionally have trackable
objects and functions assigned to attributes:
```python
exported = tf.train.Checkpoint(v=tf.Variable(3.))
exported.f = tf.function(
lambda x: exported.v * x,
input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
tf.saved_model.save(exported, path)
imported = tf.saved_model.load(path)
assert 3. == imported.v.numpy()
assert 6. == imported.f(x=tf.constant(2.)).numpy()
```
_Loading Keras models_
Keras models are trackable, so they can be saved to SavedModel. The object
returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have
`.fit`, `.predict`, etc. methods). A few attributes and functions are still
available: `.variables`, `.trainable_variables` and `.__call__`.
```python
model = tf.keras.Model(...)
tf.saved_model.save(model, path)
imported = tf.saved_model.load(path)
outputs = imported(inputs)
```
Use `tf.keras.models.load_model` to restore the Keras model.
_Importing SavedModels from TensorFlow 1.x_
SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
graph instead of `tf.function` objects. These SavedModels will be loaded with
the following attributes:
* `.signatures`: A dictionary mapping signature names to functions.
* `.prune(feeds, fetches) `: A method which allows you to extract
functions for new subgraphs. This is equivalent to importing the SavedModel
and naming feeds and fetches in a Session from TensorFlow 1.x.
```python
imported = tf.saved_model.load(path_to_v1_saved_model)
pruned = imported.prune("x:0", "out:0")
pruned(tf.ones([]))
```
See `tf.compat.v1.wrap_function` for details.
* `.variables`: A list of imported variables.
* `.graph`: The whole imported graph.
* `.restore(save_path)`: A function that restores variables from a checkpoint
saved from `tf.compat.v1.Saver`.
_Consuming SavedModels asynchronously_
When consuming SavedModels asynchronously (the producer is a separate
process), the SavedModel directory will appear before all files have been
written, and `tf.saved_model.load` will fail if pointed at an incomplete
SavedModel. Rather than checking for the directory, check for
"saved_model_dir/saved_model.pb". This file is written atomically as the last
`tf.saved_model.save` file operation.
Args:
export_dir: The SavedModel directory to load from.
tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
if the SavedModel contains a single MetaGraph, as for those exported from
`tf.saved_model.save`.
options: `tf.saved_model.LoadOptions` object that specifies options for
loading.
Returns:
A trackable object with a `signatures` attribute mapping from signature
keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
it also points to trackable objects, functions, debug info which it has been
saved.
Raises:
ValueError: If `tags` don't match a MetaGraph in the SavedModel.
"""
return load_internal(export_dir, tags, options)["root"]
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader,
filters=None):
"""Loader implementation."""
options = options or load_options.LoadOptions()
if tags is not None and not isinstance(tags, set):
# Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
# sequences for nest.flatten, so we put those through as-is.
tags = nest.flatten(tags)
saved_model_proto, debug_info = (
loader_impl.parse_saved_model_with_debug_info(export_dir))
if (len(saved_model_proto.meta_graphs) == 1 and
saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
meta_graph_def = saved_model_proto.meta_graphs[0]
# tensor_content field contains raw bytes in litle endian format which causes problems
# when loaded on big-endian systems requiring byteswap
if sys.byteorder == 'big':
saved_model_utils.swap_function_tensor_content(
meta_graph_def, "little", "big")
if (tags is not None
and set(tags) != set(meta_graph_def.meta_info_def.tags)):
raise ValueError(
("The SavedModel at {} has one MetaGraph with tags {}, but got an "
"incompatible argument tags={} to tf.saved_model.load. You may omit "
"it, pass 'None', or pass matching tags.")
.format(export_dir, meta_graph_def.meta_info_def.tags, tags))
object_graph_proto = meta_graph_def.object_graph_def
ckpt_options = checkpoint_options.CheckpointOptions(
experimental_io_device=options.experimental_io_device)
with ops.init_scope():
try:
loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
ckpt_options, filters)
except errors.NotFoundError as err:
raise FileNotFoundError(
str(err) + "\n If trying to load on a different device from the "
"computational device, consider using setting the "
"`experimental_io_device` option on tf.saved_model.LoadOptions "
"to the io_device such as '/job:localhost'."
)
root = loader.get(0)
if isinstance(loader, Loader):
root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)
root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
root.tensorflow_git_version = (
meta_graph_def.meta_info_def.tensorflow_git_version)
else:
if filters:
raise ValueError("SavedModels saved from Tensorflow V1 or Estimator (any "
"version) cannot be loaded with node filters.")
with ops.init_scope():
root = load_v1_in_v2.load(export_dir, tags)
root.graph_debug_info = debug_info
if filters:
return {node_id: loader.get(node_id) for node_id in filters}
else:
return {"root": root}