blob: 232a5b5e1ba43643875a949ebffce08581096359 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SavedModel integration tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import tensorflow.compat.v2 as tf
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
from tensorflow.examples.saved_model.integration_tests import integration_scripts as scripts
from tensorflow.python.distribute import combinations
class SavedModelTest(scripts.TestCase, parameterized.TestCase):
def __init__(self, method_name="runTest", has_extra_deps=False):
super(SavedModelTest, self).__init__(method_name)
self.has_extra_deps = has_extra_deps
def skipIfMissingExtraDeps(self):
"""Skip test if it requires extra dependencies.
b/132234211: The extra dependencies are not available in all environments
that run the tests, e.g. "tensorflow_hub" is not available from tests
within "tensorflow" alone. Those tests are instead run by another
internal test target.
"""
if not self.has_extra_deps:
self.skipTest("Missing extra dependencies")
def test_text_rnn(self):
export_dir = self.get_temp_dir()
self.assertCommandSucceeded("export_text_rnn_model", export_dir=export_dir)
self.assertCommandSucceeded("use_text_rnn_model", model_dir=export_dir)
def test_rnn_cell(self):
export_dir = self.get_temp_dir()
self.assertCommandSucceeded("export_rnn_cell", export_dir=export_dir)
self.assertCommandSucceeded("use_rnn_cell", model_dir=export_dir)
def test_text_embedding_in_sequential_keras(self):
self.skipIfMissingExtraDeps()
export_dir = self.get_temp_dir()
self.assertCommandSucceeded(
"export_simple_text_embedding", export_dir=export_dir)
self.assertCommandSucceeded(
"use_model_in_sequential_keras", model_dir=export_dir)
def test_text_embedding_in_dataset(self):
export_dir = self.get_temp_dir()
self.assertCommandSucceeded(
"export_simple_text_embedding", export_dir=export_dir)
self.assertCommandSucceeded(
"use_text_embedding_in_dataset", model_dir=export_dir)
TEST_MNIST_CNN_GENERATE_KWARGS = dict(
combinations=(
combinations.combine(
# Test all combinations with tf.saved_model.save().
# Test all combinations using tf.keras.models.save_model()
# for both the reusable and the final full model.
use_keras_save_api=True,
named_strategy=list(ds_utils.named_strategies.values()),
retrain_flag_value=["true", "false"],
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
) + combinations.combine(
# Test few critcial combinations with raw tf.saved_model.save(),
# including export of a reusable SavedModel that gets assembled
# manually, including support for adjustable hparams.
use_keras_save_api=False,
named_strategy=None,
retrain_flag_value=["true", "false"],
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
)),
test_combinations=[combinations.NamedGPUCombination()])
@combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS)
def test_mnist_cnn(self, use_keras_save_api, named_strategy,
retrain_flag_value, regularization_loss_multiplier):
self.skipIfMissingExtraDeps()
fast_test_mode = True
temp_dir = self.get_temp_dir()
feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
full_model_dir = os.path.join(temp_dir, "full_model")
self.assertCommandSucceeded(
"export_mnist_cnn",
fast_test_mode=fast_test_mode,
export_dir=feature_extrator_dir,
use_keras_save_api=use_keras_save_api)
use_kwargs = dict(fast_test_mode=fast_test_mode,
input_saved_model_dir=feature_extrator_dir,
retrain=retrain_flag_value,
output_saved_model_dir=full_model_dir,
use_keras_save_api=use_keras_save_api)
if named_strategy:
use_kwargs["strategy"] = str(named_strategy)
if regularization_loss_multiplier is not None:
use_kwargs[
"regularization_loss_multiplier"] = regularization_loss_multiplier
self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs)
self.assertCommandSucceeded(
"deploy_mnist_cnn",
fast_test_mode=fast_test_mode,
saved_model_dir=full_model_dir)
if __name__ == "__main__":
scripts.MaybeRunScriptInstead()
tf.test.main()