| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Exports a SavedModel from a Trackable Python object.""" |
| |
| import collections |
| import functools |
| import gc |
| import os |
| import re |
| import sys |
| import traceback |
| |
| from absl import logging |
| import numpy |
| |
| from tensorflow.core.framework import function_pb2 |
| from tensorflow.core.framework import versions_pb2 |
| from tensorflow.core.protobuf import meta_graph_pb2 |
| from tensorflow.core.protobuf import saved_model_pb2 |
| from tensorflow.core.protobuf import saved_object_graph_pb2 |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import function as defun |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import error_interpolation |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import function as framework_fn |
| from tensorflow.python.framework import meta_graph |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.framework import versions |
| from tensorflow.python.lib.io import file_io |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.saved_model import builder_impl |
| from tensorflow.python.saved_model import function_serialization |
| from tensorflow.python.saved_model import nested_structure_coder |
| from tensorflow.python.saved_model import pywrap_saved_model |
| from tensorflow.python.saved_model import registration |
| from tensorflow.python.saved_model import revived_types |
| from tensorflow.python.saved_model import save_context |
| from tensorflow.python.saved_model import save_options |
| from tensorflow.python.saved_model import signature_constants |
| from tensorflow.python.saved_model import signature_def_utils |
| from tensorflow.python.saved_model import signature_serialization |
| from tensorflow.python.saved_model import tag_constants |
| from tensorflow.python.saved_model import utils_impl |
| from tensorflow.python.saved_model.pywrap_saved_model import constants |
| from tensorflow.python.saved_model.pywrap_saved_model import metrics |
| from tensorflow.python.training.saving import checkpoint_options |
| from tensorflow.python.training.saving import functional_saver |
| from tensorflow.python.training.saving import saveable_object_util |
| from tensorflow.python.training.tracking import base |
| from tensorflow.python.training.tracking import graph_view |
| from tensorflow.python.training.tracking import trackable_utils |
| from tensorflow.python.training.tracking import tracking |
| from tensorflow.python.training.tracking import util |
| from tensorflow.python.util import compat |
| from tensorflow.python.util import object_identity |
| from tensorflow.python.util.tf_export import tf_export |
| |
| _UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant)) |
| |
| # A container for an EagerTensor constant which has been copied to the exported |
| # Graph. |
| _CapturedConstant = collections.namedtuple("_CapturedConstant", |
| ["eager_tensor", "graph_tensor"]) |
| # Container for tensors captured from external functions. |
| _CapturedTensor = collections.namedtuple("_CapturedTensor", |
| ["name", "concrete_function"]) |
| |
| # Number of untraced functions to display to user in warning message. |
| _NUM_DISPLAY_UNTRACED_FUNCTIONS = 5 |
| |
| # API label for SavedModel metrics. |
| _SAVE_V2_LABEL = "save_v2" |
| |
| |
| class _AugmentedGraphView(graph_view.ObjectGraphView): |
| """An extendable graph which also tracks functions attached to objects. |
| |
| Extensions through `add_object` appear in the object graph and any checkpoints |
| generated from it, even if they are not dependencies of the node they were |
| attached to in the saving program. For example a `.signatures` attribute is |
| added to exported SavedModel root objects without modifying the root object |
| itself. |
| |
| Also tracks functions attached to objects in the graph, through the caching |
| `list_functions` method. Enumerating functions only through this method |
| ensures that we get a consistent view of functions, even if object attributes |
| create new functions every time they are accessed. |
| """ |
| |
| def __init__(self, root): |
| if (not context.executing_eagerly() and not ops.inside_function()): |
| saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() |
| else: |
| saveables_cache = None |
| super(_AugmentedGraphView, self).__init__(root, saveables_cache) |
| # Object -> (name -> dep) |
| self._extra_dependencies = object_identity.ObjectIdentityDictionary() |
| self._functions = object_identity.ObjectIdentityDictionary() |
| # Cache shared between objects in the same object graph. This is passed to |
| # each trackable object's `_list_extra_dependencies_for_serialization` and |
| # `_list_functions_for_serialization` function. |
| self._serialization_cache = object_identity.ObjectIdentityDictionary() |
| |
| def add_object(self, parent_node, name_in_parent, subgraph_root): |
| """Attach an object to `parent_node`, overriding any existing dependency.""" |
| self._extra_dependencies.setdefault(parent_node, |
| {})[name_in_parent] = subgraph_root |
| |
| def list_children(self, obj): |
| """Overrides parent method to include extra children.""" |
| extra_dependencies = self.list_extra_children(obj) |
| extra_dependencies.update(self._extra_dependencies.get(obj, {})) |
| |
| used_names = set() |
| for name, dep in super(_AugmentedGraphView, self).list_children(obj): |
| used_names.add(name) |
| if name in extra_dependencies: |
| # Extra dependencies (except for `.signatures`, which is always added |
| # when saving) should not have naming conflicts with dependencies |
| # defined by the user. |
| if name != signature_serialization.SIGNATURE_ATTRIBUTE_NAME: |
| obj_identifier = obj._object_identifier # pylint: disable=protected-access |
| raise ValueError( |
| f"Error when exporting object {obj} with identifier " |
| f"'{obj_identifier}'. The object has an attribute named " |
| f"'{name}', which is reserved. List of all reserved attributes: " |
| f"{list(extra_dependencies.keys())}") |
| |
| yield base.TrackableReference(name, extra_dependencies[name]) |
| else: |
| yield base.TrackableReference(name, dep) |
| for name, dep in extra_dependencies.items(): |
| if name in used_names: |
| continue |
| yield base.TrackableReference(name, dep) |
| |
| def list_dependencies(self, obj): |
| """Yields `Trackables` that must be loaded before `obj`. |
| |
| Dependencies and children are both dictionaries of `Trackables`. Children |
| define the object graph structure (used in both checkpoints and SavedModel), |
| while dependency defines the order used to load the SavedModel |
| |
| Args: |
| obj: A `Trackable` object |
| |
| Yields: |
| Tuple of dependency names and trackable objects. |
| |
| Raises: |
| TypeError: if any of the returned dependencies are not instances of |
| `Trackable`. |
| """ |
| for name, dep in obj._deserialization_dependencies().items(): # pylint: disable=protected-access |
| if not isinstance(dep, base.Trackable): |
| raise TypeError( |
| f"The dependency of type {type(dep)} is not an instance `Trackable`" |
| ", and can't be saved to SavedModel. Please check the " |
| "implementation of `_deserialization_dependencies` in the parent " |
| f"object {obj}.") |
| yield name, dep |
| |
| def list_extra_children(self, obj): |
| """Returns children that are only added when exporting SavedModel.""" |
| return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access |
| self._serialization_cache) |
| |
| def list_functions(self, obj): |
| obj_functions = self._functions.get(obj, None) |
| if obj_functions is None: |
| obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access |
| self._serialization_cache) |
| self._functions[obj] = obj_functions |
| return obj_functions |
| |
| |
| class _SaveableView(object): |
| """Provides a frozen view over a trackable root. |
| |
| This class helps to create a single stable view over an object to save. The |
| saving code should access properties and functions via this class and not via |
| the original object as there are cases where an object construct their |
| trackable attributes and functions dynamically per call and will yield |
| different objects if invoked more than once. |
| |
| Changes to the graph, for example adding objects, must happen in |
| `checkpoint_view` (an `_AugmentedGraphView`) before the `_SaveableView` is |
| constructed. Changes after the `_SaveableView` has been constructed will be |
| ignored. |
| """ |
| |
| def __init__(self, checkpoint_view, options, wrapped_functions=None): |
| """Initializes a SaveableView. |
| |
| Args: |
| checkpoint_view: A GraphView object. |
| options: A SaveOptions instance. |
| wrapped_functions: Dictionary that maps concrete functions to functions |
| that do not capture cached variable values. |
| """ |
| |
| self.checkpoint_view = checkpoint_view |
| self._options = options |
| # Maps functions -> wrapped functions that capture variables |
| self._wrapped_functions = wrapped_functions or {} |
| # Run through the nodes in the object graph first for side effects of |
| # creating variables. |
| self._trace_all_concrete_functions() |
| |
| (self._trackable_objects, self.node_paths, self.node_ids, |
| self._slot_variables, self.object_names) = ( |
| self.checkpoint_view.objects_ids_and_slot_variables_and_paths()) |
| |
| self._initialize_save_and_restore_functions() |
| self._initialize_nodes_and_concrete_functions() |
| |
| # Maps names of concrete functions in the object to names of wrapped |
| # functions. When writing the SavedFunction protos, the names of the |
| # wrapped functions should be used in place of the original functions. |
| self.function_name_map = { |
| compat.as_text(original.name): compat.as_text(wrapped.name) |
| for original, wrapped in self._wrapped_functions.items()} |
| self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() |
| |
| def _initialize_save_and_restore_functions(self): |
| """Generates all checkpoint save/restore functions. |
| |
| The save and restore functions are generated in the eager context (or in the |
| user's Graph/Session) before being copied to the exported GraphDef. These |
| functions record the ops for saving/restoring the entire object or |
| individual objects (e.g. variables and hash tables). |
| |
| The global save and restore functions are generated for compatibility with |
| TF1 and loading from C++, and is saved in the `MetaGraphDef.saver_def`. |
| |
| The individual functions are generated for the Python TF2 use case, where |
| users use the loaded SavedModel as-is, or compose new models using parts |
| of the object loaded from the SavedModel. These functions are recorded in |
| the `saveable_objects` map in the `SavedObject` proto. |
| """ |
| checkpoint_factory_map, registered_savers = ( |
| graph_view.get_checkpoint_factories_and_keys(self.object_names)) |
| self._obj_to_registered_saver = object_identity.ObjectIdentityDictionary() |
| for saver_name, trackables in registered_savers.items(): |
| for trackable in trackables.values(): |
| self._obj_to_registered_saver[trackable] = saver_name |
| self._saveable_objects_map = ( |
| _gen_save_and_restore_functions(checkpoint_factory_map)) |
| |
| def _initialize_nodes_and_concrete_functions(self): |
| """Creates graph with nodes for trackable objects and functions. |
| |
| Adds functions for each trackable object to `self.nodes` and associated |
| concrete functions to `self.concrete_functions` for serialization. |
| """ |
| self.nodes = list(self._trackable_objects) |
| self.concrete_functions = [] |
| self.gradient_functions = [] |
| self.gradient_defs = [] |
| self._seen_function_names = set() |
| self._untraced_functions = [] |
| |
| for obj in self._trackable_objects: |
| for function in self.checkpoint_view.list_functions(obj).values(): |
| self._add_function_to_graph(function) |
| |
| if obj in self._saveable_objects_map: |
| for save_fn, restore_fn in self._saveable_objects_map[obj].values(): |
| self._add_function_to_graph(save_fn) |
| self._add_function_to_graph(restore_fn) |
| |
| if self._untraced_functions: |
| logging.warning( |
| "Found untraced functions such as %s while saving (showing %d of %d)." |
| " These functions will not be directly callable after loading.", |
| ", ".join(self._untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]), |
| min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self._untraced_functions)), |
| len(self._untraced_functions)) |
| |
| @property |
| def concrete_and_gradient_functions(self): |
| return self.concrete_functions + self.gradient_functions |
| |
| def _add_function_to_graph(self, function): |
| """Adds a function to serialize to the object graph. |
| |
| If `function` is a concrete function, it will be added to the list of |
| concrete functions tracked by `_SaveableView`. If the function is a |
| tf.function, any underlying concrete functions will be added to the list of |
| concrete functions for later serialization. |
| |
| Args: |
| function: a `def_function.Function` or `ConcreteFunction` |
| """ |
| # Add the function to the graph |
| if function not in self.node_ids: |
| self.node_ids[function] = len(self.nodes) |
| self.nodes.append(function) |
| |
| # Gather the concrete function(s) |
| if isinstance(function, def_function.Function): |
| concrete_functions = ( |
| function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access |
| else: |
| concrete_functions = [function] |
| |
| # Keep track of untraced functions for later reporting to the user |
| if not concrete_functions: |
| self._untraced_functions.append(function.name) |
| |
| # Add the concrete functions for later serialization |
| for concrete_function in concrete_functions: |
| # Users can attach the same tf.function to their model multiple times, |
| # so we deduplicate their underlying concrete functions. |
| if concrete_function.name not in self._seen_function_names: |
| self.concrete_functions.append(concrete_function) |
| self._seen_function_names.add(concrete_function.name) |
| |
| def _trace_all_concrete_functions(self): |
| """Trace concrete functions to force side-effects. |
| |
| Lists the concrete functions in order to: |
| - populate the cache for functions that have an input_signature |
| and have not been called |
| - force side effects of creation of concrete functions, e.g. create |
| variables on first run. |
| """ |
| for obj in self.checkpoint_view.list_objects(): |
| for function in self.checkpoint_view.list_functions(obj).values(): |
| if isinstance(function, def_function.Function): |
| function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access |
| |
| @property |
| def root(self): |
| return self.nodes[0] |
| |
| def fill_object_graph_proto(self, proto): |
| """Populate the nodes, children and slot_variables of a SavedObjectGraph.""" |
| for node_id, node in enumerate(self.nodes): |
| assert self.node_ids[node] == node_id |
| object_proto = proto.nodes.add() |
| object_proto.slot_variables.extend(self._slot_variables.get(node, ())) |
| if isinstance( |
| node, |
| (def_function.Function, defun.ConcreteFunction, _CapturedConstant, |
| _CapturedTensor)): |
| continue |
| for child in self.checkpoint_view.list_children(node): |
| child_proto = object_proto.children.add() |
| child_proto.node_id = self.node_ids[child.ref] |
| child_proto.local_name = child.name |
| for name, ref in self.checkpoint_view.list_dependencies(node): |
| child_proto = object_proto.dependencies.add() |
| child_proto.node_id = self.node_ids[ref] |
| child_proto.local_name = name |
| for local_name, ref_function in ( |
| self.checkpoint_view.list_functions(node).items()): |
| child_proto = object_proto.children.add() |
| child_proto.node_id = self.node_ids[ref_function] |
| child_proto.local_name = local_name |
| |
| if node in self._saveable_objects_map: |
| assert node not in self._obj_to_registered_saver, ( |
| "Objects can't have both SaveableObjects and a registered saver") |
| |
| for local_name, (save_fn, restore_fn) in ( |
| self._saveable_objects_map[node].items()): |
| saveable_object_proto = object_proto.saveable_objects[local_name] |
| saveable_object_proto.save_function = self.node_ids[save_fn] |
| saveable_object_proto.restore_function = self.node_ids[restore_fn] |
| |
| elif node in self._obj_to_registered_saver: |
| object_proto.registered_saver = self._obj_to_registered_saver[node] |
| |
| def map_resources(self): |
| """Makes new resource handle ops corresponding to existing resource tensors. |
| |
| Creates resource handle ops in the current default graph, whereas |
| `accessible_objects` will be from an eager context. Resource mapping adds |
| resource handle ops to the main GraphDef of a SavedModel, which allows the |
| C++ loader API to interact with resources. |
| |
| Returns: |
| A tuple of (object_map, resource_map, asset_info): |
| object_map: A dictionary mapping from object in `accessible_objects` to |
| replacement objects created to hold the new resource tensors. |
| resource_map: A dictionary mapping from resource tensors extracted from |
| `accessible_objects` to newly created resource tensors. |
| asset_info: An _AssetInfo tuple describing external assets referenced |
| from accessible_objects. |
| """ |
| # Only makes sense when adding to the export Graph |
| assert not context.executing_eagerly() |
| # TODO(allenl): Handle MirroredVariables and other types of variables which |
| # may need special casing. |
| object_map = object_identity.ObjectIdentityDictionary() |
| resource_map = {} |
| asset_info = _AssetInfo( |
| asset_defs=[], |
| asset_initializers_by_resource={}, |
| asset_filename_map={}, |
| asset_index={}) |
| |
| for node_id, obj in enumerate(self.nodes): |
| if isinstance(obj, tracking.Asset): |
| _process_asset(obj, asset_info, resource_map) |
| self.captured_tensor_node_ids[obj.asset_path] = node_id |
| elif isinstance(obj, base.Trackable): |
| node_object_map, node_resource_map = obj._map_resources(self._options) # pylint: disable=protected-access |
| for capturable in node_resource_map.keys(): |
| self.captured_tensor_node_ids[capturable] = node_id |
| object_map.update(node_object_map) |
| resource_map.update(node_resource_map) |
| |
| for concrete_function in self.concrete_functions: |
| if not concrete_function.graph.saveable: |
| raise ValueError( |
| (f"Unable to save function {concrete_function.name} for the " |
| "following reason(s):\n" + |
| "\n".join(concrete_function.graph.saving_errors))) |
| for capture in concrete_function.captured_inputs: |
| if (tensor_util.is_tf_type(capture) and |
| capture.dtype not in _UNCOPIABLE_DTYPES and |
| capture not in self.captured_tensor_node_ids): |
| if hasattr(capture, "_cached_variable"): |
| if concrete_function not in self._wrapped_functions: |
| wrapped = self._wrapped_functions[concrete_function] = ( |
| function_serialization.wrap_cached_variables( |
| concrete_function)) |
| self.function_name_map[compat.as_text(concrete_function.name)] = ( |
| compat.as_text(wrapped.name)) |
| continue |
| capture_constant_value = tensor_util.constant_value(capture) |
| if capture_constant_value is None: |
| raise ValueError( |
| f"Unable to save function {concrete_function.name} because it " |
| f"captures graph tensor {capture} from a parent function which " |
| "cannot be converted to a constant with `tf.get_static_value`.") |
| |
| if numpy.prod(capture.shape.as_list()) > 1 and numpy.all( |
| capture_constant_value == capture_constant_value.flat[0]): |
| # For the common case of a constant array filled with the same |
| # value, rebuidling the constant op specifically with the shape arg, |
| # since otherwise the whole array is written into the node def, |
| # causing performance and graph proto size issues (protos cannot be |
| # bigger than 2GB). |
| copied_tensor = constant_op.constant( |
| capture_constant_value.flat[0], |
| dtype=capture.dtype, |
| shape=capture.shape) |
| else: |
| copied_tensor = constant_op.constant(capture_constant_value) |
| |
| node = _CapturedConstant( |
| eager_tensor=capture, graph_tensor=copied_tensor) |
| self.add_capture_and_node(capture, node) |
| resource_map[capture] = copied_tensor |
| |
| self.concrete_functions = [ |
| self._wrapped_functions.get(x, x) for x in self.concrete_functions |
| ] |
| return object_map, resource_map, asset_info |
| |
| def add_capture_and_node(self, capture, node): |
| node_id = len(self.nodes) |
| self.nodes.append(node) |
| self.node_ids[capture] = node_id |
| self.node_ids[node] = node_id |
| self.captured_tensor_node_ids[capture] = node_id |
| return node_id |
| |
| |
| def _gen_save_and_restore_functions(checkpoint_factory_map): |
| """Generates global and individual save/restore concrete functions. |
| |
| The global functions records the ops to save and restore the entire object to |
| a file prefix, while the individual functions save and restore value tensors |
| for resources. |
| |
| This function is intended to run on the output of |
| `graph_view.get_checkpoint_factories_and_keys(object_names)`, which returns |
| the generated a map of `_CheckpointFactoryData`. |
| |
| Args: |
| checkpoint_factory_map: A dictionary mapping trackable objects to |
| _CheckpointFactoryData. |
| |
| Returns: |
| Tuple of ( |
| saveable_fn_map: Maps obj -> factory name -> (concrete save, restore) |
| ) |
| """ |
| # Maps obj -> factory attribute_name -> (concrete save, concrete restore) |
| # This |
| saveable_fn_map = object_identity.ObjectIdentityDictionary() |
| |
| for obj, factory_data_list in checkpoint_factory_map.items(): |
| for factory_data in factory_data_list: |
| saveable_factory = factory_data.factory |
| attribute_name = factory_data.name |
| |
| # If object revives as a resource (or TPU/Mirrored) variable, |
| # there is no need to trace the save and restore functions. |
| if (resource_variable_ops.is_resource_variable(obj) or |
| resource_variable_ops.is_resource_variable(saveable_factory) or |
| not callable(saveable_factory)): |
| continue |
| concrete_save, concrete_restore = ( |
| saveable_object_util.trace_save_restore_functions( |
| saveable_factory, obj)) |
| if not concrete_save: |
| continue |
| saveable_fn_map.setdefault(obj, {})[attribute_name] = ( |
| concrete_save, concrete_restore) |
| return saveable_fn_map |
| |
| |
| def _tensor_dict_to_tensorinfo(tensor_dict): |
| return { |
| key: utils_impl.build_tensor_info_internal(value) |
| for key, value in tensor_dict.items() |
| } |
| |
| |
| def _map_captures_to_created_tensors(original_captures, resource_map): |
| """Maps eager tensors captured by a function to Graph resources for export. |
| |
| Args: |
| original_captures: A dictionary mapping from tensors captured by the |
| function to interior placeholders for those tensors (inside the function |
| body). |
| resource_map: A dictionary mapping from resource tensors owned by the eager |
| context to resource tensors in the exported graph. |
| |
| Returns: |
| A list of stand-in tensors which belong to the exported graph, corresponding |
| to the function's captures. |
| |
| Raises: |
| AssertionError: If the function references a resource which is not part of |
| `resource_map`. |
| """ |
| export_captures = [] |
| for exterior, interior in original_captures: |
| mapped_resource = resource_map.get(exterior, None) |
| if mapped_resource is None: |
| trackable_referrers = [] |
| # Try to figure out where the resource came from by iterating over objects |
| # which reference it. This is slow and doesn't help us figure out how to |
| # match it to other objects when loading the SavedModel as a checkpoint, |
| # so we can't continue saving. But we can at least tell the user what |
| # needs attaching. |
| for primary_referrer in gc.get_referrers(exterior): |
| if isinstance(primary_referrer, base.Trackable): |
| trackable_referrers.append(primary_referrer) |
| for secondary_referrer in gc.get_referrers(primary_referrer): |
| if isinstance(secondary_referrer, base.Trackable): |
| trackable_referrers.append(secondary_referrer) |
| raise AssertionError( |
| "Tried to export a function which references 'untracked' resource " |
| f"{interior}. TensorFlow objects (e.g. tf.Variable) captured by " |
| "functions must be 'tracked' by assigning them to an attribute of a " |
| "tracked object or assigned to an attribute of the main object " |
| "directly.\n\n Trackable Python objects referring to this tensor " |
| "(from gc.get_referrers, limited to two hops):\n{}".format("\n".join( |
| [repr(obj) for obj in trackable_referrers]))) |
| export_captures.append(mapped_resource) |
| return export_captures |
| |
| |
| def _to_safe_name_scope(signature_key, user_input_name): |
| """Creates a sanitized name scope from user signature and input names. |
| |
| Concatenates signature and input names, sanitizing as needed to be a valid |
| scope name. |
| |
| Args: |
| signature_key: The user-provided key for the signature. |
| user_input_name: The user-provided name for the input placeholder. |
| |
| Returns: |
| A name scope that is safe to be used in tf.name_scope(). |
| """ |
| name_scope = "{}_{}".format(signature_key, user_input_name) |
| if re.match(r"^[A-Za-z0-9.][A-Za-z0-9_.\\-]*$", name_scope): |
| return name_scope |
| invalid_prefix_stripped = re.sub(r"^[^A-Za-z0-9.]*", "", name_scope) |
| return re.sub(r"[^A-Za-z0-9_.\\-]", "_", invalid_prefix_stripped) |
| |
| |
| def _map_function_arguments_to_created_inputs(function_arguments, signature_key, |
| function_name): |
| """Creates exterior placeholders in the exported graph for function arguments. |
| |
| Functions have two types of inputs: tensors captured from the outside (eager) |
| context, and arguments to the function which we expect to receive from the |
| user at each call. `_map_captures_to_created_tensors` replaces |
| captured tensors with stand-ins (typically these are resource dtype tensors |
| associated with variables). `_map_function_inputs_to_created_inputs` runs over |
| every argument, creating a new placeholder for each which will belong to the |
| exported graph rather than the function body. |
| |
| Args: |
| function_arguments: A list of argument placeholders in the function body. |
| signature_key: The name of the signature being exported, for error messages. |
| function_name: The name of the function, for error messages. |
| |
| Returns: |
| A tuple of (mapped_inputs, exterior_placeholders) |
| mapped_inputs: A list with entries corresponding to `function_arguments` |
| containing all of the inputs of the function gathered from the exported |
| graph (both captured resources and arguments). |
| exterior_argument_placeholders: A dictionary mapping from argument names |
| to placeholders in the exported graph, containing the explicit arguments |
| to the function which a user is expected to provide. |
| |
| Raises: |
| ValueError: If argument names are not unique. |
| """ |
| # `exterior_argument_placeholders` holds placeholders which are outside the |
| # function body, directly contained in a MetaGraph of the SavedModel. The |
| # function body itself contains nearly identical placeholders used when |
| # running the function, but these exterior placeholders allow Session-based |
| # APIs to call the function using feeds and fetches which name Tensors in the |
| # MetaGraph. |
| exterior_argument_placeholders = {} |
| mapped_inputs = [] |
| for placeholder in function_arguments: |
| # `export_captures` contains an exhaustive set of captures, so if we don't |
| # find the input there then we now know we have an argument. |
| user_input_name = compat.as_str_any( |
| placeholder.op.get_attr("_user_specified_name")) |
| # If the internal placeholders for a function have names which were |
| # uniquified by TensorFlow, then a single user-specified argument name |
| # must refer to multiple Tensors. The resulting signatures would be |
| # confusing to call. Instead, we throw an exception telling the user to |
| # specify explicit names. |
| if user_input_name != placeholder.op.name: |
| # This should be unreachable, since concrete functions may not be |
| # generated with non-unique argument names. |
| raise ValueError( |
| "Got non-flat/non-unique argument names for SavedModel signature " |
| f"'{signature_key}': more than one argument to " |
| f"'{compat.as_str_any(function_name)}' was named " |
| f"'{user_input_name}'. " |
| "Signatures have one Tensor per named input, so to have " |
| "predictable names Python functions used to generate these " |
| "signatures should avoid *args and Tensors in nested " |
| "structures unless unique names are specified for each. Use " |
| "tf.TensorSpec(..., name=...) to provide a name for a Tensor " |
| "input.") |
| arg_placeholder = array_ops.placeholder( |
| shape=placeholder.shape, |
| dtype=placeholder.dtype, |
| name=_to_safe_name_scope(signature_key, user_input_name)) |
| exterior_argument_placeholders[user_input_name] = arg_placeholder |
| mapped_inputs.append(arg_placeholder) |
| return mapped_inputs, exterior_argument_placeholders |
| |
| |
| def _call_function_with_mapped_captures(function, args, resource_map): |
| """Calls `function` in the exported graph, using mapped resource captures.""" |
| export_captures = _map_captures_to_created_tensors(function.graph.captures, |
| resource_map) |
| # Calls the function quite directly, since we have new captured resource |
| # tensors we need to feed in which weren't part of the original function |
| # definition. |
| # pylint: disable=protected-access |
| outputs = function._call_flat(args, export_captures) |
| # pylint: enable=protected-access |
| return outputs |
| |
| |
| def _generate_signatures(signature_functions, resource_map): |
| """Validates and calls `signature_functions` in the default graph. |
| |
| Args: |
| signature_functions: A dictionary mapping string keys to concrete TensorFlow |
| functions (e.g. from `signature_serialization.canonicalize_signatures`) |
| which will be used to generate SignatureDefs. |
| resource_map: A dictionary mapping from resource tensors in the eager |
| context to resource tensors in the Graph being exported. This dictionary |
| is used to re-bind resources captured by functions to tensors which will |
| exist in the SavedModel. |
| |
| Returns: |
| Each function in the `signature_functions` dictionary is called with |
| placeholder Tensors, generating a function call operation and output |
| Tensors. The placeholder Tensors, the function call operation, and the |
| output Tensors from the function call are part of the default Graph. |
| |
| This function then returns a dictionary with the same structure as |
| `signature_functions`, with the concrete functions replaced by SignatureDefs |
| implicitly containing information about how to call each function from a |
| TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference |
| the generated placeholders and Tensor outputs by name. |
| |
| The caller is expected to include the default Graph set while calling this |
| function as a MetaGraph in a SavedModel, including the returned |
| SignatureDefs as part of that MetaGraph. |
| """ |
| signatures = {} |
| for signature_key, function in sorted(signature_functions.items()): |
| if function.graph.captures: |
| argument_inputs = function.graph.inputs[:-len(function.graph.captures)] |
| else: |
| argument_inputs = function.graph.inputs |
| mapped_inputs, exterior_argument_placeholders = ( |
| _map_function_arguments_to_created_inputs(argument_inputs, |
| signature_key, function.name)) |
| outputs = _call_function_with_mapped_captures( |
| function, mapped_inputs, resource_map) |
| signatures[signature_key] = signature_def_utils.build_signature_def( |
| _tensor_dict_to_tensorinfo(exterior_argument_placeholders), |
| _tensor_dict_to_tensorinfo(outputs), |
| method_name=signature_constants.PREDICT_METHOD_NAME) |
| return signatures |
| |
| |
| def _trace_resource_initializers(accessible_objects): |
| """Create concrete functions from `CapturableResource` objects.""" |
| resource_initializers = [] |
| |
| def _wrap_initializer(obj): |
| obj._initialize() # pylint: disable=protected-access |
| return constant_op.constant(1.) # Dummy control output |
| |
| def _wrap_obj_initializer(obj): |
| return lambda: _wrap_initializer(obj) |
| |
| for obj in accessible_objects: |
| if isinstance(obj, tracking.CapturableResource): |
| resource_initializers.append( |
| def_function.function( |
| _wrap_obj_initializer(obj), |
| # All inputs are captures. |
| input_signature=[]).get_concrete_function()) |
| return resource_initializers |
| |
| |
| _AssetInfo = collections.namedtuple( |
| "_AssetInfo", |
| [ |
| # List of AssetFileDef protocol buffers |
| "asset_defs", |
| # Map from asset variable resource Tensors to their init ops |
| "asset_initializers_by_resource", |
| # Map from base asset filenames to full paths |
| "asset_filename_map", |
| # Map from Asset to index of corresponding AssetFileDef |
| "asset_index" |
| ]) |
| |
| |
| def _process_asset(trackable_asset, asset_info, resource_map): |
| """Add `trackable_asset` to `asset_info` and `resource_map`.""" |
| original_path_tensor = trackable_asset.asset_path |
| original_path = tensor_util.constant_value(original_path_tensor) |
| try: |
| original_path = str(original_path.astype(str)) |
| except AttributeError: |
| # Already a string rather than a numpy array |
| pass |
| path = builder_impl.get_asset_filename_to_add( |
| asset_filepath=original_path, |
| asset_filename_map=asset_info.asset_filename_map) |
| # TODO(andresp): Instead of mapping 1-1 between trackable asset |
| # and asset in the graph def consider deduping the assets that |
| # point to the same file. |
| asset_path_initializer = array_ops.placeholder( |
| shape=original_path_tensor.shape, |
| dtype=dtypes.string, |
| name="asset_path_initializer") |
| asset_variable = resource_variable_ops.ResourceVariable( |
| asset_path_initializer) |
| asset_info.asset_filename_map[path] = original_path |
| asset_def = meta_graph_pb2.AssetFileDef() |
| asset_def.filename = path |
| asset_def.tensor_info.name = asset_path_initializer.name |
| asset_info.asset_defs.append(asset_def) |
| asset_info.asset_initializers_by_resource[original_path_tensor] = ( |
| asset_variable.initializer) |
| asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1 |
| resource_map[original_path_tensor] = asset_variable |
| |
| |
| def _iterate_op_types(fn): |
| """Iterates through each op in the function and returns the op type and op.""" |
| if isinstance(fn, framework_fn._DefinedFunction): # pylint: disable=protected-access |
| for node in fn.definition.node_def: |
| op_type = node.attr["_gradient_op_type"].s |
| if op_type: |
| raise ValueError( |
| "Unable to save gradient functions when exporting a " |
| "_DefinedFunction (generally created through graph freezing utils " |
| "or through V1 graph importers). Please save with " |
| "`options=tf.SaveOptions(experimental_custom_gradients=False)`") |
| else: |
| for op in fn.graph.get_operations(): |
| try: |
| op_type = op.get_attr("_gradient_op_type") |
| except ValueError: |
| continue |
| yield op_type, op |
| |
| |
| def _get_outer_most_capture(fn, capture, func_graph_map): |
| """Tries to find the original captured tensor if capture more than once.""" |
| outer_fn = fn |
| while outer_fn is not None and not isinstance(capture, ops.EagerTensor): |
| if capture.graph is not outer_fn.graph: |
| outer_fn = func_graph_map.get(outer_fn.graph.outer_graph) |
| else: |
| try: |
| capture_index = outer_fn.graph.internal_captures.index(capture) |
| except ValueError: |
| break # Capture is a tensor inside function, and not captured from |
| # another external function |
| capture = outer_fn.graph.external_captures[capture_index] |
| outer_fn = func_graph_map.get(outer_fn.graph.outer_graph) |
| return outer_fn, capture |
| |
| |
| def _trace_gradient_functions(graph, saveable_view): |
| """Traces gradient functions and records them in the SaveableView.""" |
| functions = list(graph._functions.values()) # pylint: disable=protected-access |
| func_graph_map = {f.graph: f for f in functions if hasattr(f, "graph")} |
| seen_op_types = set() |
| |
| for fn in functions: |
| for op_type, op in _iterate_op_types(fn): |
| if op_type in seen_op_types: |
| continue |
| seen_op_types.add(op_type) |
| |
| try: |
| custom_gradient = ops.gradient_registry.lookup(op_type) |
| except LookupError: |
| continue |
| |
| try: |
| grad_fn = ( |
| def_function.function(custom_gradient).get_concrete_function( |
| None, *op.inputs)) |
| except Exception as exc: |
| traceback.print_exc() |
| raise ValueError( |
| "Error when tracing gradients for SavedModel.\n\n" |
| "Check the error log to see the error that was raised when " |
| "converting a gradient function to a concrete function. You may " |
| "need to update the custom gradient, or disable saving gradients " |
| "with the option tf.saved_model.SaveOptions(custom_gradients=False)" |
| f".\n\tProblematic op name: {op.name}\n\tGradient inputs: " |
| f"{op.inputs}") from exc |
| |
| # The gradient function will capture all intermediate values. These |
| # captures be serialized so that they can be re-bound to the function when |
| # loading. |
| bad_captures = [] |
| for capture in grad_fn.captured_inputs: |
| if capture.dtype in _UNCOPIABLE_DTYPES: |
| continue |
| # Tries to find the outermost capture in case the tensor is a constant |
| # or not actually captured in the current function (this could happen if |
| # the function is a while loop body, in which case the captured input |
| # is not the internal captured tensor). |
| outer_fn, outer_capture = _get_outer_most_capture( |
| fn, capture, func_graph_map) |
| if outer_fn is None or isinstance(outer_capture, ops.EagerTensor): |
| if outer_capture not in saveable_view.captured_tensor_node_ids: |
| raise ValueError(f"Found invalid capture {outer_capture} when " |
| "saving custom gradients.") |
| saveable_view.captured_tensor_node_ids[capture] = ( |
| saveable_view.captured_tensor_node_ids[outer_capture]) |
| elif outer_capture.graph is outer_fn.graph: |
| capture_name = outer_capture.name |
| # It's possible for EagerDefinedFunctions to save different names for |
| # input tensors when serialized to FunctionDef (all non-alphanumeric |
| # characters are converted to '_'). |
| if isinstance(outer_fn, defun._EagerDefinedFunction): # pylint:disable=protected-access |
| try: |
| arg_index = outer_fn.graph.inputs.index(outer_capture) |
| capture_name = outer_fn.signature.input_arg[arg_index].name + ":0" |
| except ValueError: |
| pass |
| |
| node = _CapturedTensor(capture_name, outer_fn.name) |
| saveable_view.add_capture_and_node(capture, node) |
| else: |
| bad_captures.append(capture.name) |
| if not bad_captures: |
| grad_fn.add_to_graph(graph) |
| else: |
| raise ValueError( |
| f"Cannot save custom gradient {op_type} called in function {fn} " |
| "because SavedModel is unable to serialize the captured " |
| f"inputs: {bad_captures}") |
| |
| saveable_view.gradient_functions.append(grad_fn) |
| func_graph_map[grad_fn.graph] = grad_fn |
| |
| grad_def = function_pb2.RegisteredGradient() |
| grad_def.gradient_func = grad_fn.name |
| grad_def.registered_op_type = op_type |
| saveable_view.gradient_defs.append(grad_def) |
| |
| |
| def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, |
| namespace_whitelist, save_custom_gradients): |
| """Generates a MetaGraph which calls `signature_functions`. |
| |
| Args: |
| meta_graph_def: The MetaGraphDef proto to fill. |
| saveable_view: The _SaveableView being exported. |
| signature_functions: A dictionary mapping signature keys to concrete |
| functions containing signatures to add to the MetaGraph. |
| namespace_whitelist: List of strings containing whitelisted op namespaces. |
| save_custom_gradients: Whether to save custom gradients. |
| |
| Returns: |
| A tuple of (_AssetInfo, Graph) containing the captured assets and |
| exported Graph generated from tracing the saveable_view. |
| """ |
| # List objects from the eager context to make sure Optimizers give us the |
| # right Graph-dependent variables. |
| accessible_objects = saveable_view.nodes |
| resource_initializer_functions = _trace_resource_initializers( |
| accessible_objects) |
| exported_graph = ops.Graph() |
| resource_initializer_ops = [] |
| with exported_graph.as_default(): |
| object_map, resource_map, asset_info = saveable_view.map_resources() |
| for resource_initializer_function in resource_initializer_functions: |
| asset_dependencies = [] |
| for capture in resource_initializer_function.graph.external_captures: |
| asset_initializer = asset_info.asset_initializers_by_resource.get( |
| capture, None) |
| if asset_initializer is not None: |
| asset_dependencies.append(asset_initializer) |
| with ops.control_dependencies(asset_dependencies): |
| resource_initializer_ops.append( |
| _call_function_with_mapped_captures(resource_initializer_function, |
| [], resource_map)) |
| resource_initializer_ops.extend( |
| asset_info.asset_initializers_by_resource.values()) |
| with ops.control_dependencies(resource_initializer_ops): |
| init_op = control_flow_ops.no_op() |
| # Add the same op to the main_op collection and to the init_op |
| # signature. The collection is for compatibility with older loader APIs; |
| # only one will be executed. |
| meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append( |
| init_op.name) |
| meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( |
| signature_def_utils.op_signature_def(init_op, |
| constants.INIT_OP_SIGNATURE_KEY)) |
| |
| # Saving an object-based checkpoint again gathers variables. We need to do the |
| # gathering from the eager context so Optimizers save the right set of |
| # variables, but want any operations associated with the save/restore to be in |
| # the exported graph (thus the `to_graph` argument). |
| call_with_mapped_captures = functools.partial( |
| _call_function_with_mapped_captures, resource_map=resource_map) |
| named_saveable_objects, registered_savers = ( |
| saveable_view.checkpoint_view.frozen_saveables_and_savers( |
| object_map=object_map, to_graph=exported_graph, |
| call_with_mapped_captures=call_with_mapped_captures)) |
| saver = functional_saver.MultiDeviceSaver(named_saveable_objects, |
| registered_savers, |
| call_with_mapped_captures) |
| |
| with exported_graph.as_default(): |
| signatures = _generate_signatures(signature_functions, resource_map) |
| for concrete_function in saveable_view.concrete_functions: |
| concrete_function.add_to_graph() |
| if save_custom_gradients: |
| _trace_gradient_functions(exported_graph, saveable_view) |
| saver_def = saver.to_proto() |
| meta_graph_def.saver_def.CopyFrom(saver_def) |
| |
| # At this point all nodes that can be added to the SavedObjectGraph have been |
| # added, so run the deserialization depenency validation. |
| _validate_dependencies(saveable_view) |
| |
| graph_def = exported_graph.as_graph_def(add_shapes=True) |
| graph_def.library.registered_gradients.extend(saveable_view.gradient_defs) |
| _verify_ops(graph_def, namespace_whitelist) |
| |
| meta_graph_def.graph_def.CopyFrom(graph_def) |
| meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) |
| meta_graph_def.meta_info_def.tensorflow_version = versions.__version__ |
| meta_graph_def.meta_info_def.tensorflow_git_version = ( |
| versions.__git_version__) |
| # We currently always strip default attributes. |
| meta_graph_def.meta_info_def.stripped_default_attrs = True |
| meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( |
| meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)) |
| meta_graph_def.asset_file_def.extend(asset_info.asset_defs) |
| for signature_key, signature in signatures.items(): |
| meta_graph_def.signature_def[signature_key].CopyFrom(signature) |
| meta_graph.strip_graph_default_valued_attrs(meta_graph_def) |
| # store tensor_content in litle endian format |
| if sys.byteorder == "big": |
| utils_impl.swap_function_tensor_content(meta_graph_def, "big", "little") |
| return asset_info, exported_graph |
| |
| |
| def _verify_ops(graph_def, namespace_whitelist): |
| """Verifies that all namespaced ops in the graph are whitelisted. |
| |
| Args: |
| graph_def: the GraphDef to validate. |
| namespace_whitelist: a list of namespaces to allow. If `None`, all will be |
| allowed. If an op does not have a namespace, it will be allowed. |
| |
| Raises: |
| ValueError: If the graph contains ops that violate the whitelist. |
| """ |
| # By default, if the user has not specified a whitelist, we want to allow |
| # everything. We check for None directly rather than falseness, since the |
| # user may instead want to pass an empty list to disallow all custom |
| # namespaced ops. |
| if namespace_whitelist is None: |
| return |
| |
| invalid_ops = [] |
| invalid_namespaces = set() |
| |
| all_operations = [] |
| all_operations.extend(meta_graph.ops_used_by_graph_def(graph_def)) |
| |
| for op in all_operations: |
| if ">" in op: |
| namespace = op.split(">")[0] |
| if namespace not in namespace_whitelist: |
| invalid_ops.append(op) |
| invalid_namespaces.add(namespace) |
| if invalid_ops: |
| raise ValueError( |
| "Attempted to save ops from non-whitelisted namespaces to SavedModel: " |
| f"{invalid_ops}.\nPlease verify that these ops should be saved, since " |
| "they must be available when loading the SavedModel. If loading from " |
| "Python, you must import the library defining these ops. From C++, " |
| "link the custom ops to the serving binary. Once you've confirmed this," |
| " add the following namespaces to the `namespace_whitelist` " |
| f"argument in tf.saved_model.SaveOptions: {invalid_namespaces}.") |
| |
| |
| def _validate_dependencies(saveble_view): |
| """Ensures that the dependencies can be topologically sorted for loading.""" |
| dependency_map = {} |
| for node in saveble_view.nodes: |
| node_id = saveble_view.node_ids[node] |
| deps = dependency_map[node_id] = [] |
| # TODO(kathywu): Remove once all of these have been converted to trackable. |
| if isinstance( |
| node, |
| (def_function.Function, defun.ConcreteFunction, _CapturedConstant, |
| _CapturedTensor)): |
| continue # These are not `Trackable` and therefore have no dependencies. |
| for _, dep in saveble_view.checkpoint_view.list_dependencies(node): |
| if dep not in saveble_view.node_ids: |
| node_path = trackable_utils.pretty_print_node_path( |
| saveble_view.node_paths[node]) |
| raise ValueError( |
| f"Found an untracked dependency. Object {node_path} depends " |
| f"on {dep}, but this dependency isn't listed as a child. " |
| "Please track this child by overriding `_checkpoint_dependencies` " |
| "or use `._track_trackable`.") |
| deps.append(saveble_view.node_ids[dep]) |
| try: |
| trackable_utils.order_by_dependency(dependency_map) |
| except trackable_utils.CyclicDependencyError as err: |
| pretty_printed_nodes = [] |
| pretty_printed_dependencies = [] |
| |
| for x, deps in err.leftover_dependency_map.items(): |
| node_path = trackable_utils.pretty_print_node_path( |
| saveble_view.node_paths[saveble_view.nodes[x]]) |
| pretty_printed_nodes.append( |
| f"\tNode {x} = {node_path} (type {type(saveble_view.nodes[x])})") |
| pretty_printed_dependencies.append( |
| f"\tNode {x} depends on nodes {deps}") |
| pretty_printed_nodes = "\n".join(pretty_printed_nodes) |
| pretty_printed_dependencies = "\n".join(pretty_printed_dependencies) |
| raise ValueError( |
| "There is one or more dependency cycle in the saved Trackable object. " |
| "Saving cannot continue until this cycle is resolved." |
| f"\n>> Unresolved nodes:\n{pretty_printed_nodes}" |
| f"\n>> Unresolved cyclic dependencies:\n{pretty_printed_dependencies}") |
| |
| |
| def _serialize_object_graph(saveable_view, asset_file_def_index): |
| """Save a SavedObjectGraph proto for `root`.""" |
| # SavedObjectGraph is similar to the TrackableObjectGraph proto in the |
| # checkpoint. It will eventually go into the SavedModel. |
| proto = saved_object_graph_pb2.SavedObjectGraph() |
| saveable_view.fill_object_graph_proto(proto) |
| |
| coder = nested_structure_coder.StructureCoder() |
| for concrete_function in saveable_view.concrete_and_gradient_functions: |
| name = compat.as_text(concrete_function.name) |
| name = saveable_view.function_name_map.get(name, name) |
| serialized = function_serialization.serialize_concrete_function( |
| concrete_function, saveable_view.captured_tensor_node_ids, coder) |
| if serialized is not None: |
| proto.concrete_functions[name].CopyFrom(serialized) |
| |
| for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): |
| _write_object_proto(obj, obj_proto, asset_file_def_index, |
| saveable_view.function_name_map) |
| return proto |
| |
| |
| def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): |
| """Saves an object into SavedObject proto.""" |
| if isinstance(obj, tracking.Asset): |
| proto.asset.SetInParent() |
| proto.asset.asset_file_def_index = asset_file_def_index[obj] |
| elif resource_variable_ops.is_resource_variable(obj): |
| options = save_context.get_save_options() |
| obj._write_object_proto(proto, options) # pylint: disable=protected-access |
| elif isinstance(obj, def_function.Function): |
| proto.function.CopyFrom(function_serialization.serialize_function( |
| obj, function_name_map)) |
| elif isinstance(obj, defun.ConcreteFunction): |
| proto.bare_concrete_function.CopyFrom( |
| function_serialization.serialize_bare_concrete_function( |
| obj, function_name_map)) |
| elif isinstance(obj, _CapturedConstant): |
| proto.constant.operation = obj.graph_tensor.op.name |
| elif isinstance(obj, _CapturedTensor): |
| proto.captured_tensor.name = obj.name |
| proto.captured_tensor.concrete_function = obj.concrete_function |
| elif isinstance(obj, tracking.CapturableResource): |
| proto.resource.device = obj._resource_device # pylint: disable=protected-access |
| else: |
| registered_type_proto = revived_types.serialize(obj) |
| if registered_type_proto is None: |
| # Fallback for types with no matching registration |
| # pylint:disable=protected-access |
| registered_type_proto = saved_object_graph_pb2.SavedUserObject( |
| identifier=obj._object_identifier, |
| version=versions_pb2.VersionDef( |
| producer=1, min_consumer=1, bad_consumers=[])) |
| # pylint:enable=protected-access |
| proto.user_object.CopyFrom(registered_type_proto) |
| |
| registered_name = registration.get_registered_class_name(obj) |
| if registered_name: |
| proto.registered_name = registered_name |
| serialized_user_proto = obj._serialize_to_proto() # pylint: disable=protected-access |
| if serialized_user_proto is not None: |
| proto.serialized_user_proto.Pack(serialized_user_proto) |
| |
| |
| def _export_debug_info(exported_graph, export_dir): |
| """Exports debug information from graph to file. |
| |
| Creates and writes GraphDebugInfo with traces for ops in all functions of the |
| exported_graph. |
| |
| Args: |
| exported_graph: A Graph that has been created by tracing a saveable view. |
| export_dir: SavedModel directory in which to write the debug info. |
| """ |
| exported_operations = [] |
| for fn_name in exported_graph._functions: # pylint: disable=protected-access |
| fn = exported_graph._get_function(fn_name) # pylint: disable=protected-access |
| if not isinstance(fn, defun._EagerDefinedFunction): # pylint: disable=protected-access |
| continue |
| |
| fn_graph = fn.graph |
| for fn_op in fn_graph.get_operations(): |
| exported_operations.append((fn_name, fn_op)) |
| |
| graph_debug_info = error_interpolation.create_graph_debug_info_def( |
| exported_operations) |
| file_io.atomic_write_string_to_file( |
| file_io.join( |
| utils_impl.get_or_create_debug_dir(export_dir), |
| constants.DEBUG_INFO_FILENAME_PB), |
| graph_debug_info.SerializeToString(deterministic=True)) |
| |
| |
| @tf_export( |
| "saved_model.save", |
| v1=["saved_model.save", "saved_model.experimental.save"]) |
| def save(obj, export_dir, signatures=None, options=None): |
| # pylint: disable=line-too-long |
| """Exports a [tf.Module](https://www.tensorflow.org/api_docs/python/tf/Module) (and subclasses) `obj` to [SavedModel format](https://www.tensorflow.org/guide/saved_model#the_savedmodel_format_on_disk). |
| |
| The `obj` must inherit from the [`Trackable` class](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/tracking/base.py#L591). |
| |
| Example usage: |
| |
| >>> class Adder(tf.Module): |
| ... @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)]) |
| ... def add(self, x): |
| ... return x + x |
| |
| >>> model = Adder() |
| >>> tf.saved_model.save(model, '/tmp/adder') |
| |
| The resulting SavedModel is then servable with an input named "x", a scalar |
| with dtype float32. |
| |
| _Signatures_ |
| |
| Signatures define the input and output types for a computation. The optional |
| save `signatures` argument controls which methods in `obj` will be |
| available to programs which consume `SavedModel`s, for example, serving |
| APIs. Python functions may be decorated with |
| `@tf.function(input_signature=...)` and passed as signatures directly, or |
| lazily with a call to `get_concrete_function` on the method decorated with |
| `@tf.function`. |
| |
| Example: |
| |
| >>> class Adder(tf.Module): |
| ... @tf.function |
| ... def add(self, x): |
| ... return x + x |
| |
| >>> model = Adder() |
| >>> tf.saved_model.save( |
| ... model, '/tmp/adder',signatures=model.add.get_concrete_function( |
| ... tf.TensorSpec([], tf.float32))) |
| |
| If a `@tf.function` does not have an input signature and |
| `get_concrete_function` is not called on that method, the function will not |
| be directly callable in the restored SavedModel. |
| |
| Example: |
| |
| >>> class Adder(tf.Module): |
| ... @tf.function |
| ... def add(self, x): |
| ... return x + x |
| |
| >>> model = Adder() |
| >>> tf.saved_model.save(model, '/tmp/adder') |
| >>> restored = tf.saved_model.load('/tmp/adder') |
| >>> restored.add(1.) |
| Traceback (most recent call last): |
| ... |
| ValueError: Found zero restored functions for caller function. |
| |
| If the `signatures` argument is omitted, `obj` will be searched for |
| `@tf.function`-decorated methods. If exactly one traced `@tf.function` is |
| found, that method will be used as the default signature for the SavedModel. |
| Else, any `@tf.function` attached to `obj` or its dependencies will be |
| exported for use with `tf.saved_model.load`. |
| |
| When invoking a signature in an exported SavedModel, `Tensor` arguments are |
| identified by name. These names will come from the Python function's argument |
| names by default. They may be overridden by specifying a `name=...` argument |
| in the corresponding `tf.TensorSpec` object. Explicit naming is required if |
| multiple `Tensor`s are passed through a single argument to the Python |
| function. |
| |
| The outputs of functions used as `signatures` must either be flat lists, in |
| which case outputs will be numbered, or a dictionary mapping string keys to |
| `Tensor`, in which case the keys will be used to name outputs. |
| |
| Signatures are available in objects returned by `tf.saved_model.load` as a |
| `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` |
| on an object with a custom `.signatures` attribute will raise an exception. |
| |
| _Using `tf.saved_model.save` with Keras models_ |
| |
| While Keras has its own [saving and loading API](https://www.tensorflow.org/guide/keras/save_and_serialize), |
| this function can be used to export Keras models. For example, exporting with |
| a signature specified: |
| |
| >>> class Adder(tf.keras.Model): |
| ... @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) |
| ... def concat(self, x): |
| ... return x + x |
| |
| >>> model = Adder() |
| >>> tf.saved_model.save(model, '/tmp/adder') |
| |
| Exporting from a function without a fixed signature: |
| |
| >>> class Adder(tf.keras.Model): |
| ... @tf.function |
| ... def concat(self, x): |
| ... return x + x |
| |
| >>> model = Adder() |
| >>> tf.saved_model.save( |
| ... model, '/tmp/adder', |
| ... signatures=model.concat.get_concrete_function( |
| ... tf.TensorSpec(shape=[], dtype=tf.string, name="string_input"))) |
| |
| `tf.keras.Model` instances constructed from inputs and outputs already have a |
| signature and so do not require a `@tf.function` decorator or a `signatures` |
| argument. If neither are specified, the model's forward pass is exported. |
| |
| >>> x = tf.keras.layers.Input((4,), name="x") |
| >>> y = tf.keras.layers.Dense(5, name="out")(x) |
| >>> model = tf.keras.Model(x, y) |
| >>> tf.saved_model.save(model, '/tmp/saved_model/') |
| |
| The exported SavedModel takes "x" with shape [None, 4] and returns "out" |
| with shape [None, 5] |
| |
| _Variables and Checkpoints_ |
| |
| Variables must be tracked by assigning them to an attribute of a tracked |
| object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers |
| from `tf.keras.layers`, optimizers from `tf.train`) track their variables |
| automatically. This is the same tracking scheme that `tf.train.Checkpoint` |
| uses, and an exported `Checkpoint` object may be restored as a training |
| checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's |
| "variables/" subdirectory. |
| |
| `tf.function` does not hard-code device annotations from outside the function |
| body, instead of using the calling context's device. This means for example |
| that exporting a model that runs on a GPU and serving it on a CPU will |
| generally work, with some exceptions: |
| |
| * `tf.device` annotations inside the body of the function will be hard-coded |
| in the exported model; this type of annotation is discouraged. |
| * Device-specific operations, e.g. with "cuDNN" in the name or with |
| device-specific layouts, may cause issues. |
| * For `ConcreteFunctions`, active distribution strategies will cause device |
| placements to be hard-coded in the function. |
| |
| SavedModels exported with `tf.saved_model.save` [strip default-valued |
| attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) |
| automatically, which removes one source of incompatibilities when the consumer |
| of a SavedModel is running an older TensorFlow version than the |
| producer. There are however other sources of incompatibilities which are not |
| handled automatically, such as when the exported model contains operations |
| which the consumer does not have definitions for. |
| |
| Args: |
| obj: A trackable object (e.g. tf.Module or tf.train.Checkpoint) to export. |
| export_dir: A directory in which to write the SavedModel. |
| signatures: Optional, one of three types: |
| * a `tf.function` with an input signature specified, which will use the |
| default serving signature key, |
| * the result of `f.get_concrete_function` on a `@tf.function`-decorated |
| function `f`, in which case `f` will be used to generate a signature for |
| the SavedModel under the default serving signature key, |
| * a dictionary, which maps signature keys to either `tf.function` |
| instances with input signatures or concrete functions. Keys of such a |
| dictionary may be arbitrary strings, but will typically be from the |
| `tf.saved_model.signature_constants` module. |
| options: `tf.saved_model.SaveOptions` object for configuring save options. |
| |
| Raises: |
| ValueError: If `obj` is not trackable. |
| |
| @compatibility(eager) |
| Not well supported when graph building. From TensorFlow 1.x, |
| `tf.compat.v1.enable_eager_execution()` should run first. Calling |
| tf.saved_model.save in a loop when graph building from TensorFlow 1.x will |
| add new save operations to the default graph each iteration. |
| |
| May not be called from within a function body. |
| @end_compatibility |
| """ |
| # pylint: enable=line-too-long |
| metrics.IncrementWriteApi(_SAVE_V2_LABEL) |
| save_and_return_nodes(obj, export_dir, signatures, options) |
| metrics.IncrementWrite(write_version="2") |
| |
| |
| def save_and_return_nodes(obj, |
| export_dir, |
| signatures=None, |
| options=None, |
| experimental_skip_checkpoint=False): |
| """Saves a SavedModel while returning all saved nodes and their paths. |
| |
| Please see `tf.saved_model.save` for details. |
| |
| Args: |
| obj: A trackable object to export. |
| export_dir: A directory in which to write the SavedModel. |
| signatures: A function or dictionary of functions to save in the SavedModel |
| as signatures. |
| options: `tf.saved_model.SaveOptions` object for configuring save options. |
| experimental_skip_checkpoint: If set to `True`, the checkpoint will not |
| be written. |
| |
| Returns: |
| A tuple of (a list of saved nodes in the order they are serialized to the |
| `SavedObjectGraph`, dictionary mapping nodes to one possible path from |
| the root node to the key node) |
| """ |
| options = options or save_options.SaveOptions() |
| # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x |
| # compatible (no sessions) and share it with this export API rather than |
| # making a SavedModel proto and writing it directly. |
| saved_model = saved_model_pb2.SavedModel() |
| meta_graph_def = saved_model.meta_graphs.add() |
| |
| _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = ( |
| _build_meta_graph(obj, signatures, options, meta_graph_def)) |
| saved_model.saved_model_schema_version = ( |
| constants.SAVED_MODEL_SCHEMA_VERSION) |
| |
| # Write the checkpoint, copy assets into the assets directory, and write out |
| # the SavedModel proto itself. |
| if not experimental_skip_checkpoint: |
| utils_impl.get_or_create_variables_dir(export_dir) |
| ckpt_options = checkpoint_options.CheckpointOptions( |
| experimental_io_device=options.experimental_io_device) |
| object_saver.save( |
| utils_impl.get_variables_path(export_dir), options=ckpt_options) |
| builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, |
| export_dir) |
| # Note that this needs to be the last file operation when saving the |
| # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an |
| # indication that the SavedModel is completely written. |
| if context.executing_eagerly(): |
| try: |
| context.async_wait() # Ensure save operations have completed. |
| except errors.NotFoundError as err: |
| raise FileNotFoundError( |
| f"{err}\n You may be trying to save on a different device from the " |
| "computational device. Consider setting the " |
| "`experimental_io_device` option in `tf.saved_model.SaveOptions` " |
| "to the io_device such as '/job:localhost'.") |
| |
| # We will slowly migrate code in this function to pywrap_saved_model.Save |
| # as we build up the C++ API. |
| pywrap_saved_model.Save(export_dir) |
| |
| path = file_io.join( |
| compat.as_str(export_dir), |
| compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) |
| file_io.atomic_write_string_to_file( |
| path, saved_model.SerializeToString(deterministic=True)) |
| # Save debug info, if requested. |
| if options.save_debug_info: |
| _export_debug_info(exported_graph, export_dir) |
| |
| # Clean reference cycles so repeated export()s don't make work for the garbage |
| # collector. Before this point, we need to keep references to captured |
| # constants in the saved graph. |
| ops.dismantle_graph(exported_graph) |
| |
| return saved_nodes, node_paths |
| |
| |
| def export_meta_graph(obj, filename, signatures=None, options=None): |
| """Exports the MetaGraph proto of the `obj` to a file. |
| |
| This function goes through the same procedures saved_model.save goes to |
| produce the given object's MetaGraph, then saves it to the given file. It |
| skips saving checkpoint information, and is useful when all one wants is the |
| graph defining the model. |
| |
| Args: |
| obj: A trackable object to build the MetaGraph from. |
| filename: The file into which to write the MetaGraph. |
| signatures: Optional, either a `tf.function` with an input signature |
| specified or the result of `f.get_concrete_function` on a |
| `@tf.function`-decorated function `f`, in which case `f` will be used to |
| generate a signature for the SavedModel under the default serving |
| signature key. `signatures` may also be a dictionary, in which case it |
| maps from signature keys to either `tf.function` instances with input |
| signatures or concrete functions. The keys of such a dictionary may be |
| arbitrary strings, but will typically be from the |
| `tf.saved_model.signature_constants` module. |
| options: Optional, `tf.saved_model.SaveOptions` object that specifies |
| options for saving. |
| """ |
| options = options or save_options.SaveOptions() |
| export_dir = os.path.dirname(filename) |
| meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph( |
| obj, signatures, options) |
| |
| file_io.atomic_write_string_to_file( |
| filename, meta_graph_def.SerializeToString(deterministic=True)) |
| |
| # Save debug info, if requested. |
| if options.save_debug_info: |
| _export_debug_info(exported_graph, export_dir) |
| |
| # Clean reference cycles so repeated export()s don't make work for the garbage |
| # collector. Before this point, we need to keep references to captured |
| # constants in the saved graph. |
| ops.dismantle_graph(exported_graph) |
| |
| |
| def _build_meta_graph_impl(obj, |
| signatures, |
| options, |
| meta_graph_def=None): |
| """Creates a MetaGraph containing the resources and functions of an object.""" |
| if ops.inside_function(): |
| raise AssertionError( |
| "`tf.saved_model.save` is not supported inside a traced @tf.function. " |
| "Move the call to the outer eagerly-executed context.") |
| # pylint: enable=line-too-long |
| if not isinstance(obj, base.Trackable): |
| raise ValueError( |
| "Expected an object of type `Trackable`, such as `tf.Module` or a " |
| f"subclass of the `Trackable` class, for export. Got {obj} " |
| f"with type {type(obj)}.") |
| meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef() |
| |
| checkpoint_graph_view = _AugmentedGraphView(obj) |
| if signatures is None: |
| signatures = signature_serialization.find_function_to_export( |
| checkpoint_graph_view) |
| |
| signatures, wrapped_functions = ( |
| signature_serialization.canonicalize_signatures(signatures)) |
| signature_serialization.validate_saveable_view(checkpoint_graph_view) |
| signature_map = signature_serialization.create_signature_map(signatures) |
| checkpoint_graph_view.add_object( |
| parent_node=checkpoint_graph_view.root, |
| name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME, |
| subgraph_root=signature_map) |
| |
| # Use _SaveableView to provide a frozen listing of properties and functions. |
| saveable_view = _SaveableView(checkpoint_graph_view, options, |
| wrapped_functions) |
| object_saver = util.TrackableSaver(checkpoint_graph_view) |
| asset_info, exported_graph = _fill_meta_graph_def( |
| meta_graph_def, saveable_view, signatures, |
| options.namespace_whitelist, options.experimental_custom_gradients) |
| if options.function_aliases: |
| function_aliases = meta_graph_def.meta_info_def.function_aliases |
| for alias, func in options.function_aliases.items(): |
| for fdef in func._stateful_fn._function_cache.all_values(): # pylint: disable=protected-access |
| function_aliases[fdef.name] = alias |
| for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access |
| function_aliases[fdef.name] = alias |
| |
| object_graph_proto = _serialize_object_graph( |
| saveable_view, asset_info.asset_index) |
| meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) |
| |
| return (meta_graph_def, exported_graph, object_saver, asset_info, |
| saveable_view.nodes, saveable_view.node_paths) |
| |
| |
| def _build_meta_graph(obj, |
| signatures, |
| options, |
| meta_graph_def=None): |
| """Creates a MetaGraph under a save context. |
| |
| Args: |
| obj: A trackable object to build the MetaGraph from. |
| signatures: Can be a `tf.function` with an input signature specified or the |
| result of `f.get_concrete_function` on a `@tf.function`-decorated function |
| `f`. `signatures` may also be a dictionary, in which case it maps from |
| signature keys to `tf.function` instances. If None, finds signature to |
| export from the `@tf.function`-decorated methods in `obj`. |
| options: `tf.saved_model.SaveOptions` object that specifies options for |
| saving. |
| meta_graph_def: Optional, the MetaGraphDef proto fill. |
| |
| Raises: |
| AssertionError: If `export_meta_graph` is executing inside a `tf.function`. |
| ValueError: If `obj` is not trackable. |
| |
| Returns: |
| meta_graph_def: Filled MetaGraphDef proto |
| exported_graph: `tf.Graph` object generated from `obj`. |
| object_saver: `util.TrackableSaver` of the `obj` and its dependencies. |
| asset_info: `_AssetInfo` tuple containing external assets in the `obj`. |
| """ |
| |
| with save_context.save_context(options): |
| return _build_meta_graph_impl(obj, signatures, options, meta_graph_def) |