| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for SavedModel.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import os |
| |
| from tensorflow.core.framework import types_pb2 |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.core.protobuf import meta_graph_pb2 |
| from tensorflow.python.client import session |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import test_ops |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.lib.io import file_io |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import state_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.ops.ragged import ragged_factory_ops |
| from tensorflow.python.platform import test |
| from tensorflow.python.saved_model import builder as saved_model_builder |
| from tensorflow.python.saved_model import constants |
| from tensorflow.python.saved_model import loader |
| from tensorflow.python.saved_model import loader_impl |
| from tensorflow.python.saved_model import main_op |
| from tensorflow.python.saved_model import signature_def_utils |
| from tensorflow.python.saved_model import tag_constants |
| from tensorflow.python.saved_model import utils |
| from tensorflow.python.training import saver_test_utils |
| from tensorflow.python.training import training |
| from tensorflow.python.util import compat |
| |
| SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123") |
| |
| |
| def tearDownModule(): |
| file_io.delete_recursively(test.get_temp_dir()) |
| |
| |
| class SavedModelTestBase(test.TestCase): |
| |
| def _get_export_dir(self, label): |
| return os.path.join(test.get_temp_dir(), label) |
| |
| def _init_and_validate_variable(self, sess, variable_name, variable_value): |
| v = variables.VariableV1(variable_value, name=variable_name) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(variable_value, self.evaluate(v)) |
| |
| def _build_asset_collection(self, asset_file_name, asset_file_contents, |
| asset_file_tensor_name, asset_subdir=""): |
| parent_dir = os.path.join( |
| compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_subdir)) |
| file_io.recursive_create_dir(parent_dir) |
| asset_filepath = os.path.join( |
| compat.as_bytes(parent_dir), compat.as_bytes(asset_file_name)) |
| file_io.write_string_to_file(asset_filepath, asset_file_contents) |
| asset_file_tensor = constant_op.constant( |
| asset_filepath, name=asset_file_tensor_name) |
| ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file_tensor) |
| asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) |
| return asset_collection |
| |
| |
| class SavedModelTest(SavedModelTestBase): |
| |
| def _validate_assets(self, |
| export_dir, |
| asset_file_def, |
| expected_asset_file_name, |
| expected_asset_file_contents, |
| expected_asset_tensor_name, |
| asset_id=0): |
| assets_path = os.path.join( |
| compat.as_bytes(export_dir), |
| compat.as_bytes(constants.ASSETS_DIRECTORY), |
| compat.as_bytes(expected_asset_file_name)) |
| actual_asset_contents = file_io.read_file_to_string(assets_path) |
| self.assertEqual(expected_asset_file_contents, |
| compat.as_text(actual_asset_contents)) |
| self.assertEqual(expected_asset_file_name, |
| asset_file_def[asset_id].filename) |
| self.assertEqual(expected_asset_tensor_name, |
| asset_file_def[asset_id].tensor_info.name) |
| |
| def _validate_inputs_tensor_info_fail(self, builder, tensor_info): |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| foo_signature = signature_def_utils.build_signature_def({ |
| "foo_inputs": tensor_info |
| }, dict(), "foo") |
| self.assertRaises( |
| AssertionError, |
| builder.add_meta_graph_and_variables, |
| sess, ["foo"], |
| signature_def_map={"foo_key": foo_signature}) |
| |
| def _validate_inputs_tensor_info_accept(self, builder, tensor_info): |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| foo_signature = signature_def_utils.build_signature_def({ |
| "foo_inputs": tensor_info |
| }, dict(), "foo") |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], |
| signature_def_map={"foo_key": foo_signature}) |
| |
| def _validate_outputs_tensor_info_fail(self, builder, tensor_info): |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| foo_signature = signature_def_utils.build_signature_def( |
| dict(), {"foo_outputs": tensor_info}, "foo") |
| self.assertRaises( |
| AssertionError, |
| builder.add_meta_graph_and_variables, |
| sess, ["foo"], |
| signature_def_map={"foo_key": foo_signature}) |
| |
| def _validate_outputs_tensor_info_accept(self, builder, tensor_info): |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| foo_signature = signature_def_utils.build_signature_def( |
| dict(), {"foo_outputs": tensor_info}, "foo") |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], |
| signature_def_map={"foo_key": foo_signature}) |
| |
| def _validate_sig_def_keys(self, builder, valid_tensor_info, invalid_key): |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| foo_signature = signature_def_utils.build_signature_def( |
| dict(), {"foo_key": valid_tensor_info}, "foo") |
| self.assertRaises( |
| KeyError, |
| builder.add_meta_graph_and_variables, |
| sess, ["foo"], |
| signature_def_map={invalid_key: foo_signature}) |
| |
| def testMaybeSavedModelDir(self): |
| base_path = test.test_src_dir_path("/python/saved_model") |
| self.assertFalse(loader.maybe_saved_model_directory(base_path)) |
| base_path = test.test_src_dir_path(SAVED_MODEL_PATH) |
| self.assertTrue(loader.maybe_saved_model_directory(base_path)) |
| base_path = "complete_garbage" |
| self.assertFalse(loader.maybe_saved_model_directory(base_path)) |
| |
| def testBadSavedModelFileFormat(self): |
| export_dir = self._get_export_dir("test_bad_saved_model_file_format") |
| # Attempt to load a SavedModel from an export directory that does not exist. |
| with self.session(graph=ops.Graph()) as sess: |
| with self.assertRaisesRegexp(IOError, |
| "SavedModel file does not exist at: %s" % |
| export_dir): |
| loader.load(sess, ["foo"], export_dir) |
| |
| os.makedirs(export_dir) |
| # Write an invalid binary proto to saved_model.pb. |
| path_to_pb = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB) |
| with open(path_to_pb, "w") as f: |
| f.write("invalid content") |
| with self.session(graph=ops.Graph()) as sess: |
| with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" % |
| constants.SAVED_MODEL_FILENAME_PB): |
| loader.load(sess, ["foo"], export_dir) |
| |
| # Cleanup the directory and start again. |
| file_io.delete_recursively(export_dir) |
| |
| os.makedirs(export_dir) |
| # Write an invalid text proto to saved_model.pbtxt |
| path_to_pbtxt = os.path.join(export_dir, |
| constants.SAVED_MODEL_FILENAME_PBTXT) |
| with open(path_to_pbtxt, "w") as f: |
| f.write("invalid content") |
| with self.session(graph=ops.Graph()) as sess: |
| with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" % |
| constants.SAVED_MODEL_FILENAME_PBTXT): |
| loader.load(sess, ["foo"], export_dir) |
| |
| @test_util.run_deprecated_v1 |
| def testVerifySessionGraphUsage(self): |
| export_dir = self._get_export_dir("test_verify_session_graph_usage") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| # Build a session and supply it to the load operation. |
| sess = session.Session(graph=ops.Graph()) |
| loader.load(sess, [tag_constants.TRAINING], export_dir) |
| |
| # Check the variable within the scope of the session and its graph. |
| with sess: |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| @test_util.run_deprecated_v1 |
| def testSequence(self): |
| export_dir = self._get_export_dir("test_sequence") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Expect an assertion error since add_meta_graph_and_variables() should be |
| # invoked before any add_meta_graph() calls. |
| with self.session(graph=ops.Graph()) as sess: |
| self.assertRaises(AssertionError, builder.add_meta_graph, ["foo"]) |
| |
| # Expect an assertion error for multiple calls of |
| # add_meta_graph_and_variables() since weights should be saved exactly once. |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| builder.add_meta_graph_and_variables(sess, ["bar"]) |
| self.assertRaises(AssertionError, builder.add_meta_graph_and_variables, |
| sess, ["baz"]) |
| |
| @test_util.run_deprecated_v1 |
| def testTags(self): |
| export_dir = self._get_export_dir("test_tags") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Graph with a single variable. SavedModel invoked to: |
| # - add with weights. |
| # - a single tag (from predefined constants). |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) |
| |
| # Graph that updates the single variable. SavedModel invoked to: |
| # - simply add the model (weights are not updated). |
| # - a single tag (from predefined constants). |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 43) |
| builder.add_meta_graph([tag_constants.SERVING]) |
| |
| # Graph that updates the single variable. SavedModel invoked to: |
| # - simply add the model (weights are not updated). |
| # - multiple tags (from predefined constants). |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 45) |
| builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU]) |
| |
| # Graph that updates the single variable. SavedModel invoked to: |
| # - simply add the model (weights are not updated). |
| # - multiple tags (from predefined constants for serving on TPU). |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 45) |
| builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU]) |
| |
| # Graph that updates the single variable. SavedModel is invoked: |
| # - to add the model (weights are not updated). |
| # - multiple custom tags. |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 44) |
| builder.add_meta_graph(["foo", "bar"]) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| # Restore the graph with a single predefined tag whose variables were saved. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, [tag_constants.TRAINING], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| # Restore the graph with a single predefined tag whose variables were not |
| # saved. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, [tag_constants.SERVING], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| # Restore the graph with multiple predefined tags whose variables were not |
| # saved. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, [tag_constants.SERVING, tag_constants.GPU], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| # Restore the graph with multiple predefined tags (for serving on TPU) |
| # whose variables were not saved. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, [tag_constants.SERVING, tag_constants.TPU], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| # Restore the graph with multiple tags. Provide duplicate tags to test set |
| # semantics. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["foo", "bar", "foo"], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| # Try restoring a graph with a non-existent tag. This should yield a runtime |
| # error. |
| with self.session(graph=ops.Graph()) as sess: |
| self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"], |
| export_dir) |
| |
| # Try restoring a graph where a subset of the tags match. Since tag matching |
| # for meta graph defs follows "all" semantics, this should yield a runtime |
| # error. |
| with self.session(graph=ops.Graph()) as sess: |
| self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"], |
| export_dir) |
| |
| @test_util.run_v1_only("b/120545219") |
| def testVariables(self): |
| export_dir = self._get_export_dir("test_variables") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Graph with two variables. SavedModel invoked to: |
| # - add with weights. |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v1", 1) |
| self._init_and_validate_variable(sess, "v2", 2) |
| builder.add_meta_graph_and_variables(sess, ["foo"]) |
| |
| # Graph with a single variable (subset of the variables from the previous |
| # graph whose weights were saved). SavedModel invoked to: |
| # - simply add the model (weights are not updated). |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v2", 3) |
| builder.add_meta_graph(["bar"]) |
| |
| # Graph with a single variable (disjoint set of variables from the previous |
| # graph whose weights were saved). SavedModel invoked to: |
| # - simply add the model (weights are not updated). |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v3", 4) |
| builder.add_meta_graph(["baz"]) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| # Restore the graph with tag "foo", whose variables were saved. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["foo"], export_dir) |
| collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) |
| self.assertEqual(len(collection_vars), 2) |
| self.assertEqual(1, collection_vars[0].eval()) |
| self.assertEqual(2, collection_vars[1].eval()) |
| |
| # Restore the graph with tag "bar", whose variables were not saved. Only the |
| # subset of the variables added to the graph will be restored with the |
| # checkpointed value. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["bar"], export_dir) |
| collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) |
| self.assertEqual(len(collection_vars), 1) |
| self.assertEqual(2, collection_vars[0].eval()) |
| |
| # Try restoring the graph with tag "baz", whose variables were not saved. |
| # Since this graph has a disjoint set of variables from the set that was |
| # saved, this should raise an error. |
| with self.session(graph=ops.Graph()) as sess: |
| self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"], |
| export_dir) |
| |
| @test_util.run_deprecated_v1 |
| def testGraphWithoutVariables(self): |
| export_dir = self._get_export_dir("test_graph_has_variables") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Graph with no variables. |
| with self.session(graph=ops.Graph()) as sess: |
| constant_5_name = constant_op.constant(5.0).name |
| builder.add_meta_graph_and_variables(sess, ["foo"]) |
| |
| # Second graph with no variables |
| with self.session(graph=ops.Graph()) as sess: |
| constant_6_name = constant_op.constant(6.0).name |
| builder.add_meta_graph(["bar"]) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| # Restore the graph with tag "foo". |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["foo"], export_dir) |
| # Read the constant a from the graph. |
| a = ops.get_default_graph().get_tensor_by_name(constant_5_name) |
| b = constant_op.constant(6.0) |
| c = a * b |
| self.assertEqual(30.0, self.evaluate(c)) |
| |
| # Restore the graph with tag "bar". |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["bar"], export_dir) |
| # Read the constant a from the graph. |
| a = ops.get_default_graph().get_tensor_by_name(constant_6_name) |
| b = constant_op.constant(5.0) |
| c = a * b |
| self.assertEqual(30.0, self.evaluate(c)) |
| |
| @test_util.run_deprecated_v1 |
| def testNoOverwrite(self): |
| export_dir = self._get_export_dir("test_no_overwrite") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Graph with a single variable. SavedModel invoked to: |
| # - add with weights. |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| builder.add_meta_graph_and_variables(sess, ["foo"]) |
| |
| # Save the SavedModel to disk in text format. |
| builder.save(as_text=True) |
| |
| # Restore the graph with tag "foo", whose variables were saved. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["foo"], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| # An attempt to create another builder with the same export directory should |
| # result in an assertion error. |
| self.assertRaises(AssertionError, saved_model_builder._SavedModelBuilder, |
| export_dir) |
| |
| @test_util.run_deprecated_v1 |
| def testSaveAsText(self): |
| export_dir = self._get_export_dir("test_astext") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Graph with a single variable. SavedModel invoked to: |
| # - add with weights. |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| builder.add_meta_graph_and_variables(sess, ["foo"]) |
| |
| # Graph with the same single variable. SavedModel invoked to: |
| # - simply add the model (weights are not updated). |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 43) |
| builder.add_meta_graph(["bar"]) |
| |
| # Save the SavedModel to disk in text format. |
| builder.save(as_text=True) |
| |
| # Restore the graph with tag "foo", whose variables were saved. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["foo"], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| # Restore the graph with tag "bar", whose variables were not saved. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["bar"], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| @test_util.run_v1_only("b/120545219") |
| def testCollections(self): |
| export_dir = self._get_export_dir("test_collections") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Graph with a single variable added to a collection. SavedModel invoked to: |
| # - add with weights. |
| with self.session(graph=ops.Graph()) as sess: |
| v = variables.VariableV1(42, name="v") |
| ops.add_to_collection("foo_vars", v) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(42, self.evaluate(v)) |
| builder.add_meta_graph_and_variables(sess, ["foo"]) |
| |
| # Graph with the same single variable added to a different collection. |
| # SavedModel invoked to: |
| # - simply add the model (weights are not updated). |
| with self.session(graph=ops.Graph()) as sess: |
| v = variables.VariableV1(43, name="v") |
| ops.add_to_collection("bar_vars", v) |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(43, self.evaluate(v)) |
| builder.add_meta_graph(["bar"]) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| # Restore the graph with tag "foo", whose variables were saved. The |
| # collection 'foo_vars' should contain a single element. The collection |
| # 'bar_vars' should not be found. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["foo"], export_dir) |
| collection_foo_vars = ops.get_collection("foo_vars") |
| self.assertEqual(len(collection_foo_vars), 1) |
| self.assertEqual(42, collection_foo_vars[0].eval()) |
| |
| self.assertEqual(len(ops.get_collection("bar_vars")), 0) |
| |
| # Restore the graph with tag "bar", whose variables were not saved. The |
| # collection-def exported as part of the meta graph def is updated to |
| # reflect the new collection. The value of the variable in the |
| # collection-def corresponds to the saved value (from the previous graph |
| # with tag "foo"). |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["bar"], export_dir) |
| collection_bar_vars = ops.get_collection("bar_vars") |
| self.assertEqual(len(collection_bar_vars), 1) |
| self.assertEqual(42, collection_bar_vars[0].eval()) |
| |
| self.assertEqual(len(ops.get_collection("foo_vars")), 0) |
| |
| @test_util.run_deprecated_v1 |
| def testSignatureDefs(self): |
| export_dir = self._get_export_dir("test_signature_defs") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Graph with a single variable and a single entry in the signature def map. |
| # SavedModel is invoked to add with weights. |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| # Build and populate an empty SignatureDef for testing. |
| foo_signature = signature_def_utils.build_signature_def(dict(), |
| dict(), "foo") |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], signature_def_map={"foo_key": foo_signature}) |
| |
| # Graph with the same single variable and multiple entries in the signature |
| # def map. No weights are saved by SavedModel. |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 43) |
| # Build and populate a different SignatureDef for testing. |
| bar_signature = signature_def_utils.build_signature_def(dict(), |
| dict(), "bar") |
| # Also, build a different SignatureDef corresponding to "foo_key" defined |
| # in the previous graph. |
| foo_new_signature = signature_def_utils.build_signature_def(dict(), |
| dict(), |
| "foo_new") |
| builder.add_meta_graph( |
| ["bar"], |
| signature_def_map={ |
| "bar_key": bar_signature, |
| "foo_key": foo_new_signature |
| }) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| # Restore the graph with tag "foo". The single entry in the SignatureDef map |
| # corresponding to "foo_key" should exist. |
| with self.session(graph=ops.Graph()) as sess: |
| foo_graph = loader.load(sess, ["foo"], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| foo_signature = foo_graph.signature_def |
| self.assertEqual(len(foo_signature), 1) |
| self.assertEqual("foo", foo_signature["foo_key"].method_name) |
| |
| # Restore the graph with tag "bar". The SignatureDef map should have two |
| # entries. One corresponding to "bar_key" and another corresponding to the |
| # new value of "foo_key". |
| with self.session(graph=ops.Graph()) as sess: |
| bar_graph = loader.load(sess, ["bar"], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| bar_signature = bar_graph.signature_def |
| self.assertEqual(len(bar_signature), 2) |
| self.assertEqual("bar", bar_signature["bar_key"].method_name) |
| self.assertEqual("foo_new", bar_signature["foo_key"].method_name) |
| |
| def testSignatureDefValidationFails(self): |
| export_dir = self._get_export_dir("test_signature_def_validation_fail") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| tensor_without_encoding = meta_graph_pb2.TensorInfo() |
| tensor_without_encoding.dtype = types_pb2.DT_FLOAT |
| self._validate_inputs_tensor_info_fail(builder, tensor_without_encoding) |
| self._validate_outputs_tensor_info_fail(builder, tensor_without_encoding) |
| |
| tensor_without_dtype = meta_graph_pb2.TensorInfo() |
| tensor_without_dtype.name = "x" |
| self._validate_inputs_tensor_info_fail(builder, tensor_without_dtype) |
| self._validate_outputs_tensor_info_fail(builder, tensor_without_dtype) |
| |
| tensor_empty = meta_graph_pb2.TensorInfo() |
| self._validate_inputs_tensor_info_fail(builder, tensor_empty) |
| self._validate_outputs_tensor_info_fail(builder, tensor_empty) |
| |
| valid_tensor_info = meta_graph_pb2.TensorInfo() |
| valid_tensor_info.name = "foo" |
| valid_tensor_info.dtype = types_pb2.DT_FLOAT |
| |
| self._validate_sig_def_keys(builder, valid_tensor_info, |
| constants.INIT_OP_SIGNATURE_KEY) |
| self._validate_sig_def_keys(builder, valid_tensor_info, |
| constants.TRAIN_OP_SIGNATURE_KEY) |
| |
| @test_util.run_deprecated_v1 |
| def testSignatureDefValidationSucceedsWithName(self): |
| tensor_with_name = meta_graph_pb2.TensorInfo() |
| tensor_with_name.name = "foo" |
| tensor_with_name.dtype = types_pb2.DT_FLOAT |
| |
| export_dir = self._get_export_dir("test_signature_def_validation_name_1") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| self._validate_inputs_tensor_info_accept(builder, tensor_with_name) |
| |
| export_dir = self._get_export_dir("test_signature_def_validation_name_2") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| self._validate_outputs_tensor_info_accept(builder, tensor_with_name) |
| |
| @test_util.run_deprecated_v1 |
| def testSignatureDefValidationSucceedsWithCoo(self): |
| tensor_with_coo = meta_graph_pb2.TensorInfo() |
| # TODO(soergel) test validation of each of the fields of coo_sparse |
| tensor_with_coo.coo_sparse.values_tensor_name = "foo" |
| tensor_with_coo.dtype = types_pb2.DT_FLOAT |
| |
| export_dir = self._get_export_dir("test_signature_def_validation_coo_1") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| self._validate_inputs_tensor_info_accept(builder, tensor_with_coo) |
| |
| export_dir = self._get_export_dir("test_signature_def_validation_coo_2") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| self._validate_outputs_tensor_info_accept(builder, tensor_with_coo) |
| |
| @test_util.run_deprecated_v1 |
| def testSignatureDefValidationSucceedsWithRagged(self): |
| ragged_tensor = ragged_factory_ops.constant([[1, 2], [3]]) |
| tensor_with_ragged = utils.build_tensor_info(ragged_tensor) |
| |
| export_dir = self._get_export_dir("test_signature_def_validation_ragged_1") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| self._validate_inputs_tensor_info_accept(builder, tensor_with_ragged) |
| |
| export_dir = self._get_export_dir("test_signature_def_validation_ragged_2") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| self._validate_outputs_tensor_info_accept(builder, tensor_with_ragged) |
| |
| @test_util.run_deprecated_v1 |
| def testAssets(self): |
| export_dir = self._get_export_dir("test_assets") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| # Build an asset collection. |
| ignored_filepath = os.path.join( |
| compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt")) |
| file_io.write_string_to_file(ignored_filepath, "will be ignored") |
| |
| asset_list = self._build_asset_collection("hello42.txt", "foo bar baz", |
| "asset_file_tensor") |
| |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], assets_list=asset_list) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| foo_graph = loader.load(sess, ["foo"], export_dir) |
| self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt", |
| "foo bar baz", "asset_file_tensor:0") |
| ignored_asset_path = os.path.join( |
| compat.as_bytes(export_dir), |
| compat.as_bytes(constants.ASSETS_DIRECTORY), |
| compat.as_bytes("ignored.txt")) |
| self.assertFalse(file_io.file_exists(ignored_asset_path)) |
| |
| @test_util.run_deprecated_v1 |
| def testAssetsNameCollisionDiffFile(self): |
| export_dir = self._get_export_dir("test_assets_name_collision_diff_file") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| asset_list = self._build_asset_collection( |
| "hello42.txt", "foo bar bak", "asset_file_tensor", asset_subdir="1") |
| |
| asset_list = self._build_asset_collection( |
| "hello42.txt", "foo bar baz", "asset_file_tensor_1", asset_subdir="2") |
| |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], assets_list=asset_list) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| foo_graph = loader.load(sess, ["foo"], export_dir) |
| self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt", |
| "foo bar bak", "asset_file_tensor:0") |
| self._validate_assets( |
| export_dir, |
| foo_graph.asset_file_def, |
| "hello42.txt_1", |
| "foo bar baz", |
| "asset_file_tensor_1:0", |
| asset_id=1) |
| |
| @test_util.run_deprecated_v1 |
| def testAssetsNameCollisionSameFilepath(self): |
| export_dir = self._get_export_dir("test_assets_name_collision_same_path") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| asset_list = self._build_asset_collection("hello42.txt", "foo bar baz", |
| "asset_file_tensor") |
| |
| asset_list = self._build_asset_collection("hello42.txt", "foo bar baz", |
| "asset_file_tensor_1") |
| |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], assets_list=asset_list) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| foo_graph = loader.load(sess, ["foo"], export_dir) |
| self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt", |
| "foo bar baz", "asset_file_tensor:0") |
| # The second tensor should be recorded, but the same. |
| self._validate_assets( |
| export_dir, |
| foo_graph.asset_file_def, |
| "hello42.txt", |
| "foo bar baz", |
| "asset_file_tensor_1:0", |
| asset_id=1) |
| ignored_asset_path = os.path.join( |
| compat.as_bytes(export_dir), |
| compat.as_bytes(constants.ASSETS_DIRECTORY), |
| compat.as_bytes("hello42.txt_1")) |
| self.assertFalse(file_io.file_exists(ignored_asset_path)) |
| |
| @test_util.run_deprecated_v1 |
| def testAssetsNameCollisionSameFile(self): |
| export_dir = self._get_export_dir("test_assets_name_collision_same_file") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| asset_list = self._build_asset_collection( |
| "hello42.txt", "foo bar baz", "asset_file_tensor", asset_subdir="1") |
| |
| asset_list = self._build_asset_collection( |
| "hello42.txt", "foo bar baz", "asset_file_tensor_1", asset_subdir="2") |
| |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], assets_list=asset_list) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| foo_graph = loader.load(sess, ["foo"], export_dir) |
| self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt", |
| "foo bar baz", "asset_file_tensor:0") |
| # The second tensor should be recorded, but the same. |
| self._validate_assets( |
| export_dir, |
| foo_graph.asset_file_def, |
| "hello42.txt", |
| "foo bar baz", |
| "asset_file_tensor_1:0", |
| asset_id=1) |
| ignored_asset_path = os.path.join( |
| compat.as_bytes(export_dir), |
| compat.as_bytes(constants.ASSETS_DIRECTORY), |
| compat.as_bytes("hello42.txt_1")) |
| self.assertFalse(file_io.file_exists(ignored_asset_path)) |
| |
| @test_util.run_deprecated_v1 |
| def testAssetsNameCollisionManyFiles(self): |
| export_dir = self._get_export_dir("test_assets_name_collision_many_files") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| for i in range(5): |
| idx = str(i) |
| asset_list = self._build_asset_collection( |
| "hello42.txt", |
| "foo bar baz " + idx, |
| "asset_file_tensor_" + idx, |
| asset_subdir=idx) |
| |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], assets_list=asset_list) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| foo_graph = loader.load(sess, ["foo"], export_dir) |
| for i in range(1, 5): |
| idx = str(i) |
| self._validate_assets( |
| export_dir, |
| foo_graph.asset_file_def, |
| "hello42.txt_" + idx, |
| "foo bar baz " + idx, |
| "asset_file_tensor_{}:0".format(idx), |
| asset_id=i) |
| |
| self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt", |
| "foo bar baz 0", "asset_file_tensor_0:0") |
| |
| @test_util.run_v1_only("b/120545219") |
| def testCustomInitOp(self): |
| export_dir = self._get_export_dir("test_main_op") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| # Add `v1` and `v2` variables to the graph. |
| v1 = variables.VariableV1(1, name="v1") |
| ops.add_to_collection("v", v1) |
| v2 = variables.VariableV1(2, name="v2") |
| ops.add_to_collection("v", v2) |
| |
| # Initialize another variable `v3` to 42. |
| v3 = variables.VariableV1(42, name="v3") |
| ops.add_to_collection("v", v3) |
| |
| # Set up an assignment op to be run as part of the main_op. |
| with ops.control_dependencies([main_op.main_op()]): |
| add_v1_v2 = math_ops.add(v1._ref(), v2._ref()) |
| custom_init_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2)) |
| |
| self.evaluate(custom_init_op) |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], init_op=custom_init_op) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["foo"], export_dir) |
| self.assertEqual(1, ops.get_collection("v")[0].eval()) |
| self.assertEqual(2, ops.get_collection("v")[1].eval()) |
| # Evaluates to the sum of the first two variables and assigned as part of |
| # the main_op, following a restore. |
| self.assertEqual(3, ops.get_collection("v")[2].eval()) |
| |
| @test_util.run_v1_only("b/120545219") |
| def testTrainOp(self): |
| export_dir = self._get_export_dir("test_train_op") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| # Add `v1` and `v2` variables to the graph. |
| v1 = variables.VariableV1(1, name="v1") |
| ops.add_to_collection("v", v1) |
| v2 = variables.VariableV1(2, name="v2") |
| ops.add_to_collection("v", v2) |
| |
| self.evaluate(variables.global_variables_initializer()) |
| train_op = state_ops.assign_add(v1, v2) |
| |
| self.evaluate(train_op) |
| builder.add_meta_graph_and_variables(sess, ["foo"], train_op=train_op) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| meta_graph_def = loader.load(sess, ["foo"], export_dir) |
| self.assertEqual(3, ops.get_collection("v")[0].eval()) |
| self.assertEqual(2, ops.get_collection("v")[1].eval()) |
| self.assertIsInstance( |
| loader_impl.get_train_op(meta_graph_def), ops.Tensor) |
| |
| @test_util.run_v1_only("b/120545219") |
| def testTrainOpGroup(self): |
| export_dir = self._get_export_dir("test_train_op_group") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| # Add `v1` and `v2` variables to the graph. |
| v1 = variables.VariableV1(1, name="v1") |
| ops.add_to_collection("v", v1) |
| v2 = variables.VariableV1(2, name="v2") |
| ops.add_to_collection("v", v2) |
| |
| self.evaluate(variables.global_variables_initializer()) |
| train_op = control_flow_ops.group() |
| |
| self.evaluate(train_op) |
| builder.add_meta_graph_and_variables(sess, ["foo"], train_op=train_op) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| meta_graph_def = loader.load(sess, ["foo"], export_dir) |
| self.assertEqual(1, ops.get_collection("v")[0].eval()) |
| self.assertEqual(2, ops.get_collection("v")[1].eval()) |
| self.assertIsInstance( |
| loader_impl.get_train_op(meta_graph_def), ops.Operation) |
| |
| @test_util.run_v1_only("b/120545219") |
| def testTrainOpAfterVariables(self): |
| export_dir = self._get_export_dir("test_train_op_after_variables") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| # Add `v1` and `v2` variables to the graph. |
| v1 = variables.VariableV1(1, name="v1") |
| ops.add_to_collection("v", v1) |
| v2 = variables.VariableV1(2, name="v2") |
| ops.add_to_collection("v", v2) |
| |
| self.evaluate(variables.global_variables_initializer()) |
| builder.add_meta_graph_and_variables(sess, ["pre_foo"]) |
| |
| train_op = state_ops.assign_add(v1, v2) |
| self.evaluate(train_op) |
| builder.add_meta_graph(["foo"], train_op=train_op) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| meta_graph_def = loader.load(sess, ["foo"], export_dir) |
| self.assertIsInstance( |
| loader_impl.get_train_op(meta_graph_def), ops.Tensor) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["pre_foo"], export_dir) |
| self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY)) |
| |
| @test_util.run_deprecated_v1 |
| def testMultipleAssets(self): |
| export_dir = self._get_export_dir("test_multiple_assets") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| # Build an asset collection specific to `foo` graph. |
| asset_list = self._build_asset_collection("foo.txt", "content_foo", |
| "asset_file_tensor") |
| |
| # Add the asset collection as part of the graph with tag "foo". |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], assets_list=asset_list) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| # Build an asset collection specific to `bar` graph. |
| asset_list = self._build_asset_collection("bar.txt", "content_bar", |
| "asset_file_tensor") |
| |
| # Add the asset collection as part of the graph with tag "bar". |
| builder.add_meta_graph(["bar"], assets_list=asset_list) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| # Check assets restored for graph with tag "foo". |
| with self.session(graph=ops.Graph()) as sess: |
| foo_graph = loader.load(sess, ["foo"], export_dir) |
| self._validate_assets(export_dir, foo_graph.asset_file_def, "foo.txt", |
| "content_foo", "asset_file_tensor:0") |
| |
| # Check assets restored for graph with tag "bar". |
| with self.session(graph=ops.Graph()) as sess: |
| bar_graph = loader.load(sess, ["bar"], export_dir) |
| self._validate_assets(export_dir, bar_graph.asset_file_def, "bar.txt", |
| "content_bar", "asset_file_tensor:0") |
| |
| @test_util.run_deprecated_v1 |
| def testDuplicateAssets(self): |
| export_dir = self._get_export_dir("test_duplicate_assets") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| # Build an asset collection with `foo.txt` that has `foo` specific |
| # content. |
| asset_list = self._build_asset_collection("foo.txt", "content_foo", |
| "asset_file_tensor") |
| |
| # Add the asset collection as part of the graph with tag "foo". |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], assets_list=asset_list) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| # Build an asset collection with `foo.txt` that has `bar` specific |
| # content. |
| asset_list = self._build_asset_collection("foo.txt", "content_bar", |
| "asset_file_tensor") |
| |
| # Add the asset collection as part of the graph with tag "bar". |
| builder.add_meta_graph(["bar"], assets_list=asset_list) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| # Check assets restored for graph with tag "foo". |
| with self.session(graph=ops.Graph()) as sess: |
| foo_graph = loader.load(sess, ["foo"], export_dir) |
| self._validate_assets(export_dir, foo_graph.asset_file_def, "foo.txt", |
| "content_foo", "asset_file_tensor:0") |
| |
| # Check assets restored for graph with tag "bar". |
| with self.session(graph=ops.Graph()) as sess: |
| bar_graph = loader.load(sess, ["bar"], export_dir) |
| |
| # Validate the assets for `bar` graph. `foo.txt` should contain the |
| # original contents corresponding to `foo` graph since an asset with the |
| # same name across multiple graphs is only stored the first time |
| self._validate_assets(export_dir, bar_graph.asset_file_def, "foo.txt", |
| "content_foo", "asset_file_tensor:0") |
| |
| @test_util.run_v1_only("b/120545219") |
| def testOp(self): |
| export_dir = self._get_export_dir("test_op") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with session.Session( |
| graph=ops.Graph(), |
| config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: |
| with sess.graph.device("/cpu:0"): |
| v1 = variables.VariableV1(1, name="v1") |
| with sess.graph.device("/cpu:1"): |
| v2 = variables.VariableV1(2, name="v2") |
| |
| # v3 is an unsaved variable derived from v1 and v2. It is used to |
| # exercise the ability to run an init op when restoring a graph. |
| v3 = variables.VariableV1(1, name="v3", trainable=False, collections=[]) |
| assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2)) |
| init_op = control_flow_ops.group(assign_v3, name="init_op") |
| |
| ops.add_to_collection("v", v1) |
| ops.add_to_collection("v", v2) |
| ops.add_to_collection("v", v3) |
| ops.add_to_collection("init_op", init_op) |
| |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(1, ops.get_collection("v")[0].eval()) |
| self.assertEqual(2, ops.get_collection("v")[1].eval()) |
| |
| builder.add_meta_graph_and_variables(sess, ["foo"]) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with session.Session( |
| graph=ops.Graph(), |
| config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: |
| loader.load(sess, ["foo"], export_dir) |
| |
| # Validate variables, run the init op and verify result. |
| self.assertEqual(1, ops.get_collection("v")[0].eval()) |
| self.assertEqual(2, ops.get_collection("v")[1].eval()) |
| ops.get_collection("init_op")[0].run() |
| self.assertEqual(3, ops.get_collection("v")[2].eval()) |
| |
| def testCustomSaveable(self): |
| export_dir = self._get_export_dir("custom_saveable") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with session.Session( |
| graph=ops.Graph(), |
| config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: |
| # CheckpointedOp is a key-value table that can be saved across sessions. |
| # The table register itself in SAVEABLE_OBJECTS collection. |
| v1 = saver_test_utils.CheckpointedOp(name="v1") |
| self.evaluate(variables.global_variables_initializer()) |
| v1.insert("k1", 3.0).run() |
| # Once the table is restored, we can access it through this reference. |
| ops.add_to_collection("table_ref", v1.table_ref) |
| builder.add_meta_graph_and_variables(sess, ["foo"]) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with session.Session( |
| graph=ops.Graph(), |
| config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: |
| loader.load(sess, ["foo"], export_dir) |
| # Instantiate a wrapper object from the checkpointed reference. |
| v1 = saver_test_utils.CheckpointedOp( |
| name="v1", table_ref=ops.get_collection("table_ref")[0]) |
| self.assertEqual(b"k1", v1.keys().eval()) |
| self.assertEqual(3.0, v1.values().eval()) |
| |
| @test_util.run_deprecated_v1 |
| def testCustomSaver(self): |
| export_dir = self._get_export_dir("test_custom_saver") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| variables.VariableV1(1, name="v1") |
| self.evaluate(variables.global_variables_initializer()) |
| custom_saver = training.Saver(name="my_saver") |
| builder.add_meta_graph_and_variables(sess, ["tag"], saver=custom_saver) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with ops.Graph().as_default() as graph: |
| with self.session(graph=graph) as sess: |
| saved_graph = loader.load(sess, ["tag"], export_dir) |
| graph_ops = [x.name for x in graph.get_operations()] |
| self.assertTrue("my_saver/restore_all" in graph_ops) |
| self.assertFalse("save/restore_all" in graph_ops) |
| self.assertEqual( |
| saved_graph.saver_def.restore_op_name, "my_saver/restore_all") |
| |
| @test_util.run_deprecated_v1 |
| def testNoCustomSaver(self): |
| export_dir = self._get_export_dir("test_no_custom_saver") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| variables.VariableV1(1, name="v1") |
| self.evaluate(variables.global_variables_initializer()) |
| training.Saver(name="my_saver") |
| builder.add_meta_graph_and_variables(sess, ["tag"]) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with ops.Graph().as_default() as graph: |
| with self.session(graph=graph) as sess: |
| saved_graph = loader.load(sess, ["tag"], export_dir) |
| graph_ops = [x.name for x in graph.get_operations()] |
| self.assertTrue("my_saver/restore_all" in graph_ops) |
| self.assertTrue("save/restore_all" in graph_ops) |
| self.assertEqual( |
| saved_graph.saver_def.restore_op_name, "save/restore_all") |
| |
| @test_util.run_deprecated_v1 |
| def testMultipleCustomSavers(self): |
| export_dir = self._get_export_dir("test_multiple_custom_savers") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| variables.VariableV1(1, name="v1") |
| self.evaluate(variables.global_variables_initializer()) |
| builder.add_meta_graph_and_variables(sess, ["tag_0"]) |
| |
| saver_1 = training.Saver() |
| builder.add_meta_graph(["tag_1"], saver=saver_1) |
| |
| saver_2 = training.Saver() |
| builder.add_meta_graph(["tag_2"], saver=saver_2) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| def _validate_custom_saver(tag_name, saver_name): |
| with ops.Graph().as_default() as graph: |
| with self.session(graph=graph) as sess: |
| saved_graph = loader.load(sess, [tag_name], export_dir) |
| self.assertEqual( |
| saved_graph.saver_def.restore_op_name, |
| saver_name) |
| |
| _validate_custom_saver("tag_0", "save/restore_all") |
| _validate_custom_saver("tag_1", "save_1/restore_all") |
| _validate_custom_saver("tag_2", "save_2/restore_all") |
| |
| @test_util.run_deprecated_v1 |
| def testImportScope(self): |
| export_dir = self._get_export_dir("test_scoped_assets") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Build a SavedModel with a variable, an asset, and a constant tensor. |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| asset_list = self._build_asset_collection("foo.txt", "content_foo", |
| "asset_file_tensor") |
| constant_op.constant("constant value", name="constant_tensor_name") |
| builder.add_meta_graph_and_variables( |
| sess, ["tag_name"], assets_list=asset_list) |
| |
| # Save the asset file path for later comparison. |
| asset_file_path = asset_list[0].eval() |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| # Restore the SavedModel under an import_scope in a new graph/session. |
| graph_proto = loader.load( |
| sess, ["tag_name"], export_dir, import_scope="scope_name") |
| |
| # The loaded variable tensor should be scoped, but its contents should be |
| # unchanged. |
| self.assertEqual( |
| "scope_name/v:0", |
| ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].name) |
| self.assertEqual( |
| 42, |
| ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| # The loaded asset tensor should be scoped, but the asset file path and |
| # contents should be unchanged. |
| asset_list = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) |
| self.assertEqual(1, len(asset_list)) |
| self.assertEqual(asset_file_path, asset_list[0].eval()) |
| self.assertEqual("scope_name/asset_file_tensor:0", asset_list[0].name) |
| # The static asset data inside graph_proto.collection_def should not be |
| # scoped. |
| self._validate_assets(export_dir, graph_proto.asset_file_def, "foo.txt", |
| "content_foo", "asset_file_tensor:0") |
| |
| # The constant tensor should be scoped, but its contents should be |
| # unchanged. |
| self.assertEqual( |
| compat.as_bytes("constant value"), |
| ops.get_default_graph().get_tensor_by_name( |
| "scope_name/constant_tensor_name:0").eval()) |
| |
| @test_util.run_deprecated_v1 |
| def testClearDevices(self): |
| export_dir = self._get_export_dir("test_clear_devices") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Specify a device and save a variable. |
| ops.reset_default_graph() |
| with session.Session( |
| target="", |
| config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: |
| with sess.graph.device("/cpu:0"): |
| self._init_and_validate_variable(sess, "v", 42) |
| builder.add_meta_graph_and_variables( |
| sess, [tag_constants.TRAINING], clear_devices=True) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| # Restore the graph with a single predefined tag whose variables were saved |
| # without any device information. |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, [tag_constants.TRAINING], export_dir) |
| self.assertEqual( |
| 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) |
| |
| # Tests the behavior of loading SavedModels that having missing attrs or attrs |
| # with incorrect types. |
| def testInconsistentConsumerDefaultAttrs(self): |
| export_dir = self._get_export_dir( |
| "test_strip_default_attrs_no_consumer_defaults") |
| builder = saved_model_builder._SavedModelBuilder(export_dir) |
| |
| # Add a graph with a single variable and a test op with a defaultless |
| # float32 attr, "test_attr". |
| with session.Session(graph=ops.Graph()) as sess: |
| variables.VariableV1(1.0, dtype=dtypes.float64, name="var") |
| test_ops.test_attr(T=dtypes.float32, name="test_attr") |
| self.evaluate(variables.global_variables_initializer()) |
| builder.add_meta_graph_and_variables(sess, ["foo"]) |
| |
| # Save the SavedModel to disk in text format. |
| builder.save(as_text=True) |
| |
| # Rewrite the SavedModel to remove the T attr from "test_attr". |
| saved_model_file = os.path.join( |
| export_dir, constants.SAVED_MODEL_FILENAME_PBTXT) |
| with open(saved_model_file) as f: |
| original_saved_model = f.read() |
| |
| no_attr_saved_model = original_saved_model.replace(""" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| }""", "") |
| with open(saved_model_file, "w") as f: |
| f.write(no_attr_saved_model) |
| |
| # Loading the SavedModel via the loader must fail because the SavedModel |
| # does not have any attr values for the "TestAttr" node, and there is no |
| # default specified in the TestAttr OpDef. |
| sess = session.Session(graph=ops.Graph()) |
| with self.assertRaisesRegexp( |
| ValueError, "NodeDef missing attr 'T' from Op<name=TestAttr"): |
| loader.load(sess, ["foo"], export_dir) |
| |
| # Rewrite the SavedModel to change the type of the T attr in "test_attr" |
| bad_type_saved_model = original_saved_model.replace(""" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| }""", """ |
| attr { |
| key: "T" |
| value { |
| type: DT_DOUBLE |
| } |
| }""") |
| with open(saved_model_file, "w") as f: |
| f.write(bad_type_saved_model) |
| |
| # Loading the SavedModel via the loader must fail because there is no |
| # OpKernel registered to handle T = double. |
| sess = session.Session(graph=ops.Graph()) |
| with self.assertRaisesRegexp( |
| errors.InvalidArgumentError, |
| "No OpKernel was registered to support Op 'TestAttr' used by node " |
| "test_attr \\(defined at .*\\) with these attrs: \\[.*\\]\n" |
| "Registered devices:.*\n" |
| "Registered kernels:.*" |
| ): |
| loader.load(sess, ["foo"], export_dir) |
| |
| |
| class SavedModelV1Test(SavedModelTestBase): |
| |
| def _validate_asset_collection(self, |
| export_dir, |
| graph_collection_def, |
| expected_asset_file_name, |
| expected_asset_file_contents, |
| expected_asset_tensor_name, |
| asset_id=0): |
| assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value |
| asset = meta_graph_pb2.AssetFileDef() |
| assets_any[asset_id].Unpack(asset) |
| assets_path = os.path.join( |
| compat.as_bytes(export_dir), |
| compat.as_bytes(constants.ASSETS_DIRECTORY), |
| compat.as_bytes(expected_asset_file_name)) |
| actual_asset_contents = file_io.read_file_to_string(assets_path) |
| self.assertEqual(expected_asset_file_contents, |
| compat.as_text(actual_asset_contents)) |
| self.assertEqual(expected_asset_file_name, asset.filename) |
| self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name) |
| |
| @test_util.run_deprecated_v1 |
| def testWritingAssetsToCollection(self): |
| export_dir = self._get_export_dir("test_writing_assets_to_collection") |
| builder = saved_model_builder.SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| self._init_and_validate_variable(sess, "v", 42) |
| |
| # Build an asset list. |
| ignored_filepath = os.path.join( |
| compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt")) |
| file_io.write_string_to_file(ignored_filepath, "will be ignored") |
| |
| asset_collection = self._build_asset_collection( |
| "hello42.txt", "foo bar baz", "asset_file_tensor") |
| |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], assets_collection=asset_collection) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| foo_graph = loader.load(sess, ["foo"], export_dir) |
| self._validate_asset_collection(export_dir, foo_graph.collection_def, |
| "hello42.txt", "foo bar baz", |
| "asset_file_tensor:0") |
| ignored_asset_path = os.path.join( |
| compat.as_bytes(export_dir), |
| compat.as_bytes(constants.ASSETS_DIRECTORY), |
| compat.as_bytes("ignored.txt")) |
| self.assertFalse(file_io.file_exists(ignored_asset_path)) |
| |
| @test_util.run_deprecated_v1 |
| def testLegacyInitOpWithNonEmptyCollection(self): |
| export_dir = self._get_export_dir( |
| "test_legacy_init_op_with_non_empty_collection") |
| self._testInitOpsWithNonEmptyCollection(export_dir, |
| constants.LEGACY_INIT_OP_KEY) |
| |
| @test_util.run_deprecated_v1 |
| def testMainOpWithNonEmptyCollection(self): |
| export_dir = self._get_export_dir("test_main_op_with_non_empty_collection") |
| self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY) |
| |
| def _testInitOpsWithNonEmptyCollection(self, export_dir, key): |
| builder = saved_model_builder.SavedModelBuilder(export_dir) |
| |
| g = ops.Graph() |
| with self.session(graph=g) as sess: |
| # Initialize variable `v1` to 1. |
| v1 = variables.VariableV1(1, name="v1") |
| ops.add_to_collection("v", v1) |
| |
| # Initialize another variable `v2` to 42. |
| v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[]) |
| ops.add_to_collection("v", v2) |
| |
| # Set up an assignment op to be run as part of the init op. |
| assign_v2 = state_ops.assign(v2, v1) |
| init_op = control_flow_ops.group(assign_v2, name="init_op") |
| |
| self.evaluate(variables.global_variables_initializer()) |
| |
| ops.add_to_collection(key, control_flow_ops.no_op()) |
| # ValueError should be raised since the LEGACY_INIT_OP_KEY collection |
| # is not empty and we don't support multiple init ops. |
| with self.assertRaisesRegexp(ValueError, "Graph already contains"): |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], legacy_init_op=init_op) |
| # We shouldn't be able to add as MAIN_OP, either. |
| with self.assertRaisesRegexp(ValueError, "Graph already contains"): |
| builder.add_meta_graph_and_variables(sess, ["foo"], main_op=init_op) |
| |
| def testStripDefaultAttrs(self): |
| export_dir = self._get_export_dir("test_strip_default_attrs") |
| builder = saved_model_builder.SavedModelBuilder(export_dir) |
| |
| # Add a graph with two float32 variables and a Complex Op composing them |
| # with strip_default_attrs enabled. |
| with session.Session(graph=ops.Graph()) as sess: |
| real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real") |
| imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag") |
| math_ops.complex(real_num, imag_num, name="complex") |
| self.evaluate(variables.global_variables_initializer()) |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], strip_default_attrs=True) |
| |
| # Add a graph with the same float32 variables and a Complex Op composing |
| # them with strip_default_attrs disabled. |
| with session.Session(graph=ops.Graph()) as sess: |
| real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real") |
| imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag") |
| math_ops.complex(real_num, imag_num, name="complex") |
| self.evaluate(variables.global_variables_initializer()) |
| builder.add_meta_graph(["bar"], strip_default_attrs=False) |
| |
| # Save the SavedModel to disk in text format. |
| builder.save(as_text=True) |
| |
| # Loading graph "foo" via the loader must restore the defaults for the |
| # "Complex" node based on the "Complex" OpDef in the Op registry. |
| sess = session.Session(graph=ops.Graph()) |
| meta_graph_def = loader.load(sess, ["foo"], export_dir) |
| complex_node = test_util.get_node_def_from_graph("complex", |
| meta_graph_def.graph_def) |
| self.assertIn("T", complex_node.attr) |
| self.assertIn("Tout", complex_node.attr) |
| |
| # Load graph "foo" from disk as-is to verify default attrs are stripped. |
| saved_model_pb = loader_impl.parse_saved_model(export_dir) |
| self.assertIsNotNone(saved_model_pb) |
| |
| meta_graph_foo_def = None |
| meta_graph_bar_def = None |
| for meta_graph_def in saved_model_pb.meta_graphs: |
| if set(meta_graph_def.meta_info_def.tags) == set(["foo"]): |
| meta_graph_foo_def = meta_graph_def |
| elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]): |
| meta_graph_bar_def = meta_graph_def |
| |
| self.assertIsNotNone(meta_graph_foo_def) |
| self.assertIsNotNone(meta_graph_bar_def) |
| |
| # "Complex" Op has 2 attributes with defaults: |
| # o "T" : float32. (input type) |
| # o "Tout" : complex64. (output type) |
| |
| # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout". |
| # Graph "foo" was saved with strip_default_attrs set to True. |
| node_def = test_util.get_node_def_from_graph("complex", |
| meta_graph_foo_def.graph_def) |
| self.assertNotIn("T", node_def.attr) |
| self.assertNotIn("Tout", node_def.attr) |
| |
| # "Complex" Op in graph "bar" must have attributes "T" and "Tout". |
| # Graph "bar" was saved with strip_default_attrs set to False. |
| node_def = test_util.get_node_def_from_graph("complex", |
| meta_graph_bar_def.graph_def) |
| self.assertIn("T", node_def.attr) |
| self.assertIn("Tout", node_def.attr) |
| |
| @test_util.run_v1_only("b/120545219") |
| def testLegacyInitOp(self): |
| export_dir = self._get_export_dir("test_legacy_init_op") |
| builder = saved_model_builder.SavedModelBuilder(export_dir) |
| |
| with self.session(graph=ops.Graph()) as sess: |
| # Add `v1` and `v2` variables to the graph. |
| v1 = variables.VariableV1(1, name="v1") |
| ops.add_to_collection("v", v1) |
| v2 = variables.VariableV1(2, name="v2") |
| ops.add_to_collection("v", v2) |
| |
| # Initialize another variable `v3` to 42. |
| v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[]) |
| ops.add_to_collection("v", v3) |
| |
| # Set up an assignment op to be run as part of the init_op. |
| assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2)) |
| legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op") |
| |
| self.evaluate(variables.global_variables_initializer()) |
| builder.add_meta_graph_and_variables( |
| sess, ["foo"], legacy_init_op=legacy_init_op) |
| |
| # Save the SavedModel to disk. |
| builder.save() |
| |
| with self.session(graph=ops.Graph()) as sess: |
| loader.load(sess, ["foo"], export_dir) |
| self.assertEqual(1, ops.get_collection("v")[0].eval()) |
| self.assertEqual(2, ops.get_collection("v")[1].eval()) |
| # Evaluates to the sum of the first two variables and assigned as part of |
| # the legacy_init_op, following a restore. |
| self.assertEqual(3, ops.get_collection("v")[2].eval()) |
| |
| |
| if __name__ == "__main__": |
| test.main() |