blob: ac40a216cedf6a02321968cfb79c4d803e5c0593 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Import a TF v1-style SavedModel when executing eagerly."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.python.eager import context
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import function_deserialization
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import signature_serialization
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training.tracking import tracking
class _Initializer(tracking.CapturableResource):
"""Represents an initialization operation restored from a SavedModel.
Without this object re-export of imported 1.x SavedModels would omit the
original SavedModel's initialization procedure.
Created when `tf.saved_model.load` loads a TF 1.x-style SavedModel with an
initialization op. This object holds a function which runs the
initialization. It does not require any manual user intervention;
`tf.saved_model.save` will see this object and automatically add it to the
exported SavedModel, and `tf.saved_model.load` runs the initialization
function automatically.
"""
def __init__(self, init_fn, asset_paths):
super(_Initializer, self).__init__()
self._asset_paths = asset_paths
self._init_fn = init_fn
def _create_resource(self):
return array_ops.placeholder(
dtype=dtypes.resource, shape=[], name="unused_resource")
def _initialize(self):
return self._init_fn(*[path.asset_path for path in self._asset_paths])
class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
"""Loads a SavedModel without using Sessions."""
def get_meta_graph_def_from_tags(self, tags):
"""Override to support implicit one-MetaGraph loading with tags=None."""
if tags is None:
if len(self._saved_model.meta_graphs) != 1:
tag_sets = [mg.meta_info_def.tags
for mg in self._saved_model.meta_graphs]
raise ValueError(
("Importing a SavedModel with tf.saved_model.load requires a "
"'tags=' argument if there is more than one MetaGraph. Got "
"'tags=None', but there are {} MetaGraphs in the SavedModel with "
"tag sets {}. Pass a 'tags=' argument to load this SavedModel.")
.format(len(self._saved_model.meta_graphs), tag_sets))
return self._saved_model.meta_graphs[0]
return super(_EagerSavedModelLoader, self).get_meta_graph_def_from_tags(
tags)
def load_graph(self, returns, meta_graph_def):
"""Called from wrap_function to import `meta_graph_def`."""
# pylint: disable=protected-access
saver, _ = tf_saver._import_meta_graph_with_return_elements(
meta_graph_def)
# pylint: enable=protected-access
returns[0] = saver
def restore_variables(self, wrapped, saver):
"""Restores variables from the checkpoint."""
if saver is not None:
saver_def = saver.saver_def
filename_tensor = wrapped.graph.as_graph_element(
saver_def.filename_tensor_name)
# We both feed and fetch filename_tensor so we have an operation to use to
# feed into variable initializers (only relevant for v1 graph building).
restore_fn = wrapped.prune(
feeds=[filename_tensor],
fetches=[filename_tensor,
wrapped.graph.as_graph_element(saver_def.restore_op_name)])
initializer, _ = restore_fn(constant_op.constant(self._variables_path))
if not ops.executing_eagerly_outside_functions():
# Add the initialization operation to the table initializers collection
# in case we don't have any lifted variables to attach it to. There
# isn't another great place to put it.
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, initializer)
one_unlifted = False
for variable in wrapped.graph.get_collection_ref(
ops.GraphKeys.GLOBAL_VARIABLES):
if variable.graph is wrapped.graph:
one_unlifted = True
# pylint: disable=protected-access
variable._initializer_op = initializer
# pylint: enable=protected-access
if one_unlifted:
logging.warning(
"Some variables could not be lifted out of a loaded function. "
"Run the tf.initializers.tables_initializer() operation to "
"restore these variables.")
def _extract_signatures(self, wrapped, meta_graph_def):
"""Creates ConcreteFunctions for signatures in `meta_graph_def`."""
signature_functions = {}
for signature_key, signature_def in meta_graph_def.signature_def.items():
if signature_def.inputs:
input_names, input_specs = zip(*signature_def.inputs.items())
else:
input_names = []
input_specs = []
# TODO(allenl): Support optional arguments
feeds = [wrapped.graph.as_graph_element(inp.name) for inp in input_specs]
fetches = {name: out for name, out in signature_def.outputs.items()}
try:
signature_fn = wrapped.prune(feeds=feeds, fetches=fetches)
except lift_to_graph.UnliftableError as ex:
# Mutate the exception to add a bit more detail.
args = ex.args
if not args:
message = ""
else:
message = args[0]
message = (
("A SavedModel signature needs an input for each placeholder the "
"signature's outputs use. An output for signature '{}' depends on "
"a placeholder which is not an input (i.e. the placeholder is not "
"fed a value).\n\n").format(signature_key)
+ message)
ex.args = (message,) + args[1:]
raise
# pylint: disable=protected-access
signature_fn._arg_keywords = input_names
if len(input_names) == 1:
# Allowing positional arguments does not create any ambiguity if there's
# only one.
signature_fn._num_positional_args = 1
else:
signature_fn._num_positional_args = 0
# pylint: enable=protected-access
signature_functions[signature_key] = signature_fn
return signature_functions
def load(self, tags):
"""Creates an object from the MetaGraph identified by `tags`."""
meta_graph_def = self.get_meta_graph_def_from_tags(tags)
load_shared_name_suffix = "_load_{}".format(ops.uid())
functions = function_deserialization.load_function_def_library(
meta_graph_def.graph_def.library,
load_shared_name_suffix=load_shared_name_suffix)
# Replace existing functions in the MetaGraphDef with renamed functions so
# we don't have duplicates or name collisions.
meta_graph_def.graph_def.library.Clear()
for function in functions.values():
meta_graph_def.graph_def.library.function.append(function.function_def)
# We've renamed functions and shared names. We need the same operation on
# the GraphDef itself for consistency.
for node_def in meta_graph_def.graph_def.node:
function_deserialization.fix_node_def(node_def, functions,
load_shared_name_suffix,
debug_name="MetaGraph import")
load_graph_returns = [None]
wrapped = wrap_function.wrap_function(
functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
signature=[])
saver, = load_graph_returns
self.restore_variables(wrapped, saver)
with wrapped.graph.as_default():
init_op = loader_impl.get_init_op(
meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
# Add a dummy Tensor we know we can fetch to add control dependencies to.
init_anchor = constant_op.constant(0., name="dummy_fetch")
root = tracking.AutoTrackable()
asset_feed_tensors = []
asset_paths = []
for tensor_name, value in loader_impl.get_asset_tensors(
self._export_dir, meta_graph_def).items():
asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
asset_paths.append(tracking.TrackableAsset(value))
init_fn = wrapped.prune(
feeds=asset_feed_tensors,
fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)])
initializer = _Initializer(init_fn, asset_paths)
# pylint: disable=protected-access
local_init_op, _ = initializer._initialize()
# pylint: enable=protected-access
with ops.init_scope():
if not context.executing_eagerly():
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op)
for variable in wrapped.graph.get_collection_ref(
ops.GraphKeys.LOCAL_VARIABLES):
# pylint: disable=protected-access
variable._initializer_op = local_init_op
# pylint: enable=protected-access
root.initializer = initializer
root.asset_paths = asset_paths
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
root.signatures = signature_serialization.create_signature_map(
signature_functions)
root.variables = list(wrapped.graph.variables)
root.tensorflow_version = (
meta_graph_def.meta_info_def.tensorflow_version)
root.tensorflow_git_version = (
meta_graph_def.meta_info_def.tensorflow_git_version)
root.graph = wrapped.graph
root.prune = wrapped.prune
return root
def load(export_dir, tags):
"""Load a v1-style SavedModel as an object."""
loader = _EagerSavedModelLoader(export_dir)
return loader.load(tags=tags)