| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for trackable object SavedModel loading.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import functools |
| import os |
| import sys |
| import tempfile |
| import weakref |
| |
| from absl.testing import parameterized |
| |
| from tensorflow.python.client import session as session_lib |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.eager import backprop |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import test |
| from tensorflow.python.eager import wrap_function |
| from tensorflow.python.feature_column import feature_column_lib |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import function as framework_function |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_spec |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.framework import versions |
| from tensorflow.python.keras import keras_parameterized |
| from tensorflow.python.keras.engine import base_layer |
| from tensorflow.python.keras.engine import input_layer |
| from tensorflow.python.keras.engine import sequential |
| from tensorflow.python.keras.engine import training as training_lib |
| from tensorflow.python.keras.layers import convolutional |
| from tensorflow.python.keras.layers import core |
| from tensorflow.python.keras.optimizer_v2 import adam |
| from tensorflow.python.lib.io import file_io |
| from tensorflow.python.module import module |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import cond_v2 |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import lookup_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.ops.ragged import ragged_factory_ops |
| from tensorflow.python.ops.ragged import ragged_tensor |
| from tensorflow.python.saved_model import load |
| from tensorflow.python.saved_model import save |
| from tensorflow.python.saved_model import tag_constants |
| from tensorflow.python.training import monitored_session |
| from tensorflow.python.training.tracking import tracking |
| from tensorflow.python.training.tracking import util |
| from tensorflow.python.util import tf_inspect |
| |
| |
| def cycle(obj, cycles, signatures=None): |
| to_save = obj |
| # TODO(vbardiovsky): It would be nice if exported protos reached a fixed |
| # point w.r.t. saving/restoring, ideally after 2nd saving. |
| for _ in range(cycles): |
| path = tempfile.mkdtemp(prefix=test.get_temp_dir()) |
| # If available, we'll run the save and restore preferring the GPU. This |
| # just makes sure we aren't throwing errors and have enough |
| # device("CPU") blocks to satisfy the placer. |
| with test_util.use_gpu(): |
| save.save(to_save, path, signatures) |
| loaded = load.load(path) |
| to_save = loaded |
| return loaded |
| |
| |
| @parameterized.named_parameters( |
| dict(testcase_name="ReloadOnce", cycles=1), |
| dict(testcase_name="ReloadTwice", cycles=2), |
| dict(testcase_name="ReloadThrice", cycles=3)) |
| class LoadTest(test.TestCase, parameterized.TestCase): |
| |
| def test_structure_import(self, cycles): |
| root = tracking.AutoTrackable() |
| root.dep_one = tracking.AutoTrackable() |
| root.dep_two = tracking.AutoTrackable() |
| root.dep_two.dep = tracking.AutoTrackable() |
| root.dep_three = root.dep_two.dep |
| imported = cycle(root, cycles) |
| self.assertIs(imported.dep_three, imported.dep_two.dep) |
| self.assertIsNot(imported.dep_one, imported.dep_two) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_variables(self, cycles): |
| root = tracking.AutoTrackable() |
| root.v1 = variables.Variable(1., trainable=True) |
| root.v2 = variables.Variable(2., trainable=False) |
| self.evaluate([root.v1.initializer, root.v2.initializer]) |
| |
| for _ in range(cycles): |
| imported = cycle(root, 1) |
| self.evaluate([imported.v1.initializer, imported.v2.initializer]) |
| |
| if not context.executing_eagerly(): |
| self.assertIsInstance(imported.v1.initializer, ops.Operation) |
| self.assertIsInstance(imported.v2.initializer, ops.Operation) |
| |
| self.assertEqual(self.evaluate(imported.v1), 1.0) |
| self.assertTrue(imported.v1.trainable) |
| self.assertEqual(self.evaluate(imported.v2), 2.0) |
| self.assertFalse(imported.v2.trainable) |
| |
| def test_variables_name(self, cycles): |
| root = tracking.AutoTrackable() |
| # Test 2 variables with same name: should work as the checkpoint |
| # is based on object name and not on variable name. |
| root.v1 = variables.Variable(1., trainable=True, name="v1") |
| root.v2 = variables.Variable(2., trainable=False, name="v1") |
| imported = cycle(root, cycles) |
| self.assertEqual(imported.v1.numpy(), 1.0) |
| self.assertEqual(imported.v2.numpy(), 2.0) |
| self.assertEqual(imported.v1.name, root.v1.name) |
| self.assertEqual(imported.v2.name, root.v2.name) |
| with variable_scope.variable_scope("foo"): |
| imported = cycle(root, cycles) |
| self.assertTrue(imported.v1.name.startswith("foo/")) |
| self.assertTrue(imported.v2.name.startswith("foo/")) |
| |
| def test_partially_defined_variable_shape(self, cycles): |
| |
| class MakeVariable(module.Module): |
| |
| def __init__(self): |
| self.v = None |
| |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec([None], dtypes.int64)]) |
| def make_variable(self, initial_value): |
| if self.v is None: |
| self.v = variables.Variable(initial_value) |
| |
| m = MakeVariable() |
| m.make_variable([1, 2, 3]) |
| m = cycle(m, cycles) |
| m.v.assign([1, 2, 3, 4]) |
| self.assertEqual([None], tensor_shape.as_shape(m.v.shape).as_list()) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_capture_variables(self, cycles): |
| root = tracking.AutoTrackable() |
| root.weights = variables.Variable(2.) |
| self.evaluate(root.weights.initializer) |
| root.f = def_function.function( |
| lambda x: root.weights * x, |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| for _ in range(cycles): |
| imported = cycle(root, 1) |
| self.evaluate(imported.weights.initializer) |
| self.assertEqual(4., self.evaluate(imported.f(constant_op.constant(2.)))) |
| self.evaluate(imported.weights.assign(4.0)) |
| self.assertEqual(8., self.evaluate(imported.f(constant_op.constant(2.)))) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_capture_constant(self, cycles): |
| root = tracking.AutoTrackable() |
| captured_constant = constant_op.constant(2.) |
| root.f = def_function.function( |
| lambda x: captured_constant * x, |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| imported = cycle(root, cycles) |
| self.assertEqual(4., self.evaluate(imported.f(constant_op.constant(2.)))) |
| |
| def test_control_outputs(self, cycles): |
| exported = tracking.AutoTrackable() |
| exported.v = variables.Variable(1.) |
| exported.f = def_function.function( |
| lambda: exported.v.assign(2., name="should_be_control_output")) |
| exported_graph = exported.f.get_concrete_function().graph |
| self.assertIn( |
| exported_graph.get_operation_by_name("should_be_control_output"), |
| exported_graph.control_outputs) |
| |
| imported = cycle(exported, cycles) |
| # Calling get_concrete_function wraps in a second call operation; we want to |
| # inspect the original function body for the control output; digging into |
| # graph.as_graph_def() and its FunctionDefLibrary is another option. |
| imported_concrete, = imported.f.concrete_functions |
| imported_graph = imported_concrete.graph |
| self.assertIn( |
| imported_graph.get_operation_by_name("should_be_control_output"), |
| imported_graph.control_outputs) |
| |
| def _make_asset(self, contents): |
| filename = tempfile.mktemp(prefix=self.get_temp_dir()) |
| with open(filename, "w") as f: |
| f.write(contents) |
| return filename |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_assets(self, cycles): |
| file1 = self._make_asset("contents 1") |
| file2 = self._make_asset("contents 2") |
| |
| root = tracking.AutoTrackable() |
| root.asset1 = tracking.Asset(file1) |
| root.asset2 = tracking.Asset(file2) |
| |
| save_dir = os.path.join(self.get_temp_dir(), "save_dir") |
| save.save(root, save_dir) |
| |
| file_io.delete_file(file1) |
| file_io.delete_file(file2) |
| load_dir = os.path.join(self.get_temp_dir(), "load_dir") |
| file_io.rename(save_dir, load_dir) |
| |
| imported = load.load(load_dir) |
| with open(self.evaluate(imported.asset1.asset_path), "r") as f: |
| self.assertEqual("contents 1", f.read()) |
| with open(self.evaluate(imported.asset2.asset_path), "r") as f: |
| self.assertEqual("contents 2", f.read()) |
| |
| def test_cond_prune(self, cycles): |
| x_in = [] |
| x_out = [] |
| |
| def f(x, y): |
| x_in.append(x) |
| xx = cond_v2.cond_v2( |
| math_ops.less(1, 2), |
| lambda: x + 1, |
| lambda: x + 2, |
| ) |
| x_out.append(xx) |
| return xx, 2 * y |
| |
| f_wrapped = wrap_function.wrap_function( |
| f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2) |
| f_pruned = f_wrapped.prune(x_in[0], [x_out[0]]) |
| |
| class Adder(module.Module): |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)]) |
| def add(self, x): |
| return f_pruned(x) |
| |
| root = Adder() |
| root.add(constant_op.constant(1.)) |
| root = cycle(root, cycles) |
| root.add(constant_op.constant(1.)) |
| |
| def test_capture_assets(self, cycles): |
| root = tracking.AutoTrackable() |
| root.vocab = tracking.Asset(self._make_asset("contents")) |
| root.f = def_function.function( |
| lambda: root.vocab.asset_path, |
| input_signature=[]) |
| imported = cycle(root, cycles) |
| original_output = root.f().numpy() |
| imported_output = imported.f().numpy() |
| self.assertNotEqual(original_output, imported_output) |
| with open(imported_output, "r") as f: |
| self.assertEqual("contents", f.read()) |
| |
| def test_capture_assets_in_graph(self, cycles): |
| root = tracking.AutoTrackable() |
| root.vocab = tracking.Asset(self._make_asset("contents")) |
| root.f = def_function.function( |
| lambda: root.vocab.asset_path, |
| input_signature=[]) |
| |
| original_output = root.f().numpy() |
| |
| if cycles > 1: |
| root = cycle(root, cycles - 1) |
| path = tempfile.mkdtemp(prefix=self.get_temp_dir()) |
| save.save(root, path) |
| |
| with ops.Graph().as_default(): |
| imported = load.load(path) |
| imported_tensor = imported.f() |
| with monitored_session.MonitoredSession() as sess: |
| imported_output = sess.run(imported_tensor) |
| self.assertNotEqual(original_output, imported_output) |
| with open(imported_output, "r") as f: |
| self.assertEqual("contents", f.read()) |
| |
| def test_dedup_assets(self, cycles): |
| vocab = self._make_asset("contents") |
| root = tracking.AutoTrackable() |
| root.asset1 = tracking.Asset(vocab) |
| root.asset2 = tracking.Asset(vocab) |
| imported = cycle(root, cycles) |
| self.assertEqual(imported.asset1.asset_path.numpy(), |
| imported.asset2.asset_path.numpy()) |
| |
| def test_implicit_input_signature(self, cycles): |
| @def_function.function |
| def func(x): |
| return 2 * x |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| |
| # Add two traces. |
| root.f(constant_op.constant(1.)) |
| root.f(constant_op.constant(1)) |
| |
| imported = cycle(root, cycles) |
| |
| self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy()) |
| self.assertEqual(14, imported.f(constant_op.constant(7)).numpy()) |
| |
| def test_explicit_input_signature(self, cycles): |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| def func(x): |
| return 2 * x |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| |
| imported = cycle(root, cycles) |
| self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy()) |
| |
| def test_explicit_save_signature(self, cycles): |
| @def_function.function |
| def func(x): |
| return 2 * x |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| |
| imported = cycle( |
| root, cycles, { |
| "f": |
| root.f.get_concrete_function( |
| tensor_spec.TensorSpec(None, dtypes.float32)) |
| }) |
| self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy()) |
| |
| def test_nested_functions(self, cycles): |
| f = def_function.function( |
| lambda x: x*2.0, |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| g = def_function.function( |
| lambda x: f(x) + 1.0, |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| |
| root = tracking.AutoTrackable() |
| root.g = g |
| imported = cycle(root, cycles) |
| imported.g(constant_op.constant([1.0])) |
| |
| def test_function_with_default_bool_input(self, cycles): |
| |
| def func(x, training=False): |
| if training: |
| return 2 * x |
| else: |
| return 7 |
| |
| root = tracking.AutoTrackable() |
| root.f = def_function.function(func) |
| |
| self.assertEqual(20, root.f(constant_op.constant(10), True).numpy()) |
| self.assertEqual(7, root.f(constant_op.constant(1)).numpy()) |
| self.assertEqual(2, root.f(constant_op.constant(1), True).numpy()) |
| |
| imported = cycle(root, cycles) |
| |
| self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy()) |
| self.assertEqual(7, imported.f(constant_op.constant(2)).numpy()) |
| |
| def test_function_with_default_none_input(self, cycles): |
| |
| def func(x, dtype=None): |
| if dtype: |
| return array_ops.zeros(shape=x.shape, dtype=dtype) |
| else: |
| return array_ops.zeros(shape=x.shape, dtype=dtypes.float32) |
| |
| root = tracking.AutoTrackable() |
| root.f = def_function.function(func) |
| |
| self.assertAllEqual([0.0, 0.0, 0.0], |
| root.f(constant_op.constant([1, 2, 3])).numpy()) |
| self.assertAllEqual([0.0, 0.0, 0.0], |
| root.f(constant_op.constant([1.0, 2.0, 3.0])).numpy()) |
| self.assertAllEqual([0.0, 0.0, 0.0, 0.0], |
| root.f(constant_op.constant([1, 2, 3, 4])).numpy()) |
| self.assertAllEqual([0, 0, 0], |
| root.f( |
| constant_op.constant([1.0, 2.0, 3.0]), |
| dtype=dtypes.int32).numpy()) |
| |
| concrete_functions = root.f._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access |
| self.assertEqual(4, len(concrete_functions)) |
| |
| imported = cycle(root, cycles) |
| |
| self.assertAllEqual([0.0, 0.0, 0.0], |
| imported.f(constant_op.constant([1, 2, 3]), |
| None).numpy()) |
| self.assertAllEqual([0.0, 0.0, 0.0], |
| imported.f(constant_op.constant([1.0, 2.0, |
| 3.0])).numpy()) |
| self.assertAllEqual([0.0, 0.0, 0.0, 0.0], |
| imported.f(constant_op.constant([1, 2, 3, 4])).numpy()) |
| self.assertAllEqual([0, 0, 0], |
| imported.f( |
| constant_op.constant([1.0, 2.0, 3.0]), |
| dtype=dtypes.int32).numpy()) |
| |
| def test_function_no_return(self, cycles): |
| |
| class TrackableWithOneVariable(tracking.AutoTrackable): |
| |
| def __init__(self, initial_value=0.0): |
| super(TrackableWithOneVariable, self).__init__() |
| self.variable = variables.Variable(initial_value) |
| |
| @def_function.function |
| def increase(self, by=1.0): |
| self.variable.assign_add(by) |
| |
| obj = TrackableWithOneVariable(5.0) |
| |
| obj.increase(constant_op.constant(10.0)) |
| self.assertEqual(15.0, obj.variable.numpy()) |
| obj.increase() |
| self.assertEqual(16.0, obj.variable.numpy()) |
| |
| imported = cycle(obj, cycles) |
| |
| imported.increase(constant_op.constant(10.0)) |
| self.assertEqual(26.0, imported.variable.numpy()) |
| imported.increase(constant_op.constant(1.0)) |
| self.assertEqual(27.0, imported.variable.numpy()) |
| |
| def test_structured_inputs(self, cycles): |
| |
| def func(x, training=True): |
| # x is a nested structure, we care about one particular tensor. |
| _, (a, b) = x |
| if training: |
| return 2 * a["a"] + b |
| else: |
| return 7 |
| |
| root = tracking.AutoTrackable() |
| root.f = def_function.function(func) |
| |
| x = constant_op.constant(10) |
| y = constant_op.constant(11) |
| |
| input1 = [6, ({"a": x}, y)] |
| input2 = [7, ({"a": x}, y)] # Not compatible with input1 signature. |
| input3 = [6, ({"a": y}, x)] # Compatible with input1 signature. |
| |
| # Note: by only calling f(input1) before serialization, only inputs with |
| # matching signature will be valid on the loaded model. |
| self.assertEqual(31, root.f(input1).numpy()) |
| |
| imported = cycle(root, cycles) |
| |
| with self.assertRaisesRegexp(ValueError, |
| "Could not find matching function to call"): |
| imported.f(input2) |
| |
| self.assertEqual(31, imported.f(input1).numpy()) |
| self.assertEqual(32, imported.f(input3).numpy()) |
| |
| def test_structured_output(self, cycles): |
| |
| # Use fields with non-alphabetical order |
| named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"]) |
| |
| def func(input1, input2): |
| named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2) |
| return [named_tuple, input2, {"x": 0.5}] |
| |
| root = tracking.AutoTrackable() |
| root.f = def_function.function(func) |
| |
| result = root.f(constant_op.constant(2), constant_op.constant(3)) |
| |
| self.assertEqual(5, result[0].a.numpy()) |
| self.assertEqual(6, result[0].b.numpy()) |
| self.assertEqual(["b", "a"], list(result[0]._asdict().keys())) |
| self.assertEqual(3, result[1].numpy()) |
| self.assertEqual(0.5, result[2]["x"].numpy()) |
| |
| imported = cycle(root, cycles) |
| |
| result = imported.f(constant_op.constant(2), constant_op.constant(5)) |
| self.assertEqual(7, result[0].a.numpy()) |
| self.assertEqual(10, result[0].b.numpy()) |
| self.assertEqual(["b", "a"], list(result[0]._asdict().keys())) |
| self.assertEqual(5, result[1].numpy()) |
| self.assertEqual(0.5, result[2]["x"].numpy()) |
| |
| def test_optimizer(self, cycles): |
| |
| class _HasOptimizer(module.Module): |
| |
| def __init__(self): |
| super(_HasOptimizer, self).__init__() |
| self.layer = core.Dense(1) |
| self.optimizer = adam.Adam(0.01) |
| |
| @def_function.function |
| def __call__(self, x): |
| return self.layer(x) |
| |
| @def_function.function |
| def train(self, x, y): |
| with backprop.GradientTape() as tape: |
| predicted = self(x) |
| loss = math_ops.reduce_sum(math_ops.abs(y - predicted)) |
| train_vars = self.layer.trainable_variables |
| grads = tape.gradient(loss, train_vars) |
| self.optimizer.apply_gradients(zip(grads, train_vars)) |
| |
| root = _HasOptimizer() |
| train_input = dict(x=constant_op.constant([[1.]]), |
| y=constant_op.constant([[2.]])) |
| root.train(**train_input) |
| imported = cycle(root, cycles) |
| self.assertAllClose(root.optimizer.learning_rate.numpy(), |
| imported.optimizer.learning_rate.numpy()) |
| self.assertAllClose(root(constant_op.constant([[-0.5]])), |
| imported(constant_op.constant([[-0.5]]))) |
| root.train(**train_input) |
| imported.train(**train_input) |
| self.assertAllClose(root(constant_op.constant([[-0.5]])), |
| imported(constant_op.constant([[-0.5]]))) |
| |
| def test_positional_arguments(self, cycles): |
| def func(x, training=False, abc=7.1, defg=7.7): |
| del abc |
| if training: |
| return 2 * x |
| if defg == 7: |
| return 6 |
| else: |
| return 7 |
| |
| root = tracking.AutoTrackable() |
| root.f = def_function.function(func) |
| |
| self.assertEqual(20, root.f(constant_op.constant(10), True).numpy()) |
| self.assertEqual(7, root.f(constant_op.constant(1)).numpy()) |
| self.assertEqual(2, root.f(constant_op.constant(1), True).numpy()) |
| self.assertEqual(6, root.f(constant_op.constant(1), defg=7.0).numpy()) |
| |
| imported = cycle(root, cycles) |
| |
| self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy()) |
| self.assertEqual(7, imported.f(constant_op.constant(2)).numpy()) |
| self.assertEqual(6, imported.f(constant_op.constant(1), defg=7.0).numpy()) |
| |
| def test_additional_kwargs(self, cycles): |
| def func(x, training=False, **options): |
| del options |
| if training: |
| return 2 * x |
| else: |
| return 7 |
| |
| root = tracking.AutoTrackable() |
| root.f = def_function.function(func) |
| |
| x = constant_op.constant(10) |
| self.assertEqual(7, root.f(x, learning_rate=0.5, epochs=3).numpy()) |
| |
| imported = cycle(root, cycles) |
| |
| with self.assertRaisesRegexp(ValueError, |
| "Could not find matching function to call.*"): |
| imported.f(x, learning_rate=0.5, epochs=4) |
| |
| self.assertEqual(7, imported.f(x, learning_rate=0.5, epochs=3).numpy()) |
| |
| def test_member_function(self, cycles): |
| class TrackableWithMember(tracking.AutoTrackable): |
| |
| def __init__(self): |
| super(TrackableWithMember, self).__init__() |
| self._some_value = 20 |
| |
| @def_function.function |
| def f(self, x, training=False): |
| if training: |
| return 2 * x |
| else: |
| return 7 + self._some_value |
| |
| root = TrackableWithMember() |
| |
| self.assertEqual(20, root.f(constant_op.constant(10), True).numpy()) |
| self.assertEqual(27, root.f(constant_op.constant(1)).numpy()) |
| self.assertEqual(2, root.f(constant_op.constant(1), True).numpy()) |
| |
| imported = cycle(root, cycles) |
| |
| self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy()) |
| self.assertEqual(27, imported.f(constant_op.constant(2)).numpy()) |
| |
| def test_side_effect_listing(self, cycles): |
| class M(tracking.AutoTrackable): |
| |
| def __init__(self): |
| super(M, self).__init__() |
| self.var = None |
| |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| def f(self, x): |
| if self.var is None: |
| self.var = variables.Variable(2.) |
| return x * self.var |
| |
| m = M() |
| cycle(m, cycles) |
| self.assertEqual(4.0, m.f(constant_op.constant(2.0)).numpy()) |
| |
| def test_basic_backprop(self, cycles): |
| weight = variables.Variable(1., trainable=True) |
| bias = variables.Variable(0., trainable=True) |
| g = def_function.function( |
| lambda x: x*weight + bias, |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| |
| root = tracking.AutoTrackable() |
| root.weight = weight |
| root.bias = bias |
| root.g = g |
| imported = cycle(root, cycles) |
| with backprop.GradientTape() as t: |
| x = constant_op.constant([3.5]) |
| loss = imported.g(x) |
| grad = t.gradient(loss, [imported.weight, imported.bias]) |
| self.assertAllClose(grad, [3.5, 1.0]) |
| |
| def test_nested_backprop(self, cycles): |
| weight = variables.Variable(1., trainable=True) |
| bias = variables.Variable(0., trainable=True) |
| |
| # Note: this function gets called from other function defs via a |
| # "PartitionedCall" op node. |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec(None, dtypes.float32), |
| tensor_spec.TensorSpec(None, dtypes.float32)]) |
| def mul(x, y): |
| return x * y |
| |
| # Note: this function gets called from other function defs via a |
| # "StatefulPartitionedCall" op node. |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec(None, dtypes.float32)]) |
| def f(x): |
| return mul(weight.read_value(), x) |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec(None, dtypes.float32)]) |
| def g(x): |
| return f(x) + bias, |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec(None, dtypes.float32)]) |
| def h(x): |
| return g(x) + bias, |
| |
| root = tracking.AutoTrackable() |
| root.weight = weight |
| root.bias = bias |
| root.g = h |
| |
| imported = cycle(root, cycles) |
| with backprop.GradientTape() as t: |
| x = constant_op.constant([3.5]) |
| loss = imported.g(x) |
| grad = t.gradient(loss, [imported.weight, imported.bias]) |
| self.assertAllClose(grad, [3.5, 2.0]) |
| |
| def test_while_loop_backprop(self, cycles): |
| weight = variables.Variable(2., trainable=True) |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))]) |
| def g(x): |
| """Adds rows of matrix x after multiplying each entry by v.""" |
| i_0 = constant_op.constant(0) |
| s_0 = constant_op.constant([0., 0.]) |
| cond = lambda i, _: i < array_ops.shape(x)[1] |
| body = lambda i, s: (i + 1, s + weight * x[:, i]) |
| i_end, s_end = control_flow_ops.while_loop(cond, body, (i_0, s_0)) |
| del i_end |
| return s_end |
| |
| root = tracking.AutoTrackable() |
| root.weight = weight |
| root.g = g |
| imported = cycle(root, cycles) |
| |
| def get_gradient(obj): |
| with backprop.GradientTape() as t: |
| x = constant_op.constant([[1., 2., 3.], [1., -2, 3.]]) |
| y = obj.g(x) |
| self.assertAllClose(y, obj.weight * [6., 2.]) |
| loss = math_ops.reduce_sum(y) # weight * 8. |
| self.assertAllEqual(t.watched_variables(), [obj.weight]) |
| return t.gradient(loss, obj.weight) |
| |
| imported_gradient = get_gradient(imported) |
| original_gradient = get_gradient(root) |
| self.assertIsNotNone(original_gradient) |
| self.assertAllClose(original_gradient, 8.) |
| self.assertIsNotNone(imported_gradient) |
| self.assertAllClose(imported_gradient, 8.) |
| |
| def _test_restored_func_with_captured_var_backprop(self, cycles, dtype): |
| weight = variables.Variable(2., trainable=True, dtype=dtype) |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec(dtype=dtype, shape=())]) |
| def g(x): |
| return x * weight |
| |
| root = tracking.AutoTrackable() |
| root.weight = weight |
| root.g = g |
| imported = cycle(root, cycles) |
| |
| def get_gradient(obj): |
| with backprop.GradientTape() as t: |
| x = constant_op.constant(2.) |
| y = obj.g(x) |
| self.assertAllClose(y, obj.weight * 2.) |
| self.assertAllEqual(t.watched_variables(), [obj.weight]) |
| return t.gradient(y, obj.weight) |
| |
| imported_gradient = get_gradient(imported) |
| original_gradient = get_gradient(root) |
| self.assertIsNotNone(original_gradient) |
| self.assertAllClose(original_gradient, 2.) |
| self.assertIsNotNone(imported_gradient) |
| self.assertAllClose(imported_gradient, 2.) |
| |
| def test_restored_func_with_captured_var_backprop_float32(self, cycles): |
| self._test_restored_func_with_captured_var_backprop(cycles, dtypes.float32) |
| |
| def test_restored_func_with_captured_var_backprop_float64(self, cycles): |
| self.skipTest("b/144573917") |
| self._test_restored_func_with_captured_var_backprop(cycles, dtypes.float64) |
| |
| def test_callable(self, cycles): |
| class M1(tracking.AutoTrackable): |
| |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| def __call__(self, x): |
| return x |
| |
| root = tracking.AutoTrackable() |
| root.m1 = M1() |
| root.m2 = tracking.AutoTrackable() |
| root.m2.__call__ = def_function.function( |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])( |
| lambda x: x*3.0) |
| imported = cycle(root, cycles) |
| x = constant_op.constant(1.0) |
| |
| self.assertTrue(callable(imported.m1)) |
| self.assertAllEqual(root.m1(x), imported.m1(x)) |
| |
| # Note: `root.m2` was not callable since `__call__` attribute was set |
| # into the instance and not on the class. But after a serialization cycle |
| # that starts to work. |
| self.assertTrue(callable(imported.m2)) |
| self.assertAllEqual(root.m2.__call__(x), imported.m2(x)) |
| |
| # Verify that user objects without `__call__` attribute are not callable. |
| self.assertFalse(callable(imported)) |
| |
| def test_chain_callable(self, cycles): |
| func = def_function.function( |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])( |
| lambda x: x*3.0) |
| root = tracking.AutoTrackable() |
| root.__call__ = tracking.AutoTrackable() |
| root.__call__.__call__ = tracking.AutoTrackable() |
| root.__call__.__call__.__call__ = func |
| |
| imported = cycle(root, cycles) |
| self.assertTrue(callable(imported)) |
| x = constant_op.constant(1.0) |
| self.assertAllEqual(imported(x).numpy(), 3.0) |
| |
| def test_load_in_graph_mode(self, cycles): |
| root = tracking.AutoTrackable() |
| root.v1 = variables.Variable(1., name="v_one", trainable=False) |
| root.v2 = variables.Variable(2., name="v_two", trainable=True) |
| root.f = def_function.function( |
| lambda x: root.v2 * x, |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| |
| if cycles > 1: |
| root = cycle(root, cycles - 1) |
| path = tempfile.mkdtemp(prefix=self.get_temp_dir()) |
| save.save(root, path) |
| |
| with ops.Graph().as_default() as g: |
| imported = load.load(path) |
| var_v1 = imported.v1 |
| self.assertFalse(var_v1.trainable) |
| var_v2 = imported.v2 |
| self.assertTrue(var_v2.trainable) |
| output = imported.f(constant_op.constant(2.)) |
| with monitored_session.MonitoredSession() as sess: |
| self.assertEqual(1.0, sess.run(var_v1)) |
| self.assertEqual(4.0, sess.run(output)) |
| self.assertCountEqual([var_v1, var_v2], |
| g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) |
| # load() should not add to TRAINABLE_VARIABLES. Higher levels of model |
| # building control retraining or frozen use of imported SavedModels. |
| self.assertCountEqual([], |
| g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) |
| |
| def test_load_in_func_graph(self, cycles): |
| root = tracking.AutoTrackable() |
| root.v1 = variables.Variable(1.) |
| root.v2 = variables.Variable(2.) |
| root.f = def_function.function( |
| lambda x: root.v2 * x, |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| |
| if cycles > 1: |
| root = cycle(root, cycles - 1) |
| path = tempfile.mkdtemp(prefix=self.get_temp_dir()) |
| save.save(root, path) |
| |
| closure = tracking.AutoTrackable() |
| @def_function.function |
| def func(x): |
| if not hasattr(closure, "model"): |
| closure.model = load.load(path) |
| return closure.model.f(x) |
| |
| inputs = constant_op.constant(2.) |
| self.assertEqual(4.0, func(inputs).numpy()) |
| |
| def test_soft_matching(self, cycles): |
| |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) |
| def func(x): |
| return 2 * x |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| |
| self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy()) |
| self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy()) |
| |
| concrete_functions = root.f._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access |
| self.assertEqual(1, len(concrete_functions)) |
| |
| imported = cycle(root, cycles) |
| |
| with self.assertRaisesRegexp(ValueError, "Python inputs incompatible"): |
| # We cannot call the function with a constant of shape (). |
| imported.f(constant_op.constant(2)).numpy() |
| |
| # TODO(vbardiovsky): When classes are revived with input_signatures, we |
| # should also check that the calls below are not generating any more |
| # concrete functions. |
| self.assertAllEqual([2, 4, 6, 8], |
| imported.f(constant_op.constant([1, 2, 3, 4])).numpy()) |
| self.assertAllEqual([2, 4, 6], |
| imported.f(constant_op.constant([1, 2, 3])).numpy()) |
| |
| def test_get_concrete_function(self, cycles): |
| |
| @def_function.function |
| def func(x, training=False): |
| if training: |
| return 2 * x |
| else: |
| return 3 * x |
| |
| func.get_concrete_function( |
| tensor_spec.TensorSpec([None], dtypes.int32), True) |
| func.get_concrete_function(tensor_spec.TensorSpec([None], dtypes.float32)) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| |
| imported = cycle(root, cycles) |
| |
| concrete = imported.f.get_concrete_function( |
| training=True, x=tensor_spec.TensorSpec([None], dtypes.int32)) |
| |
| self.assertAllEqual([2, 4, 6, 8], |
| concrete(x=constant_op.constant([1, 2, 3, 4])).numpy()) |
| with self.assertRaisesRegexp(ValueError, |
| "Could not find matching function to call"): |
| imported.f.get_concrete_function( |
| tensor_spec.TensorSpec([None], dtypes.int32)) |
| imported.f.get_concrete_function( |
| tensor_spec.TensorSpec([None], dtypes.int32), True) |
| |
| def test_concrete_function(self, cycles): |
| |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) |
| def func(x): |
| return 2 * x |
| |
| root = tracking.AutoTrackable() |
| root.f = func.get_concrete_function() |
| |
| self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy()) |
| self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy()) |
| |
| # TODO(andresp): Fix exporting of loaded concrete functions as signatures. |
| imported = cycle(root, cycles, signatures={}) |
| |
| self.assertAllEqual([2, 4, 6, 8], |
| imported.f(constant_op.constant([1, 2, 3, 4])).numpy()) |
| self.assertAllEqual([2, 4, 6], |
| imported.f(constant_op.constant([1, 2, 3])).numpy()) |
| |
| def test_concrete_function_captures(self, cycles): |
| |
| class Root(module.Module): |
| |
| def __init__(self): |
| self.v = variables.Variable(1.) |
| self.v1 = variables.Variable(1.) |
| |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| def use_v(self, x): |
| return self.v + self.v1 + 1. |
| |
| root = Root() |
| self.assertIn(root.v.handle, |
| root.use_v.get_concrete_function().graph.external_captures) |
| for _ in range(cycles): |
| root = cycle(root, 1, signatures=root.use_v.get_concrete_function()) |
| func_captures = root.use_v.get_concrete_function().graph.external_captures |
| self.assertLen(func_captures, 2) |
| self.assertTrue(any(root.v.handle is t for t in func_captures)) |
| self.assertTrue(any(root.v1.handle is t for t in func_captures)) |
| signature_captures = root.signatures[ |
| "serving_default"].graph.external_captures |
| self.assertLen(signature_captures, 2) |
| self.assertTrue(any(root.v.handle is t for t in signature_captures)) |
| self.assertTrue(any(root.v1.handle is t for t in signature_captures)) |
| |
| def test_concrete_function_arg_names(self, cycles): |
| |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) |
| def func(x): |
| return 2 * x |
| |
| root = tracking.AutoTrackable() |
| root.f = func.get_concrete_function() |
| |
| self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy()) |
| |
| # TODO(andresp): Fix exporting of loaded concrete functions as signatures. |
| imported = cycle(root, cycles, signatures={}) |
| |
| self.assertAllEqual([2, 4, 6], |
| imported.f(x=constant_op.constant([1, 2, 3])).numpy()) |
| |
| def test_concrete_function_no_signature(self, cycles): |
| @def_function.function |
| def func(x): |
| return 2 * x |
| |
| root = tracking.AutoTrackable() |
| root.f = func.get_concrete_function(constant_op.constant([1])) |
| self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy()) |
| # TODO(andresp): Fix exporting of loaded concrete functions as signatures. |
| imported = cycle(root, cycles, signatures={}) |
| self.assertAllEqual([6], |
| imported.f(constant_op.constant([3])).numpy()) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_concrete_function_backprop(self, cycles): |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec([], dtypes.float32)]) |
| def func(x): |
| return x ** 2. |
| root = tracking.AutoTrackable() |
| root.f = func.get_concrete_function() |
| |
| def _compute_gradient(function): |
| with backprop.GradientTape() as tape: |
| inp = constant_op.constant(1.) |
| tape.watch(inp) |
| output = function(inp) |
| return tape.gradient(output, inp) |
| |
| self.assertAllEqual(2., _compute_gradient(root.f)) |
| # TODO(andresp): Fix exporting of loaded concrete functions as signatures. |
| imported = cycle(root, cycles, signatures={}) |
| self.assertAllEqual(2., _compute_gradient(imported.f)) |
| |
| def test_revived_concrete_function_kwargs(self, cycles): |
| |
| @def_function.function |
| def func(x, y): |
| return x * (y + 1.) |
| root = tracking.AutoTrackable() |
| root.f = func.get_concrete_function( |
| tensor_spec.TensorSpec([], dtypes.float32), |
| tensor_spec.TensorSpec([], dtypes.float32)) |
| self.assertEqual(8., root.f(y=constant_op.constant(3.), |
| x=constant_op.constant(2.)).numpy()) |
| # TODO(andresp): Fix exporting of loaded concrete functions as signatures. |
| imported = cycle(root, cycles, signatures={}) |
| self.assertEqual(8., imported.f(y=constant_op.constant(3.), |
| x=constant_op.constant(2.)).numpy()) |
| |
| def test_revived_concrete_function_tensorspec_kwargs(self, cycles): |
| |
| @def_function.function |
| def func(*args): |
| x, y = args |
| return x * (y + 1.) |
| root = tracking.AutoTrackable() |
| root.f = func.get_concrete_function( |
| tensor_spec.TensorSpec([], dtypes.float32, name="x"), |
| tensor_spec.TensorSpec([], dtypes.float32, name="y")) |
| self.assertEqual(8., root.f(y=constant_op.constant(3.), |
| x=constant_op.constant(2.)).numpy()) |
| imported = cycle(root, cycles, signatures={}) |
| self.assertEqual(8., imported.f(y=constant_op.constant(3.), |
| x=constant_op.constant(2.)).numpy()) |
| |
| def test_concrete_function_variable_argument(self, cycles): |
| capture = variables.Variable(0) |
| |
| @def_function.function |
| def func(v): |
| v.assign_add(1) |
| capture.assign_sub(1) |
| |
| @def_function.function(input_signature=[ |
| resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) |
| ]) |
| def func_with_input_signature(v): |
| v.assign_add(5) |
| capture.assign_sub(5) |
| return 1 |
| |
| vsave = variables.Variable(1) |
| root = tracking.AutoTrackable() |
| root.f = func.get_concrete_function(vsave) |
| root.f_sig = func_with_input_signature.get_concrete_function() |
| root.capture = capture |
| |
| self.assertEqual(1, vsave.numpy()) |
| root.f(vsave) |
| self.assertEqual(2, vsave.numpy()) |
| self.assertEqual(-1, capture.numpy()) |
| |
| root.f_sig(vsave) |
| self.assertEqual(7, vsave.numpy()) |
| self.assertEqual(-6, capture.numpy()) |
| |
| imported = cycle(root, cycles) |
| |
| vload = variables.Variable(1) |
| imported.f(vload) |
| self.assertEqual(2, vload.numpy()) |
| imported.f(v=vload) |
| self.assertEqual(3, vload.numpy()) |
| self.assertEqual(-8, imported.capture.numpy()) |
| |
| imported.f_sig(v=vload) |
| self.assertEqual(8, vload.numpy()) |
| self.assertEqual(-13, imported.capture.numpy()) |
| |
| self.assertEqual(-6, capture.numpy()) |
| |
| def test_function_and_component(self, cycles): |
| |
| @def_function.function |
| def func(v): |
| return v + 1 |
| |
| root = tracking.AutoTrackable() |
| root.func = func |
| root.concrete_func = func.get_concrete_function( |
| tensor_spec.TensorSpec(None, dtypes.int32)) |
| one = constant_op.constant(1) |
| self.assertEqual(2, root.func(one).numpy()) |
| self.assertEqual(2, root.concrete_func(one).numpy()) |
| imported = cycle(root, cycles) |
| self.assertEqual(2, imported.func(one).numpy()) |
| self.assertEqual(2, imported.concrete_func(one).numpy()) |
| |
| def test_dict(self, cycles): |
| root = tracking.AutoTrackable() |
| root.variables = dict(a=variables.Variable(1.)) |
| root.variables["b"] = variables.Variable(2.) |
| root.variables["c"] = 1 |
| root.funcs = dict( |
| a=def_function.function(lambda: constant_op.constant(100.))) |
| root.funcs["conc"] = root.funcs["a"].get_concrete_function() |
| imported = cycle(root, cycles) |
| self.assertEqual(1., imported.variables["a"].numpy()) |
| self.assertEqual(2., imported.variables["b"].numpy()) |
| self.assertEqual(set(["a", "b"]), set(imported.variables.keys())) |
| self.assertEqual(100., imported.funcs["a"]().numpy()) |
| self.assertEqual(100., imported.funcs["conc"]().numpy()) |
| |
| def test_list(self, cycles): |
| root = tracking.AutoTrackable() |
| root.variables = [variables.Variable(1.)] |
| root.variables.append(1) |
| root.variables.append(variables.Variable(3.)) |
| imported = cycle(root, cycles) |
| self.assertEqual(1., imported.variables[0].numpy()) |
| self.assertEqual(3., imported.variables[2].numpy()) |
| self.assertIs(None, imported.variables[1]) |
| self.assertEqual(3, len(imported.variables)) |
| |
| def test_tuple(self, cycles): |
| root = tracking.AutoTrackable() |
| root.variables = (variables.Variable(1.), 1, variables.Variable(3.)) |
| imported = cycle(root, cycles) |
| self.assertEqual(1., imported.variables[0].numpy()) |
| self.assertEqual(3., imported.variables[2].numpy()) |
| self.assertIs(None, imported.variables[1]) |
| self.assertLen(imported.variables, 3) |
| |
| def test_functions_list(self, cycles): |
| root = tracking.AutoTrackable() |
| v1 = variables.Variable(1.) |
| root.losses = [def_function.function(lambda: math_ops.reduce_sum(v1 ** 2))] |
| root.variables = [v1] |
| |
| @def_function.function |
| def _v2_loss(): |
| if len(root.variables) == 1: |
| v2 = variables.Variable(2.) |
| root.variables.append(v2) |
| return math_ops.reduce_sum(root.variables[1] ** 2) |
| |
| root.losses.append(_v2_loss) |
| self.assertAllClose([1., 4.], [loss() for loss in root.losses]) |
| imported = cycle(root, cycles) |
| self.assertAllClose([1., 4.], [loss() for loss in imported.losses]) |
| imported.variables[0].assign(3.) |
| imported.variables[1].assign(4.) |
| self.assertAllClose([9., 16.], [loss() for loss in imported.losses]) |
| |
| def test_captured_constant(self, cycles): |
| const = array_ops.zeros([100]) |
| root = tracking.AutoTrackable() |
| root.f = def_function.function(lambda: const + 1.) |
| root.g = def_function.function(lambda: const + 2.) |
| self.assertAllClose(array_ops.ones([100]), root.f()) |
| self.assertAllClose(2. * array_ops.ones([100]), root.g()) |
| imported = cycle(root, cycles) |
| self.assertAllClose(array_ops.ones([100]), imported.f()) |
| self.assertAllClose(2. * array_ops.ones([100]), imported.g()) |
| # TODO(b/123408994): Use the public get_concrete_function. |
| f_concrete = imported.f._list_all_concrete_functions_for_serialization()[0] |
| g_concrete = imported.g._list_all_concrete_functions_for_serialization()[0] |
| self.assertLen(f_concrete.captured_inputs, 1) |
| self.assertLen(g_concrete.captured_inputs, 1) |
| # We should be using the same captured EagerTensor in both functions, not |
| # duplicating the constant. |
| self.assertIs(f_concrete.captured_inputs[0], |
| g_concrete.captured_inputs[0]) |
| |
| def test_functions_accessed_once(self, cycles): |
| |
| class Exported(tracking.AutoTrackable): |
| |
| def __init__(self): |
| self._counter = 0 |
| |
| @property |
| def make_func(self): |
| @def_function.function |
| def f(): |
| return constant_op.constant(self._counter) |
| f.get_concrete_function() # force a trace |
| self._counter += 1 |
| return f |
| |
| exported = Exported() |
| imported = cycle(exported, cycles) |
| self.assertEqual(0, imported.make_func().numpy()) |
| self.assertEqual(1, exported.make_func().numpy()) |
| |
| def test_overwritten_signatures_error(self, cycles): |
| exported = tracking.AutoTrackable() |
| exported.f = def_function.function(lambda: constant_op.constant(1.)) |
| imported = cycle( |
| exported, cycles, |
| signatures={"key": exported.f.get_concrete_function()}) |
| self.assertEqual(1., imported.signatures["key"]()["output_0"].numpy()) |
| imported.signatures = {"key1": imported.signatures["key"]} |
| with self.assertRaisesRegexp(ValueError, "signatures"): |
| save.save(imported, tempfile.mkdtemp(prefix=self.get_temp_dir())) |
| |
| def test_signature_loading(self, cycles): |
| |
| class Exported(tracking.AutoTrackable): |
| |
| def __init__(self): |
| self.v = variables.Variable(3.) |
| |
| @def_function.function |
| def do(self, x): |
| return self.v * x |
| |
| exported = Exported() |
| imported = cycle( |
| exported, |
| cycles=1, |
| signatures=exported.do.get_concrete_function( |
| tensor_spec.TensorSpec(None, dtypes.float32))) |
| for _ in range(cycles - 1): |
| imported = cycle(imported, cycles=1, signatures=imported.signatures) |
| self.assertEqual(["serving_default"], list(imported.signatures.keys())) |
| imported_function = imported.signatures["serving_default"] |
| two = constant_op.constant(2.) |
| self.assertEqual(6., imported_function(x=two)["output_0"].numpy()) |
| imported.v.assign(4.) |
| self.assertEqual(8., imported_function(x=two)["output_0"].numpy()) |
| self.assertEqual(8., imported_function(two)["output_0"].numpy()) |
| with self.assertRaises(TypeError): |
| # The signatures mapping is immutable |
| imported.signatures["random_key"] = 3 |
| |
| def test_multiple_argument_signatures_no_positional(self, cycles): |
| |
| class Exported(tracking.AutoTrackable): |
| |
| @def_function.function |
| def do(self, x, y): |
| return x + y |
| |
| exported = Exported() |
| imported = cycle( |
| exported, cycles=1, signatures=exported.do.get_concrete_function( |
| tensor_spec.TensorSpec(None, dtypes.float32), |
| tensor_spec.TensorSpec(None, dtypes.float32))) |
| for _ in range(cycles - 1): |
| imported = cycle(imported, cycles=1, signatures=imported.signatures) |
| with self.assertRaises(TypeError): |
| imported.signatures["serving_default"]( |
| constant_op.constant(1.), |
| y=constant_op.constant(2.)) |
| self.assertEqual( |
| {"output_0": 3.}, |
| self.evaluate(imported.signatures["serving_default"]( |
| x=constant_op.constant(1.), |
| y=constant_op.constant(2.)))) |
| |
| def _make_model_with_tables(self): |
| default_val = -1 |
| keys = constant_op.constant(["brain", "salad", "surgery"]) |
| values = constant_op.constant([0, 1, 2], dtypes.int64) |
| table1_initializer = lookup_ops.KeyValueTensorInitializer(keys, values) |
| table1 = lookup_ops.HashTable(table1_initializer, default_val) |
| |
| table2_file = self._make_asset("test\nfoo\nbrain\n") |
| table2_initializer = lookup_ops.TextFileIdTableInitializer(table2_file) |
| table2 = lookup_ops.HashTable(table2_initializer, default_val) |
| |
| def _make_lookup_function(table): |
| signature = [tensor_spec.TensorSpec(None, dtypes.string)] |
| return def_function.function(input_signature=signature)( |
| lambda x: table.lookup(x)) # pylint: disable=unnecessary-lambda |
| |
| root = tracking.AutoTrackable() |
| root.table1 = table1 |
| root.lookup1 = _make_lookup_function(table1) |
| root.table2 = table2 |
| root.lookup2 = _make_lookup_function(table2) |
| return root |
| |
| def test_table(self, cycles): |
| root = self._make_model_with_tables() |
| imported = cycle(root, cycles, signatures={}) |
| keys = constant_op.constant(["brain", "test", "foo", "surgery"]) |
| self.assertAllEqual([0, -1, -1, 2], imported.lookup1(keys).numpy()) |
| self.assertAllEqual([2, 0, 1, -1], imported.lookup2(keys).numpy()) |
| |
| def test_table_collections_untouched_eager(self, cycles): |
| |
| def _gather_nonempty_collections(): |
| graph = ops.get_default_graph() |
| gathered = {} |
| for collection in graph.collections: |
| collection_contents = graph.get_collection(collection) |
| if collection_contents: |
| gathered[collection] = collection_contents |
| return gathered |
| |
| root = self._make_model_with_tables() |
| # Warm up collections to ignore those that don't expand every iteration, |
| # e.g. the __varscope collection. |
| cycle(root, 1) |
| original_collections = _gather_nonempty_collections() |
| cycle(root, cycles) |
| self.assertEqual(original_collections, _gather_nonempty_collections()) |
| |
| def test_table_in_graph(self, cycles): |
| root = self._make_model_with_tables() |
| |
| if cycles > 1: |
| root = cycle(root, cycles - 1) |
| path = tempfile.mkdtemp(prefix=self.get_temp_dir()) |
| save.save(root, path) |
| imported = cycle(root, 1) |
| |
| with ops.Graph().as_default(): |
| imported = load.load(path) |
| keys = constant_op.constant(["brain", "test", "foo", "surgery"]) |
| output1 = imported.lookup1(keys) |
| output2 = imported.lookup2(keys) |
| with monitored_session.MonitoredSession() as sess: |
| self.assertAllEqual([0, -1, -1, 2], sess.run(output1)) |
| self.assertAllEqual([2, 0, 1, -1], sess.run(output2)) |
| |
| def test_preserve_argspec(self, cycles): |
| |
| def f(a, b, c): # pylint: disable=unused-argument |
| return None |
| |
| original_fullargspec = tf_inspect.getfullargspec(f) |
| |
| root = tracking.AutoTrackable() |
| root.f = def_function.function(f) |
| imported = cycle(root, cycles) |
| |
| restored_fullargspec = tf_inspect.getfullargspec(imported.f) |
| self.assertEqual(original_fullargspec, restored_fullargspec) |
| |
| def test_canonicalize_inputs(self, cycles): |
| @def_function.function(autograph=False) |
| def func(a=1, b=2, c=3, training=True): |
| if training: |
| return [a, b, c, training] |
| else: |
| return [c, b, a, training] |
| |
| # TODO(b/123501567): Work-around to trigger generic traces of a function |
| # with extra non tensor args. |
| signature = 3*[tensor_spec.TensorSpec(None, dtypes.float32)] |
| @def_function.function(input_signature=signature) |
| def trigger(a, b, c): |
| func(a, b, c, True) |
| func(a, b, c, False) |
| |
| trigger.get_concrete_function() |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| root = cycle(root, cycles) |
| self.assertAllEqual(root.f(), [1.0, 2.0, 3.0, True]) |
| self.assertAllEqual(root.f(-1.0, training=False), [3.0, 2.0, -1.0, False]) |
| |
| with self.assertRaisesRegexp(ValueError, |
| "Could not find matching function"): |
| root.f(["hello", 1.0]) |
| |
| def test_prefer_specific_trace(self, cycles): |
| @def_function.function(autograph=False) |
| def func(a): |
| if isinstance(a, int): |
| return a |
| else: |
| return a + 1 |
| |
| self.assertAllEqual(2, func(2).numpy()) |
| self.assertAllEqual(3, func(constant_op.constant(2)).numpy()) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| root = cycle(root, cycles) |
| self.assertAllEqual(2, root.f(2).numpy()) |
| self.assertAllEqual(4, root.f(3).numpy()) |
| self.assertAllEqual(3, root.f(constant_op.constant(2)).numpy()) |
| self.assertAllEqual(4, root.f(constant_op.constant(3)).numpy()) |
| |
| def test_partial(self, cycles): |
| def f(x, y): |
| return x + y |
| |
| func = def_function.function( |
| functools.partial(f, x=array_ops.zeros([1]), y=array_ops.ones([1]))) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| self.assertAllEqual(root.f(), [1.0]) |
| |
| root = cycle(root, cycles) |
| self.assertAllEqual(root.f(), [1.0]) |
| |
| def test_partial_with_non_tensor_defaults(self, cycles): |
| |
| def f(x, y=3): |
| return x + y |
| |
| func = def_function.function(functools.partial(f, y=5)) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| self.assertAllEqual(root.f(1), 6) |
| |
| root = cycle(root, cycles) |
| self.assertAllEqual(root.f(1), 6) |
| |
| def test_partial_with_positional(self, cycles): |
| def f(x, y): |
| return x + y |
| |
| func = def_function.function(functools.partial(f, constant_op.constant(5))) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| self.assertAllEqual(root.f(1), 6) |
| |
| root = cycle(root, cycles) |
| self.assertAllEqual(root.f(1), 6) |
| |
| def test_partial_with_positional_captured_tensors(self, cycles): |
| |
| def f(x, y): |
| return x + y |
| |
| tensor = constant_op.constant(5) + constant_op.constant(7) |
| func = def_function.function(functools.partial(f, tensor)) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| self.assertAllEqual(root.f(1), 13) |
| |
| root = cycle(root, cycles) |
| self.assertAllEqual(root.f(1), 13) |
| |
| def test_partial_keyword_hiding_default(self, cycles): |
| |
| def f(x=3, training=True, y=7): |
| if training: |
| return x + y |
| else: |
| return x + y + 2 |
| |
| func = def_function.function(functools.partial(f, y=6)) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| self.assertEqual(root.f().numpy(), 9) |
| self.assertEqual(root.f(training=False).numpy(), 11) |
| |
| root = cycle(root, cycles) |
| self.assertEqual(root.f().numpy(), 9) |
| self.assertEqual(root.f(training=False).numpy(), 11) |
| |
| def test_partial_with_kwargs(self, cycles): |
| |
| def f(a, b, *args, **kwargs): |
| args_sum = sum(args) |
| return a + b + kwargs["some_tensor"] * kwargs["learning_rate"] + args_sum |
| |
| constant_tensor = constant_op.constant(10) |
| func = def_function.function( |
| functools.partial( |
| f, 7, 1, 2, learning_rate=3, some_tensor=constant_tensor)) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| self.assertEqual(root.f(constant_op.constant(4)).numpy(), 44) |
| |
| root = cycle(root, cycles) |
| self.assertEqual(root.f(constant_op.constant(5)).numpy(), 45) |
| |
| def test_partial_bind_only_first_argument(self, cycles): |
| if sys.version_info[0] < 3: |
| self.skipTest("Test is only valid in python3. Only then we get some more " |
| "advanced inspection of partials where this is allowed.") |
| |
| def f(x, y): |
| return x + y |
| |
| partial_func = functools.partial(f, x=5) |
| tf_func = def_function.function(partial_func) |
| |
| root = tracking.AutoTrackable() |
| root.f = tf_func |
| self.assertAllEqual(root.f(y=constant_op.constant(7)), 12) |
| |
| root = cycle(root, cycles) |
| self.assertAllEqual(root.f(y=constant_op.constant(9)), 14) |
| |
| def test_partial_with_passed_fn_as_default(self, cycles): |
| |
| def f(x, y): |
| return x(3) + y |
| |
| def my_func(a): |
| return 2 * a |
| |
| func = def_function.function(functools.partial(f, my_func)) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9) |
| |
| root = cycle(root, cycles) |
| self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9) |
| |
| def test_partial_with_input_signature(self, cycles): |
| |
| def full_function(a, b, c=3.0): |
| return a, b, c |
| |
| partial = functools.partial(full_function, 1, c=4) |
| self.assertAllEqual((1, 2.0, 4), partial(2.0)) |
| |
| signature = [tensor_spec.TensorSpec([], dtypes.float32)] |
| func = def_function.function(partial, input_signature=signature) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| a, b, c = root.f(2.0) |
| self.assertAllEqual([a.numpy(), b.numpy(), c.numpy()], (1, 2.0, 4)) |
| |
| root = cycle(root, cycles) |
| a, b, c = root.f(3.0) |
| self.assertAllEqual([a.numpy(), b.numpy(), c.numpy()], (1, 3.0, 4)) |
| |
| def test_convert_to_input_signature(self, cycles): |
| |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) |
| def func(x): |
| return x |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| |
| root = cycle(root, cycles) |
| |
| self.assertEqual([2], root.f([2]).numpy()) |
| |
| def test_named_tuple(self, cycles): |
| |
| class NamedTupleType(collections.namedtuple("NamedTupleType", ["a", "b"])): |
| pass |
| |
| @def_function.function |
| def f(x): |
| return x.a + x.b |
| |
| f.get_concrete_function( |
| NamedTupleType( |
| a=tensor_spec.TensorSpec(None, dtypes.float32, name="a"), |
| b=tensor_spec.TensorSpec(None, dtypes.float32, name="b"))) |
| obj = tracking.AutoTrackable() |
| obj.__call__ = f |
| if sys.version_info.major == 3 and sys.version_info.minor < 5: |
| # TODO(allenl): figure out why this doesn't work in Python3.4 |
| self.skipTest("Not working in Python 3.4") |
| imported = cycle(obj, cycles) |
| self.assertAllClose(3., |
| imported(NamedTupleType(a=constant_op.constant(1.), |
| b=constant_op.constant(2.)))) |
| |
| def test_extra_args(self, cycles): |
| |
| @def_function.function |
| def f(x): |
| return math_ops.add(x["a"], 1.) |
| # Trigger a trace. |
| f({"a": constant_op.constant(2.0)}) |
| |
| obj = tracking.AutoTrackable() |
| obj.__call__ = f |
| imported = cycle(obj, cycles) |
| |
| self.assertEqual(4.0, imported({"a": 3.0}).numpy()) |
| |
| with self.assertRaisesRegexp(ValueError, |
| "Could not find matching function to call"): |
| imported({"a": 2.0, "b": 3.0}) |
| |
| def test_shapes_available(self, cycles): |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec([None, 3], dtypes.int32), |
| tensor_spec.TensorSpec([None, 2], dtypes.int32) |
| ]) |
| def func(x, y): |
| return array_ops.concat([x, y], axis=1) |
| |
| root = tracking.AutoTrackable() |
| root.f = func |
| |
| root = cycle(root, cycles) |
| |
| imported_graph = root.f.get_concrete_function().graph |
| input_x, input_y = imported_graph.inputs |
| self.assertEqual([None, 3], input_x.shape.as_list()) |
| self.assertEqual([None, 2], input_y.shape.as_list()) |
| output, = imported_graph.outputs |
| self.assertEqual([None, 5], output.shape.as_list()) |
| signature = root.signatures["serving_default"] |
| self.assertEqual( |
| [None, 3], signature.inputs[0].shape.as_list()) |
| self.assertEqual( |
| [None, 2], signature.inputs[1].shape.as_list()) |
| self.assertEqual( |
| [None, 5], signature.outputs[0].shape.as_list()) |
| |
| def test_variables_destroyed(self, cycles): |
| v1 = variables.Variable(1.) |
| weak_v1 = weakref.ref(v1) |
| root = util.Checkpoint(v=v1) |
| root = cycle(root, cycles) |
| del v1 |
| self.assertIsNone(weak_v1()) |
| weak_v2 = weakref.ref(root.v) |
| del root |
| self.assertIsNone(weak_v2()) |
| |
| def test_variable_attributes_preserved(self, cycles): |
| v = variables.Variable( |
| 1., |
| trainable=False, |
| synchronization=variables.VariableSynchronization.NONE, |
| aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) |
| self.assertEqual(variables.VariableSynchronization.NONE, |
| v.synchronization) |
| self.assertEqual(variables.VariableAggregation.ONLY_FIRST_REPLICA, |
| v.aggregation) |
| root = tracking.AutoTrackable() |
| root.v = v |
| root = cycle(root, cycles) |
| self.assertEqual(False, root.v.trainable) |
| self.assertEqual(variables.VariableSynchronization.NONE, |
| root.v.synchronization) |
| self.assertEqual(variables.VariableAggregation.ONLY_FIRST_REPLICA, |
| root.v.aggregation) |
| |
| def test_captured_dataset(self, cycles): |
| |
| class HasDataset(module.Module): |
| |
| def __init__(self): |
| super(HasDataset, self).__init__() |
| self.dataset = ( |
| dataset_ops.Dataset.range(5) |
| .map(lambda x: x ** 2)) |
| |
| @def_function.function |
| def __call__(self, x): |
| current_sum = array_ops.zeros([], dtype=dtypes.int64) |
| for element in self.dataset: |
| current_sum += x * element |
| return current_sum |
| |
| root = HasDataset() |
| self.assertEqual( |
| 3 * (1 + 4 + 9 + 16), |
| root(constant_op.constant(3, dtype=dtypes.int64)).numpy()) |
| root = cycle(root, cycles) |
| self.assertEqual( |
| 3 * (1 + 4 + 9 + 16), |
| root(constant_op.constant(3, dtype=dtypes.int64)).numpy()) |
| |
| def test_tuple_signature(self, cycles): |
| root = util.Checkpoint() |
| root.f = def_function.function( |
| lambda: (array_ops.ones([]), array_ops.zeros([])), |
| input_signature=()) |
| for _ in range(cycles): |
| root = cycle(root, 1, signatures=root.f) |
| self.assertEqual(({"output_0": 1., "output_1": 0.}), |
| self.evaluate(root.signatures["serving_default"]())) |
| |
| def test_model_with_custom_function_attached(self, cycles): |
| root = util.Checkpoint(model=sequential.Sequential([core.Dense(2)])) |
| |
| @def_function.function |
| def _use_sequential(x): |
| return root.model.call(x) |
| |
| root.model.traced_call = _use_sequential |
| |
| original = root.model.traced_call(array_ops.zeros([1, 1])).numpy() |
| root = cycle(root, cycles) |
| self.assertAllEqual( |
| original, |
| root.model.traced_call(array_ops.zeros([1, 1])).numpy()) |
| |
| def test_version_info(self, cycles): |
| root = util.Checkpoint() |
| root = cycle(root, cycles) |
| self.assertEqual(versions.__version__, root.tensorflow_version) |
| self.assertEqual(versions.__git_version__, root.tensorflow_git_version) |
| |
| def test_load_grad_save(self, cycles): |
| root = util.Checkpoint() |
| root.v = variables.Variable(2.) |
| root.f = def_function.function(lambda x: root.v * x) |
| root.g = def_function.function(root.f) |
| for _ in range(cycles): |
| with backprop.GradientTape() as tape: |
| inp = constant_op.constant(2.) |
| tape.watch(inp) |
| output = root.g(inp) |
| self.assertAllClose(4., output) |
| self.assertAllClose(2., tape.gradient(output, inp)) |
| root = cycle(root, 1) |
| |
| def test_destroy_resource(self, cycles): |
| |
| def get_handle(): |
| return resource_variable_ops.var_handle_op( |
| shape=tensor_shape.as_shape([]), |
| dtype=dtypes.float32, |
| shared_name="my_var_name", |
| name="my_var", |
| container="my_container") |
| |
| class MyResourceDeleter(tracking.CapturableResourceDeleter): |
| |
| def destroy_resource(self): |
| handle = get_handle() |
| resource_variable_ops.destroy_resource_op( |
| handle, ignore_lookup_error=True) |
| |
| class MyResource(tracking.TrackableResource): |
| |
| def __init__(self): |
| # Set the resource deleter, so when the resource object goes out of |
| # scope it will be deleted automatically. |
| super(MyResource, self).__init__(deleter=MyResourceDeleter()) |
| |
| def _create_resource(self): |
| return get_handle() |
| |
| def _initialize(self): |
| resource_variable_ops.assign_variable_op( |
| self.resource_handle, 1.0, name="assign") |
| |
| class MyModel(tracking.AutoTrackable): |
| |
| def __init__(self): |
| super(MyModel, self).__init__() |
| self.resource = MyResource() |
| |
| @def_function.function(input_signature=[]) |
| def increase(self): |
| handle = self.resource.resource_handle |
| resource_variable_ops.assign_add_variable_op( |
| handle, 10.0, name="assign_add") |
| return resource_variable_ops.read_variable_op(handle, dtypes.float32) |
| |
| root = MyModel() |
| imported = cycle(root, cycles) |
| self.assertEqual(11, imported.increase().numpy()) # Create the resource. |
| |
| handle = imported.resource.resource_handle |
| |
| # Delete the imported SaveModel. Since we explicitly set the deleter, it |
| # should destroy the resource automatically. |
| del imported |
| |
| # Try to destroy the resource again, should fail. |
| with self.assertRaisesRegexp(errors.NotFoundError, |
| r"Resource .* does not exist."): |
| resource_variable_ops.destroy_resource_op( |
| handle, ignore_lookup_error=False) |
| |
| def test_function_called_as_operation(self, cycles): |
| |
| @framework_function.Defun(dtypes.float32) |
| def inner(x): |
| return x + 1. |
| |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec([], dtypes.float32)]) |
| def outer(x): |
| return inner(x) |
| |
| root = module.Module() |
| root.f = outer |
| imported = cycle(root, cycles) |
| self.assertAllClose(2., imported.f(constant_op.constant(1.))) |
| |
| def test_ragged(self, cycles): |
| |
| @def_function.function |
| def f(x, c=1): |
| """Returns Tensor x incremented by Python constant c.""" |
| return math_ops.add(x, c) |
| |
| for c in (1, 2, 3): |
| _ = f.get_concrete_function( |
| ragged_tensor.RaggedTensorSpec([None, None], dtype=dtypes.int32), |
| c) |
| |
| obj = tracking.AutoTrackable() |
| obj.f = f |
| |
| imported1 = cycle(obj, cycles, signatures={}) |
| rt = ragged_factory_ops.constant([[1, 2], [3]]) |
| self.assertAllEqual(imported1.f(rt), [[2, 3], [4]]) |
| self.assertAllEqual(imported1.f(rt, 2), [[3, 4], [5]]) |
| self.assertAllEqual(imported1.f(rt, 3), [[4, 5], [6]]) |
| |
| imported2 = cycle(obj, cycles) |
| rt = ragged_factory_ops.constant([[1, 2], [3]]) |
| self.assertAllEqual(imported2.f(rt, 1), [[2, 3], [4]]) |
| self.assertAllEqual(imported2.f(rt, 2), [[3, 4], [5]]) |
| self.assertAllEqual(imported2.f(rt, 3), [[4, 5], [6]]) |
| |
| |
| @keras_parameterized.run_all_keras_modes(always_skip_v1=True) |
| @parameterized.named_parameters( |
| dict(testcase_name="ReloadOnce", cycles=1), |
| dict(testcase_name="ReloadTwice", cycles=2), |
| dict(testcase_name="ReloadThrice", cycles=3)) |
| class KerasLoadTest(test.TestCase, parameterized.TestCase): |
| |
| def test_dense_features_layer(self, cycles): |
| columns = [ |
| feature_column_lib.numeric_column("x"), |
| feature_column_lib.numeric_column("y") |
| ] |
| layer = feature_column_lib.DenseFeatures(columns) |
| model = sequential.Sequential([layer]) |
| model_input = {"x": constant_op.constant([[1.]]), |
| "y": constant_op.constant([[2.]])} |
| self.assertAllClose([[1., 2.]], model.predict(model_input, steps=1)) |
| loaded = cycle(model, cycles) |
| output, = loaded._default_save_signature(model_input).values() |
| self.assertAllClose([[1., 2.]], output) |
| signature_output, = loaded.signatures["serving_default"]( |
| **model_input).values() |
| self.assertAllClose([[1., 2.]], signature_output) |
| |
| def test_dense_features_layer_fit(self, cycles): |
| columns = [feature_column_lib.numeric_column("x")] |
| model = sequential.Sequential( |
| [feature_column_lib.DenseFeatures(columns), |
| core.Dense(1)]) |
| model_input = {"x": constant_op.constant([[1.]])} |
| model.compile(optimizer="adam", loss="mse", run_eagerly=True) |
| model.fit(model_input, constant_op.constant([[3.]])) |
| loaded = cycle(model, cycles) |
| loaded._default_save_signature(model_input) |
| loaded.signatures["serving_default"](**model_input) |
| |
| def test_multi_output_layer(self, cycles): |
| |
| inp = input_layer.Input(name="inp", shape=(None,), dtype=dtypes.float32) |
| |
| class _MultiOutput(base_layer.Layer): |
| |
| def call(self, x): |
| return x + 1., x + 2. |
| |
| out = _MultiOutput(name="out")(inp) |
| model = training_lib.Model(inp, out) |
| loaded = cycle(model, cycles) |
| self.assertAllClose( |
| dict(out=2., out_1=3.), |
| loaded.signatures["serving_default"](constant_op.constant(1.))) |
| |
| def test_functional_model_with_conv(self, cycles): |
| x = input_layer.Input(name="x", shape=(None, None, 3), dtype=dtypes.float32) |
| conved = convolutional.Conv2D(filters=3, kernel_size=3, dilation_rate=2)(x) |
| model = training_lib.Model([x], conved) |
| model_input = array_ops.ones((1, 10, 10, 3)) |
| initial_output = model.predict([model_input]) |
| model = cycle(model, cycles) |
| self.assertAllClose( |
| [initial_output], |
| list(model.signatures["serving_default"](model_input).values())) |
| |
| |
| class SingleCycleTests(test.TestCase, parameterized.TestCase): |
| |
| def test_load_with_tags(self): |
| root = tracking.AutoTrackable() |
| path = tempfile.mkdtemp(prefix=self.get_temp_dir()) |
| save.save(root, path) |
| with self.assertRaises(ValueError): |
| load.load(path, tags=[tag_constants.EVAL]) |
| load.load(path, tags=[tag_constants.SERVING]) |
| load.load(path, tags=tag_constants.SERVING) |
| load.load(path, tags=set([tag_constants.SERVING])) |
| |
| def test_docstring_examples(self): |
| path = tempfile.mkdtemp(prefix=self.get_temp_dir()) |
| exported = util.Checkpoint(v=variables.Variable(3.)) |
| exported.f = def_function.function( |
| lambda x: exported.v * x, |
| input_signature=[ |
| tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)]) |
| save.save(exported, path) |
| imported = load.load(path) |
| self.assertEqual(3., imported.v.numpy()) |
| self.assertEqual(6., imported.f(x=constant_op.constant(2.)).numpy()) |
| |
| save.save(exported, path, exported.f.get_concrete_function()) |
| imported = load.load(path) |
| f = imported.signatures["serving_default"] |
| self.assertAllEqual( |
| [[-3.]], |
| f(x=constant_op.constant([[-1.]]))["output_0"].numpy()) |
| |
| |
| def test_object_with_extra_dependencies(self): |
| |
| class Extra(tracking.AutoTrackable): |
| |
| def _list_extra_dependencies_for_serialization(self, cache): |
| if self not in cache: |
| cache[self] = {"a": variables.Variable(5.)} |
| return cache[self] |
| root = Extra() |
| path = tempfile.mkdtemp(prefix=self.get_temp_dir()) |
| save.save(root, path) |
| imported = load.load(path) |
| self.assertEqual(5, self.evaluate(imported.a)) |
| |
| root.a = variables.Variable(3.) |
| with self.assertRaisesRegexp( |
| ValueError, |
| "object has an attribute named a, which is reserved."): |
| save.save(root, path) |
| |
| def test_save_cached_variable(self): |
| with ops.Graph().as_default(), session_lib.Session() as session: |
| obj = tracking.AutoTrackable() |
| obj.v = variables.Variable(2., caching_device=lambda op: op.device) |
| obj.w = variables.Variable(3.) |
| session.run([obj.v.initializer, obj.w.initializer]) |
| |
| @def_function.function |
| def total(): |
| return obj.v + obj.w |
| |
| @def_function.function(input_signature=[tensor_spec.TensorSpec([])]) |
| def wrapped_total(x): |
| return total() + x |
| |
| @def_function.function |
| def increment_v(x): |
| obj.v.assign_add(x) |
| |
| session.run(increment_v(constant_op.constant(3.))) # generate signatures |
| self.assertAllClose(8, total()) |
| self.assertAllClose(13, wrapped_total(constant_op.constant(5.))) |
| |
| obj.total = total |
| obj.wrapped_total = wrapped_total.get_concrete_function() |
| obj.increment_v = increment_v |
| |
| save_dir = os.path.join(self.get_temp_dir(), "saved_model") |
| save.save(obj, save_dir, signatures=total.get_concrete_function()) |
| imported = load.load(save_dir) |
| session.run(variables.global_variables_initializer()) |
| self.assertAllClose(8, imported.total()) |
| session.run(imported.increment_v(4)) |
| self.assertAllClose(12, imported.total()) |
| self.assertAllClose(15, imported.wrapped_total(constant_op.constant(3.))) |
| self.assertAllClose({"output_0": 12}, |
| imported.signatures["serving_default"]()) |
| |
| # Try loading and running the function in eager mode |
| imported = load.load(save_dir) |
| self.assertAllClose(8, imported.total()) |
| imported.increment_v(5) |
| self.assertAllClose(13, imported.total()) |
| self.assertAllClose(13.5, imported.wrapped_total(constant_op.constant(.5))) |
| self.assertAllClose({"output_0": 13}, |
| imported.signatures["serving_default"]()) |
| |
| # TODO(allenl, kkb): Use the new memory checker here once it's fast enough (3 |
| # iterations took hundreds of seconds). It would be really nice to check |
| # allocations at a lower level. |
| @test_util.assert_no_new_pyobjects_executing_eagerly |
| def test_functions_cleaned(self): |
| if sys.version_info.major < 3: |
| self.skipTest("Not working in Python 2") |
| root = module.Module() |
| root.v = variables.Variable(1.) |
| root.f = def_function.function( |
| lambda x: x + root.v, |
| input_signature=[ |
| tensor_spec.TensorSpec(shape=[], dtype=dtypes.float32)]) |
| cycle(root, 1) |
| |
| |
| if __name__ == "__main__": |
| test.main() |