blob: 726180bed18270628c526d7f2e926b01bbeaa874 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Exports a SavedModel from a Trackable Python object."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.distribute import values as ds_values
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import versions
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.saved_model import builder_impl
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_serialization
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import signature_serialization
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils_impl
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import compat
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
# A container for an EagerTensor constant which has been copied to the exported
# Graph.
_CapturedConstant = collections.namedtuple(
"_CapturedConstant", ["eager_tensor", "graph_tensor"])
class _AugmentedGraphView(graph_view.ObjectGraphView):
"""An extendable graph which also tracks functions attached to objects.
Extensions through `add_object` appear in the object graph and any checkpoints
generated from it, even if they are not dependencies of the node they were
attached to in the saving program. For example a `.signatures` attribute is
added to exported SavedModel root objects without modifying the root object
itself.
Also tracks functions attached to objects in the graph, through the caching
`list_functions` method. Enumerating functions only through this method
ensures that we get a consistent view of functions, even if object attributes
create new functions every time they are accessed.
"""
def __init__(self, root):
if (not context.executing_eagerly()
and not ops.inside_function()):
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
else:
saveables_cache = None
super(_AugmentedGraphView, self).__init__(root, saveables_cache)
# Object -> (name -> dep)
self._extra_dependencies = object_identity.ObjectIdentityDictionary()
self._functions = object_identity.ObjectIdentityDictionary()
# Cache shared between objects in the same object graph. This is passed to
# each trackable object's `_list_extra_dependencies_for_serialization` and
# `_list_functions_for_serialization` function.
self._serialization_cache = object_identity.ObjectIdentityDictionary()
def add_object(self, parent_node, name_in_parent, subgraph_root):
"""Attach an object to `parent_node`, overriding any existing dependency."""
self._extra_dependencies.setdefault(
parent_node, {})[name_in_parent] = subgraph_root
def list_dependencies(self, obj):
"""Overrides a parent method to include `add_object` objects."""
extra_dependencies = self.list_extra_dependencies(obj)
extra_dependencies.update(self._extra_dependencies.get(obj, {}))
used_names = set()
for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj):
used_names.add(name)
if name in extra_dependencies:
# Extra dependencies (except for `.signatures`, which is always added
# when saving) should not have naming conflicts with dependencies
# defined by the user.
if name != signature_serialization.SIGNATURE_ATTRIBUTE_NAME:
raise ValueError(
"Error when exporting object {} of with identifier={}. The object"
" has an attribute named {}, which is reserved. List of all "
"reserved attributes: {}".format(
obj, obj._object_identifier, # pylint: disable=protected-access
name, extra_dependencies.keys()))
yield base.TrackableReference(name, extra_dependencies[name])
else:
yield base.TrackableReference(name, dep)
for name, dep in extra_dependencies.items():
if name in used_names:
continue
yield base.TrackableReference(name, dep)
def list_extra_dependencies(self, obj):
return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access
self._serialization_cache)
def list_functions(self, obj):
obj_functions = self._functions.get(obj, None)
if obj_functions is None:
obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
self._serialization_cache)
self._functions[obj] = obj_functions
return obj_functions
class _SaveableView(object):
"""Provides a frozen view over a trackable root.
This class helps creating a single stable view over an object to save. The
saving code should access properties and functions via this class and not via
the original object as there are cases where an object construct their
trackable attributes and functions dynamically per call and will yield
different objects if invoked more than once.
Changes to the graph, for example adding objects, must happen in
`checkpoint_view` (an `_AugmentedGraphView`) before the `_SaveableView` is
constructed. Changes after the `_SaveableView` has been constructed will be
ignored.
"""
def __init__(self, checkpoint_view):
self.checkpoint_view = checkpoint_view
trackable_objects, node_ids, slot_variables = (
self.checkpoint_view.objects_ids_and_slot_variables())
self.nodes = trackable_objects
self.node_ids = node_ids
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
self.slot_variables = slot_variables
self.concrete_functions = []
# Also add `Function`s as nodes.
nodes_without_functions = list(self.nodes)
seen_function_names = set()
for node in nodes_without_functions:
for function in checkpoint_view.list_functions(node).values():
if function not in self.node_ids:
self.node_ids[function] = len(self.nodes)
self.nodes.append(function)
if isinstance(function, def_function.Function):
# Force listing the concrete functions for the side effects:
# - populate the cache for functions that have an input_signature
# and have not been called.
# - force side effects of creation of concrete functions, e.g. create
# variables on first run.
concrete_functions = (
function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access
else:
concrete_functions = [function]
for concrete_function in concrete_functions:
if concrete_function.name not in seen_function_names:
seen_function_names.add(concrete_function.name)
self.concrete_functions.append(concrete_function)
@property
def root(self):
return self.nodes[0]
def fill_object_graph_proto(self, proto):
"""Populate the nodes, children and slot_variables of a SavedObjectGraph."""
for node_id, node in enumerate(self.nodes):
assert self.node_ids[node] == node_id
object_proto = proto.nodes.add()
object_proto.slot_variables.extend(self.slot_variables.get(node, ()))
if isinstance(node, (def_function.Function, defun.ConcreteFunction,
_CapturedConstant)):
continue
for child in self.checkpoint_view.list_dependencies(node):
child_proto = object_proto.children.add()
child_proto.node_id = self.node_ids[child.ref]
child_proto.local_name = child.name
for local_name, ref_function in (
self.checkpoint_view.list_functions(node).items()):
child_proto = object_proto.children.add()
child_proto.node_id = self.node_ids[ref_function]
child_proto.local_name = local_name
def map_resources(self):
"""Makes new resource handle ops corresponding to existing resource tensors.
Creates resource handle ops in the current default graph, whereas
`accessible_objects` will be from an eager context. Resource mapping adds
resource handle ops to the main GraphDef of a SavedModel, which allows the
C++ loader API to interact with variables.
Returns:
A tuple of (object_map, resource_map, asset_info):
object_map: A dictionary mapping from object in `accessible_objects` to
replacement objects created to hold the new resource tensors.
resource_map: A dictionary mapping from resource tensors extracted from
`accessible_objects` to newly created resource tensors.
asset_info: An _AssetInfo tuple describing external assets referenced
from accessible_objects.
"""
# Only makes sense when adding to the export Graph
assert not context.executing_eagerly()
# TODO(allenl): Handle MirroredVariables and other types of variables which
# may need special casing.
object_map = object_identity.ObjectIdentityDictionary()
resource_map = {}
asset_info = _AssetInfo(
asset_defs=[],
asset_initializers_by_resource={},
asset_filename_map={},
asset_index={})
for node_id, obj in enumerate(self.nodes):
if isinstance(obj, tracking.CapturableResource):
# pylint: disable=protected-access
with ops.device(obj._resource_device):
new_resource = obj._create_resource()
# pylint: enable=protected-access
resource_map[obj.resource_handle] = new_resource
self.captured_tensor_node_ids[obj.resource_handle] = node_id
elif ds_values.is_distributed_variable(obj):
# Put both the distributed variable and component variable handles in
# `captured_tensor_node_ids`.
# Also create a new distributed variable for `object_map` with newly
# created component variables.
new_vars = []
for v in obj.values:
new_variable = resource_variable_ops.copy_to_graph_uninitialized(v)
object_map[v] = new_variable
new_vars.append(new_variable)
resource_map[v.handle] = new_variable.handle
self.captured_tensor_node_ids[v.handle] = node_id
object_map[obj] = obj._clone_with_new_values(new_vars) # pylint: disable=protected-access
self.captured_tensor_node_ids[obj] = node_id
elif resource_variable_ops.is_resource_variable(obj):
new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
object_map[obj] = new_variable
resource_map[obj.handle] = new_variable.handle
self.captured_tensor_node_ids[obj.handle] = node_id
elif isinstance(obj, tracking.TrackableAsset):
_process_asset(obj, asset_info, resource_map)
self.captured_tensor_node_ids[obj.asset_path] = node_id
for concrete_function in self.concrete_functions:
if not concrete_function.graph.saveable:
raise ValueError(
("Unable to save function {name} for the following reason(s):\n" +
"\n".join(concrete_function.graph.saving_errors))
.format(name=concrete_function.name))
for capture in concrete_function.captured_inputs:
if (tensor_util.is_tensor(capture)
and capture.dtype not in _UNCOPIABLE_DTYPES
and capture not in self.captured_tensor_node_ids):
capture_constant_value = tensor_util.constant_value(capture)
if capture_constant_value is None:
raise ValueError(
("Attempted to save a function {} which references a symbolic "
"Tensor {} that is not a simple constant. This is not "
"supported.").format(concrete_function.name, capture))
copied_tensor = constant_op.constant(capture_constant_value)
node_id = len(self.nodes)
node = _CapturedConstant(
eager_tensor=capture, graph_tensor=copied_tensor)
self.nodes.append(node)
self.node_ids[capture] = node_id
self.node_ids[node] = node_id
self.captured_tensor_node_ids[capture] = node_id
resource_map[capture] = copied_tensor
return object_map, resource_map, asset_info
def _tensor_dict_to_tensorinfo(tensor_dict):
return {key: utils_impl.build_tensor_info_internal(value)
for key, value in tensor_dict.items()}
def _map_captures_to_created_tensors(
original_captures, resource_map):
"""Maps eager tensors captured by a function to Graph resources for export.
Args:
original_captures: A dictionary mapping from tensors captured by the
function to interior placeholders for those tensors (inside the function
body).
resource_map: A dictionary mapping from resource tensors owned by the eager
context to resource tensors in the exported graph.
Returns:
A list of stand-in tensors which belong to the exported graph, corresponding
to the function's captures.
Raises:
AssertionError: If the function references a resource which is not part of
`resource_map`.
"""
export_captures = []
for exterior, interior in original_captures:
mapped_resource = resource_map.get(exterior, None)
if mapped_resource is None:
raise AssertionError(
("Tried to export a function which references untracked object {}."
"TensorFlow objects (e.g. tf.Variable) captured by functions must "
"be tracked by assigning them to an attribute of a tracked object "
"or assigned to an attribute of the main object directly.")
.format(interior))
export_captures.append(mapped_resource)
return export_captures
def _map_function_arguments_to_created_inputs(
function_arguments, signature_key, function_name):
"""Creates exterior placeholders in the exported graph for function arguments.
Functions have two types of inputs: tensors captured from the outside (eager)
context, and arguments to the function which we expect to receive from the
user at each call. `_map_captures_to_created_tensors` replaces
captured tensors with stand-ins (typically these are resource dtype tensors
associated with variables). `_map_function_inputs_to_created_inputs` runs over
every argument, creating a new placeholder for each which will belong to the
exported graph rather than the function body.
Args:
function_arguments: A list of argument placeholders in the function body.
signature_key: The name of the signature being exported, for error messages.
function_name: The name of the function, for error messages.
Returns:
A tuple of (mapped_inputs, exterior_placeholders)
mapped_inputs: A list with entries corresponding to `function_arguments`
containing all of the inputs of the function gathered from the exported
graph (both captured resources and arguments).
exterior_argument_placeholders: A dictionary mapping from argument names
to placeholders in the exported graph, containing the explicit arguments
to the function which a user is expected to provide.
Raises:
ValueError: If argument names are not unique.
"""
# `exterior_argument_placeholders` holds placeholders which are outside the
# function body, directly contained in a MetaGraph of the SavedModel. The
# function body itself contains nearly identical placeholders used when
# running the function, but these exterior placeholders allow Session-based
# APIs to call the function using feeds and fetches which name Tensors in the
# MetaGraph.
exterior_argument_placeholders = {}
mapped_inputs = []
for placeholder in function_arguments:
# `export_captures` contains an exhaustive set of captures, so if we don't
# find the input there then we now know we have an argument.
user_input_name = compat.as_str_any(
placeholder.op.get_attr("_user_specified_name"))
# If the internal placeholders for a function have names which were
# uniquified by TensorFlow, then a single user-specified argument name
# must refer to multiple Tensors. The resulting signatures would be
# confusing to call. Instead, we throw an exception telling the user to
# specify explicit names.
if user_input_name != placeholder.op.name:
# This should be unreachable, since concrete functions may not be
# generated with non-unique argument names.
raise ValueError(
("Got non-flat/non-unique argument names for SavedModel "
"signature '{}': more than one argument to '{}' was named '{}'. "
"Signatures have one Tensor per named input, so to have "
"predictable names Python functions used to generate these "
"signatures should avoid *args and Tensors in nested "
"structures unless unique names are specified for each. Use "
"tf.TensorSpec(..., name=...) to provide a name for a Tensor "
"input.")
.format(signature_key, compat.as_str_any(function_name),
user_input_name))
arg_placeholder = array_ops.placeholder(
shape=placeholder.shape,
dtype=placeholder.dtype,
name="{}_{}".format(signature_key, user_input_name))
exterior_argument_placeholders[user_input_name] = arg_placeholder
mapped_inputs.append(arg_placeholder)
return mapped_inputs, exterior_argument_placeholders
def _call_function_with_mapped_captures(function, args, resource_map):
"""Calls `function` in the exported graph, using mapped resource captures."""
export_captures = _map_captures_to_created_tensors(
function.graph.captures, resource_map)
# Calls the function quite directly, since we have new captured resource
# tensors we need to feed in which weren't part of the original function
# definition.
# pylint: disable=protected-access
outputs = function._call_flat(args, export_captures)
# pylint: enable=protected-access
return outputs
def _generate_signatures(signature_functions, resource_map):
"""Validates and calls `signature_functions` in the default graph.
Args:
signature_functions: A dictionary mapping string keys to concrete TensorFlow
functions (e.g. from `signature_serialization.canonicalize_signatures`)
which will be used to generate SignatureDefs.
resource_map: A dictionary mapping from resource tensors in the eager
context to resource tensors in the Graph being exported. This dictionary
is used to re-bind resources captured by functions to tensors which will
exist in the SavedModel.
Returns:
Each function in the `signature_functions` dictionary is called with
placeholder Tensors, generating a function call operation and output
Tensors. The placeholder Tensors, the function call operation, and the
output Tensors from the function call are part of the default Graph.
This function then returns a dictionary with the same structure as
`signature_functions`, with the concrete functions replaced by SignatureDefs
implicitly containing information about how to call each function from a
TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference
the generated placeholders and Tensor outputs by name.
The caller is expected to include the default Graph set while calling this
function as a MetaGraph in a SavedModel, including the returned
SignatureDefs as part of that MetaGraph.
"""
signatures = {}
for signature_key, function in sorted(signature_functions.items()):
if function.graph.captures:
argument_inputs = function.graph.inputs[:-len(function.graph.captures)]
else:
argument_inputs = function.graph.inputs
mapped_inputs, exterior_argument_placeholders = (
_map_function_arguments_to_created_inputs(
argument_inputs, signature_key, function.name))
outputs = _call_function_with_mapped_captures(
function, mapped_inputs, resource_map)
signatures[signature_key] = signature_def_utils.build_signature_def(
_tensor_dict_to_tensorinfo(exterior_argument_placeholders),
_tensor_dict_to_tensorinfo(outputs),
method_name=signature_constants.PREDICT_METHOD_NAME)
return signatures
def _trace_resource_initializers(accessible_objects):
"""Create concrete functions from `CapturableResource` objects."""
resource_initializers = []
def _wrap_initializer(obj):
obj._initialize() # pylint: disable=protected-access
return constant_op.constant(1.) # Dummy control output
def _wrap_obj_initializer(obj):
return lambda: _wrap_initializer(obj)
for obj in accessible_objects:
if isinstance(obj, tracking.CapturableResource):
resource_initializers.append(def_function.function(
_wrap_obj_initializer(obj),
# All inputs are captures.
input_signature=[]).get_concrete_function())
return resource_initializers
_AssetInfo = collections.namedtuple(
"_AssetInfo", [
# List of AssetFileDef protocol buffers
"asset_defs",
# Map from asset variable resource Tensors to their init ops
"asset_initializers_by_resource",
# Map from base asset filenames to full paths
"asset_filename_map",
# Map from TrackableAsset to index of corresponding AssetFileDef
"asset_index"])
def _process_asset(trackable_asset, asset_info, resource_map):
"""Add `trackable_asset` to `asset_info` and `resource_map`."""
original_path_tensor = trackable_asset.asset_path
original_path = tensor_util.constant_value(original_path_tensor)
try:
original_path = str(original_path.astype(str))
except AttributeError:
# Already a string rather than a numpy array
pass
path = builder_impl.get_asset_filename_to_add(
asset_filepath=original_path,
asset_filename_map=asset_info.asset_filename_map)
# TODO(andresp): Instead of mapping 1-1 between trackable asset
# and asset in the graph def consider deduping the assets that
# point to the same file.
asset_path_initializer = array_ops.placeholder(
shape=original_path_tensor.shape,
dtype=dtypes.string,
name="asset_path_initializer")
asset_variable = resource_variable_ops.ResourceVariable(
asset_path_initializer)
asset_info.asset_filename_map[path] = original_path
asset_def = meta_graph_pb2.AssetFileDef()
asset_def.filename = path
asset_def.tensor_info.name = asset_path_initializer.name
asset_info.asset_defs.append(asset_def)
asset_info.asset_initializers_by_resource[original_path_tensor] = (
asset_variable.initializer)
asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1
resource_map[original_path_tensor] = asset_variable
def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
"""Generates a MetaGraph which calls `signature_functions`.
Args:
meta_graph_def: The MetaGraphDef proto to fill.
saveable_view: The _SaveableView being exported.
signature_functions: A dictionary mapping signature keys to concrete
functions containing signatures to add to the MetaGraph.
Returns:
An _AssetInfo, which contains information to help creating the SavedModel.
"""
# List objects from the eager context to make sure Optimizers give us the
# right Graph-dependent variables.
accessible_objects = saveable_view.nodes
resource_initializer_functions = _trace_resource_initializers(
accessible_objects)
exported_graph = ops.Graph()
resource_initializer_ops = []
with exported_graph.as_default():
object_map, resource_map, asset_info = saveable_view.map_resources()
for resource_initializer_function in resource_initializer_functions:
asset_dependencies = []
for capture in resource_initializer_function.graph.external_captures:
asset_initializer = asset_info.asset_initializers_by_resource.get(
capture, None)
if asset_initializer is not None:
asset_dependencies.append(asset_initializer)
with ops.control_dependencies(asset_dependencies):
resource_initializer_ops.append(
_call_function_with_mapped_captures(
resource_initializer_function, [], resource_map))
resource_initializer_ops.extend(
asset_info.asset_initializers_by_resource.values())
with ops.control_dependencies(resource_initializer_ops):
init_op = control_flow_ops.no_op()
# Add the same op to the main_op collection and to the init_op
# signature. The collection is for compatibility with older loader APIs;
# only one will be executed.
meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append(
init_op.name)
meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom(
signature_def_utils.op_signature_def(
init_op, constants.INIT_OP_SIGNATURE_KEY))
# Saving an object-based checkpoint again gathers variables. We need to do the
# gathering from the eager context so Optimizers save the right set of
# variables, but want any operations associated with the save/restore to be in
# the exported graph (thus the `to_graph` argument).
saver = functional_saver.MultiDeviceSaver(
saveable_view.checkpoint_view.frozen_saveable_objects(
object_map=object_map, to_graph=exported_graph))
with exported_graph.as_default():
signatures = _generate_signatures(signature_functions, resource_map)
for concrete_function in saveable_view.concrete_functions:
concrete_function.add_to_graph()
saver_def = saver.to_proto()
meta_graph_def.saver_def.CopyFrom(saver_def)
graph_def = exported_graph.as_graph_def(add_shapes=True)
meta_graph_def.graph_def.CopyFrom(graph_def)
meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING)
meta_graph_def.meta_info_def.tensorflow_version = versions.__version__
meta_graph_def.meta_info_def.tensorflow_git_version = (
versions.__git_version__)
# We currently always strip default attributes.
meta_graph_def.meta_info_def.stripped_default_attrs = True
meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def))
meta_graph_def.asset_file_def.extend(asset_info.asset_defs)
for signature_key, signature in signatures.items():
meta_graph_def.signature_def[signature_key].CopyFrom(signature)
meta_graph.strip_graph_default_valued_attrs(meta_graph_def)
return asset_info, exported_graph
def _serialize_object_graph(saveable_view, asset_file_def_index):
"""Save a SavedObjectGraph proto for `root`."""
# SavedObjectGraph is similar to the TrackableObjectGraph proto in the
# checkpoint. It will eventually go into the SavedModel.
proto = saved_object_graph_pb2.SavedObjectGraph()
saveable_view.fill_object_graph_proto(proto)
coder = nested_structure_coder.StructureCoder()
for concrete_function in saveable_view.concrete_functions:
serialized = function_serialization.serialize_concrete_function(
concrete_function, saveable_view.captured_tensor_node_ids, coder)
if serialized is not None:
proto.concrete_functions[concrete_function.name].CopyFrom(
serialized)
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
_write_object_proto(obj, obj_proto, asset_file_def_index)
return proto
def _write_object_proto(obj, proto, asset_file_def_index):
"""Saves an object into SavedObject proto."""
if isinstance(obj, tracking.TrackableAsset):
proto.asset.SetInParent()
proto.asset.asset_file_def_index = asset_file_def_index[obj]
elif resource_variable_ops.is_resource_variable(obj):
proto.variable.SetInParent()
if not obj.name.endswith(":0"):
raise ValueError("Cowardly refusing to save variable %s because of"
" unexpected suffix which won't be restored.")
proto.variable.name = meta_graph._op_name(obj.name) # pylint: disable=protected-access
proto.variable.trainable = obj.trainable
proto.variable.dtype = obj.dtype.as_datatype_enum
proto.variable.synchronization = obj.synchronization.value
proto.variable.aggregation = obj.aggregation.value
proto.variable.shape.CopyFrom(obj.shape.as_proto())
elif isinstance(obj, def_function.Function):
proto.function.CopyFrom(
function_serialization.serialize_function(obj))
elif isinstance(obj, defun.ConcreteFunction):
proto.bare_concrete_function.CopyFrom(
function_serialization.serialize_bare_concrete_function(obj))
elif isinstance(obj, _CapturedConstant):
proto.constant.operation = obj.graph_tensor.op.name
elif isinstance(obj, tracking.CapturableResource):
proto.resource.device = obj._resource_device # pylint: disable=protected-access
else:
registered_type_proto = revived_types.serialize(obj)
if registered_type_proto is None:
# Fallback for types with no matching registration
# pylint:disable=protected-access
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
identifier=obj._object_identifier,
version=versions_pb2.VersionDef(
producer=1, min_consumer=1, bad_consumers=[]),
metadata=obj._tracking_metadata)
# pylint:enable=protected-access
proto.user_object.CopyFrom(registered_type_proto)
@tf_export("saved_model.save",
v1=["saved_model.save", "saved_model.experimental.save"])
def save(obj, export_dir, signatures=None):
# pylint: disable=line-too-long
"""Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).
Example usage:
```python
class Adder(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def add(self, x):
return x + x + 1.
to_export = Adder()
tf.saved_model.save(to_export, '/tmp/adder')
```
The resulting SavedModel is then servable with an input named "x", its value
having any shape and dtype float32.
The optional `signatures` argument controls which methods in `obj` will be
available to programs which consume `SavedModel`s, for example serving
APIs. Python functions may be decorated with
`@tf.function(input_signature=...)` and passed as signatures directly, or
lazily with a call to `get_concrete_function` on the method decorated with
`@tf.function`.
If the `signatures` argument is omitted, `obj` will be searched for
`@tf.function`-decorated methods. If exactly one `@tf.function` is found, that
method will be used as the default signature for the SavedModel. This behavior
is expected to change in the future, when a corresponding
`tf.saved_model.load` symbol is added. At that point signatures will be
completely optional, and any `@tf.function` attached to `obj` or its
dependencies will be exported for use with `load`.
When invoking a signature in an exported SavedModel, `Tensor` arguments are
identified by name. These names will come from the Python function's argument
names by default. They may be overridden by specifying a `name=...` argument
in the corresponding `tf.TensorSpec` object. Explicit naming is required if
multiple `Tensor`s are passed through a single argument to the Python
function.
The outputs of functions used as `signatures` must either be flat lists, in
which case outputs will be numbered, or a dictionary mapping string keys to
`Tensor`, in which case the keys will be used to name outputs.
Signatures are available in objects returned by `tf.saved_model.load` as a
`.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
on an object with a custom `.signatures` attribute will raise an exception.
Since `tf.keras.Model` objects are also Trackable, this function can be
used to export Keras models. For example, exporting with a signature
specified:
```python
class Model(tf.keras.Model):
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def serve(self, serialized):
...
m = Model()
tf.saved_model.save(m, '/tmp/saved_model/')
```
Exporting from a function without a fixed signature:
```python
class Model(tf.keras.Model):
@tf.function
def call(self, x):
...
m = Model()
tf.saved_model.save(
m, '/tmp/saved_model/',
signatures=m.call.get_concrete_function(
tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp")))
```
`tf.keras.Model` instances constructed from inputs and outputs already have a
signature and so do not require a `@tf.function` decorator or a `signatures`
argument. If neither are specified, the model's forward pass is exported.
```python
x = input_layer.Input((4,), name="x")
y = core.Dense(5, name="out")(x)
model = training.Model(x, y)
tf.saved_model.save(model, '/tmp/saved_model/')
# The exported SavedModel takes "x" with shape [None, 4] and returns "out"
# with shape [None, 5]
```
Variables must be tracked by assigning them to an attribute of a tracked
object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers
from `tf.keras.layers`, optimizers from `tf.train`) track their variables
automatically. This is the same tracking scheme that `tf.train.Checkpoint`
uses, and an exported `Checkpoint` object may be restored as a training
checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's
"variables/" subdirectory. Currently variables are the only stateful objects
supported by `tf.saved_model.save`, but others (e.g. tables) will be supported
in the future.
`tf.function` does not hard-code device annotations from outside the function
body, instead using the calling context's device. This means for example that
exporting a model which runs on a GPU and serving it on a CPU will generally
work, with some exceptions. `tf.device` annotations inside the body of the
function will be hard-coded in the exported model; this type of annotation is
discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with
device-specific layouts, may cause issues. Currently a `DistributionStrategy`
is another exception: active distribution strategies will cause device
placements to be hard-coded in a function. Exporting a single-device
computation and importing under a `DistributionStrategy` is not currently
supported, but may be in the future.
SavedModels exported with `tf.saved_model.save` [strip default-valued
attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes)
automatically, which removes one source of incompatibilities when the consumer
of a SavedModel is running an older TensorFlow version than the
producer. There are however other sources of incompatibilities which are not
handled automatically, such as when the exported model contains operations
which the consumer does not have definitions for.
Args:
obj: A trackable object to export.
export_dir: A directory in which to write the SavedModel.
signatures: Optional, either a `tf.function` with an input signature
specified or the result of `f.get_concrete_function` on a
`@tf.function`-decorated function `f`, in which case `f` will be used to
generate a signature for the SavedModel under the default serving
signature key. `signatures` may also be a dictionary, in which case it
maps from signature keys to either `tf.function` instances with input
signatures or concrete functions. The keys of such a dictionary may be
arbitrary strings, but will typically be from the
`tf.saved_model.signature_constants` module.
Raises:
ValueError: If `obj` is not trackable.
@compatibility(eager)
Not well supported when graph building. From TensorFlow 1.x,
`tf.compat.v1.enable_eager_execution()` should run first. Calling
tf.saved_model.save in a loop when graph building from TensorFlow 1.x will
add new save operations to the default graph each iteration.
May not be called from within a function body.
@end_compatibility
"""
if ops.inside_function():
raise AssertionError(
"tf.saved_model.save is not supported inside a traced "
"@tf.function. Move the call to the outer eagerly-executed "
"context.")
# pylint: enable=line-too-long
if not isinstance(obj, base.Trackable):
raise ValueError(
"Expected a Trackable object for export, got {}.".format(obj))
checkpoint_graph_view = _AugmentedGraphView(obj)
if signatures is None:
signatures = signature_serialization.find_function_to_export(
checkpoint_graph_view)
signatures = signature_serialization.canonicalize_signatures(signatures)
signature_serialization.validate_saveable_view(checkpoint_graph_view)
signature_map = signature_serialization.create_signature_map(signatures)
checkpoint_graph_view.add_object(
parent_node=checkpoint_graph_view.root,
name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME,
subgraph_root=signature_map)
# Use _SaveableView to provide a frozen listing of properties and functions.
# Note we run this twice since, while constructing the view the first time
# there can be side effects of creating variables.
_ = _SaveableView(checkpoint_graph_view)
saveable_view = _SaveableView(checkpoint_graph_view)
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
# compatible (no sessions) and share it with this export API rather than
# making a SavedModel proto and writing it directly.
saved_model = saved_model_pb2.SavedModel()
meta_graph_def = saved_model.meta_graphs.add()
object_saver = util.TrackableSaver(checkpoint_graph_view)
asset_info, exported_graph = _fill_meta_graph_def(
meta_graph_def, saveable_view, signatures)
saved_model.saved_model_schema_version = (
constants.SAVED_MODEL_SCHEMA_VERSION)
# So far we've just been generating protocol buffers with no I/O. Now we write
# the checkpoint, copy assets into the assets directory, and write out the
# SavedModel proto itself.
utils_impl.get_or_create_variables_dir(export_dir)
object_saver.save(utils_impl.get_variables_path(export_dir))
builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
export_dir)
path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
object_graph_proto = _serialize_object_graph(
saveable_view, asset_info.asset_index)
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
file_io.write_string_to_file(path, saved_model.SerializeToString())
# Clean reference cycles so repeated export()s don't make work for the garbage
# collector. Before this point we need to keep references to captured
# constants in the saved graph.
ops.dismantle_graph(exported_graph)