blob: a5200b01b41895a9db578c54d189799efde1ee63 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for trackable object SavedModel save."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.lib.io import file_io
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import compat
class _ModelWithOptimizer(util.Checkpoint):
def __init__(self):
self.dense = core.Dense(1)
self.optimizer = adam.Adam(0.01)
@def_function.function(
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.float32)))
def call(self, x, y):
with backprop.GradientTape() as tape:
loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.)
trainable_variables = self.dense.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {"loss": loss}
def _import_and_infer(
save_dir, inputs,
signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
"""Import a SavedModel into a TF 1.x-style graph and run `signature_key`."""
graph = ops.Graph()
with graph.as_default(), session_lib.Session() as session:
model = loader.load(session, [tag_constants.SERVING], save_dir)
signature = model.signature_def[signature_key]
assert set(inputs.keys()) == set(signature.inputs.keys())
feed_dict = {}
for arg_name in inputs.keys():
feed_dict[graph.get_tensor_by_name(signature.inputs[arg_name].name)] = (
inputs[arg_name])
output_dict = {}
for output_name, output_tensor_info in signature.outputs.items():
output_dict[output_name] = graph.get_tensor_by_name(
output_tensor_info.name)
return session.run(output_dict, feed_dict=feed_dict)
class SaveTest(test.TestCase):
def test_method_save_signature(self):
root = tracking.AutoTrackable()
root.f = def_function.function(
lambda x: 2. * x,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root.f(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(root, save_dir, root.f)
self.assertEqual(
{"output_0": 2.},
_import_and_infer(save_dir, {"x": 1.}))
def test_method_save_concrete(self):
root = tracking.AutoTrackable()
root.f = def_function.function(
lambda z: {"out": 2. * z})
root.f(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(
root,
save_dir,
{"non_default_key": root.f.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32))})
self.assertEqual(
{"out": 2.},
_import_and_infer(
save_dir, {"z": 1.}, signature_key="non_default_key"))
def test_unbuilt_model_does_not_prevent_saving(self):
root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)]))
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
def test_captured_symbolic_tensor_exception(self):
root = module.Module()
symbolic_tensor = []
@def_function.function
def captured_intermediate(x):
symbolic_tensor.append(math_ops.add(x, x, name="a_tensor"))
return symbolic_tensor[-1] * 2
captured_intermediate(constant_op.constant(1.))
root.f = def_function.function(lambda: symbolic_tensor[-1],
input_signature=[])
with self.assertRaisesRegexp(ValueError, "a_tensor"):
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"),
signatures=root.f)
def test_unsaveable_func_graph(self):
root = module.Module()
@def_function.function(input_signature=[])
def nested_f():
ops.get_default_graph().mark_as_unsaveable("ERROR MSG")
return 1
@def_function.function(input_signature=[])
def f():
return nested_f()
root.f = f
with self.assertRaisesRegexp(ValueError, "ERROR MSG"):
save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
def test_version_information_included(self):
root = tracking.AutoTrackable()
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(root, save_dir)
saved_model_proto = loader_impl.parse_saved_model(save_dir)
self.assertEqual(
versions.__version__,
saved_model_proto.meta_graphs[0].meta_info_def.tensorflow_version)
self.assertEqual(
versions.__git_version__,
saved_model_proto.meta_graphs[0].meta_info_def.tensorflow_git_version)
def test_non_concrete_error(self):
root = tracking.AutoTrackable()
root.f = def_function.function(lambda x: 2. * x)
root.f(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
with self.assertRaisesRegexp(
ValueError, "Expected a TensorFlow function"):
save.save(root, save_dir, root.f)
def test_captures_unreachable_variable(self):
root = tracking.AutoTrackable()
unreachable_variable = variables.Variable([5.0, 2.0])
root.reachable_variable = variables.Variable([1.0, 3.0])
@def_function.function
def increase_variable(x):
return 2 * unreachable_variable * x + root.reachable_variable
root.f = increase_variable
self.assertAllEqual([101.0, 83.0],
root.f(constant_op.constant([10.0, 20.0])).numpy())
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
with self.assertRaisesRegexp(KeyError, "not reachable from root"):
save.save(root, save_dir)
def test_nested_inputs(self):
root = tracking.AutoTrackable()
root.f = def_function.function(
lambda x: 2. * x[0],
input_signature=([tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32)],))
root.f([constant_op.constant(1.), constant_op.constant(1.)])
def test_nested_outputs(self):
root = tracking.AutoTrackable()
root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x)))
root.f(constant_op.constant(1.))
to_save = root.f.get_concrete_function(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
with self.assertRaisesRegexp(
ValueError, "non-flat outputs"):
save.save(root, save_dir, to_save)
def test_nested_dict_outputs(self):
root = util.Checkpoint(
f=def_function.function(
lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)}))
root.f(constant_op.constant(1.))
to_save = root.f.get_concrete_function(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
with self.assertRaisesRegexp(
ValueError, "dictionary containing non-Tensor value"):
save.save(root, save_dir, to_save)
def test_variable(self):
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(
lambda x: root.v1 * root.v2 * x)
root.f(constant_op.constant(1.))
to_save = root.f.get_concrete_function(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(root, save_dir, to_save)
self.assertAllEqual({"output_0": 12.},
_import_and_infer(save_dir, {"x": 2.}))
def test_optimizer(self):
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
model = _ModelWithOptimizer()
first_loss = model.call(x, y)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir, model.call)
second_loss = model.call(x, y)
self.assertNotEqual(first_loss, second_loss)
self.assertAllClose(
second_loss,
_import_and_infer(save_dir, {"x": [[3., 4.]], "y": [2.]}))
def test_single_method_default_signature(self):
model = _ModelWithOptimizer()
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
model.call(x, y)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir)
self.assertIn("loss",
_import_and_infer(save_dir,
{"x": [[3., 4.]], "y": [2.]}))
def test_single_function_default_signature(self):
model = tracking.AutoTrackable()
model.f = def_function.function(lambda: 3., input_signature=())
model.f()
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir)
self.assertAllClose({"output_0": 3.},
_import_and_infer(save_dir, {}))
def test_single_function_no_signature(self):
model = tracking.AutoTrackable()
model.f = def_function.function(lambda: 3.)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(model, save_dir)
def test_find_default_save_function(self):
class ObjWithDefaultSignature(util.Checkpoint):
@def_function.function(input_signature=[tensor_spec.TensorSpec(
shape=None, dtype=dtypes.float32)])
def _default_save_signature(self, x):
return x + x + 1
obj = ObjWithDefaultSignature()
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(obj, save_dir)
self.assertAllClose(
{"output_0": 7.}, _import_and_infer(save_dir, {"x": 3.}))
def test_docstring(self):
class Adder(module.Module):
@def_function.function(input_signature=[tensor_spec.TensorSpec(
shape=None, dtype=dtypes.float32)])
def add(self, x):
return x + x + 1.
to_save = Adder()
to_save.add(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(to_save, save_dir)
self.assertAllClose({"output_0": 7.},
_import_and_infer(save_dir, {"x": 3.}))
def test_datastructures(self):
class HasDatastructures(util.Checkpoint):
def __init__(self):
self.a = [1.]
self.a.append(variables.Variable(2.))
self.b = {"a": variables.Variable(3.)}
@def_function.function(input_signature=[tensor_spec.TensorSpec(
shape=None, dtype=dtypes.float32)])
def add(self, x):
return x + math_ops.add_n(self.a) + self.b["a"]
to_save = HasDatastructures()
to_save.add(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(to_save, save_dir)
self.assertAllClose({"output_0": 10.},
_import_and_infer(save_dir, {"x": 4.}))
def test_default_attr_stripping(self):
class Complex(util.Checkpoint):
@def_function.function(input_signature=[])
def __call__(self):
return math_ops.complex(
constant_op.constant(1.),
constant_op.constant(2.),
name="complex")
to_save = Complex()
to_save()
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(to_save, save_dir)
graph = ops.Graph()
with graph.as_default(), self.session(graph) as session:
loader.load(session, [tag_constants.SERVING], save_dir)
func, = [f for name, f in graph._functions.items() if "call" in name]
complex_node, = [
node for node in func.definition.node_def if node.op == "Complex"]
self.assertNotIn("T", complex_node.attr)
self.assertNotIn("Tout", complex_node.attr)
def test_signature_attribute_reserved(self):
root = util.Checkpoint(signatures=variables.Variable(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
with self.assertRaisesRegexp(ValueError, "del obj.signatures"):
save.save(root, save_dir)
del root.signatures
save.save(root, save_dir)
def test_function_with_captured_dataset(self):
if test_util.is_gpu_available():
self.skipTest("Currently broken when a GPU is available.")
class HasDataset(module.Module):
def __init__(self):
super(HasDataset, self).__init__()
self.dataset = (
dataset_ops.Dataset.range(5)
.map(lambda x: x ** 2))
@def_function.function
def __call__(self, x):
current_sum = array_ops.zeros([], dtype=dtypes.int64)
for element in self.dataset:
current_sum += x * element
return current_sum
root = HasDataset()
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(
root, save_dir,
signatures=root.__call__.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.int64)))
self.assertAllClose({"output_0": 3 * (1 + 4 + 9 + 16)},
_import_and_infer(save_dir, {"x": 3}))
class AssetTests(test.TestCase):
def setUp(self):
super(AssetTests, self).setUp()
self._vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt")
with open(self._vocab_path, "w") as f:
f.write("alpha\nbeta\ngamma\n")
def test_asset_path_returned(self):
root = tracking.AutoTrackable()
root.path = tracking.TrackableAsset(self._vocab_path)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
root.get_asset = def_function.function(lambda: root.path.asset_path)
save.save(root, save_dir, signatures=root.get_asset.get_concrete_function())
second_dir = os.path.join(self.get_temp_dir(), "second_dir")
file_io.rename(save_dir, second_dir)
imported_path = _import_and_infer(second_dir, {})["output_0"]
self.assertIn(compat.as_str_any(second_dir),
compat.as_str_any(imported_path))
def test_table(self):
initializer = lookup_ops.TextFileInitializer(
self._vocab_path,
key_dtype=dtypes.string,
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
value_dtype=dtypes.int64,
value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
root = util.Checkpoint(table=lookup_ops.HashTable(
initializer, default_value=-1))
root.table_user = def_function.function(
root.table.lookup,
input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
self.assertEqual(
2,
self.evaluate(root.table_user(constant_op.constant("gamma"))))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(root, save_dir)
file_io.delete_file(self._vocab_path)
self.assertAllClose(
{"output_0": [2, 0]},
_import_and_infer(save_dir, {"keys": ["gamma", "alpha"]}))
second_dir = os.path.join(self.get_temp_dir(), "second_dir")
# Asset paths should track the location the SavedModel is loaded from.
file_io.rename(save_dir, second_dir)
self.assertAllClose(
{"output_0": [2, 1]},
_import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
def test_unused_asset(self):
root = tracking.AutoTrackable()
root.f = def_function.function(
lambda x: 2. * x,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root.asset = tracking.TrackableAsset(self._vocab_path)
export_dir = os.path.join(self.get_temp_dir(), "save_dir")
save.save(root, export_dir)
self.assertAllClose(
{"output_0": [0.2]},
_import_and_infer(export_dir, {"x": [0.1]}))
def test_sensible_function_building_exception(self):
root = util.Checkpoint(v=variables.Variable(2.))
root.f = def_function.function(
lambda x: 2. * root.v,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
export_dir = os.path.join(self.get_temp_dir(), "save_dir")
@def_function.function
def _calls_save():
save.save(root, export_dir)
with self.assertRaisesRegexp(AssertionError, "tf.function"):
_calls_save()
class _ModelWithOptimizerUsingDefun(util.Checkpoint):
def __init__(self):
self.dense = core.Dense(1)
self.optimizer = adam.Adam(0.01)
# Using defun due to control flow v2 cycles, b/121159261. def_function uses
# conds to gate variable initialization and so triggers cond reference cycles,
# but the thing being wrapped here does not use cond itself.
@function.defun(
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.float32)),
)
def call(self, x, y):
with backprop.GradientTape() as tape:
loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.)
trainable_variables = self.dense.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {"loss": loss}
class MemoryTests(test.TestCase):
def setUp(self):
self._model = _ModelWithOptimizerUsingDefun()
@test_util.assert_no_garbage_created
def test_no_reference_cycles(self):
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
self._model.call(x, y)
if sys.version_info[0] < 3:
# TODO(allenl): debug reference cycles in Python 2.x
self.skipTest("This test only works in Python 3+. Reference cycles are "
"created in older Python versions.")
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(self._model, save_dir, self._model.call)
if __name__ == "__main__":
test.main()