blob: 161bf7c6a0dea612593bb1c1e22e43db1df89c49 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generate a series of TensorFlow graphs that become tflite test cases.
Usage:
generate_examples <output directory>
bazel run //tensorflow/lite/testing:generate_examples
To more easily debug failures use (or override) the --save_graphdefs flag to
place text proto graphdefs into the generated zip files.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import argparse
import os
import sys
from tensorflow.lite.testing import generate_examples_lib
from tensorflow.lite.testing import toco_convert
# TODO(aselle): Disable GPU for now
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
parser.add_argument("output_path",
help="Directory where the outputs will be go.")
parser.add_argument(
"--zip_to_output",
type=str,
help="Particular zip to output.",
required=True)
parser.add_argument("--toco",
type=str,
help="Path to toco tool.",
required=True)
parser.add_argument(
"--known_bugs_are_errors",
action="store_true",
help=("If a particular model is affected by a known bug,"
" count it as a converter error."))
parser.add_argument(
"--ignore_converter_errors",
action="store_true",
help="Raise an exception if any converter error is encountered.")
parser.add_argument(
"--save_graphdefs",
action="store_true",
help="Include intermediate graphdefs in the output zip files.")
parser.add_argument(
"--run_with_flex",
action="store_true",
help="Whether the TFLite Flex converter is being used.")
parser.add_argument(
"--make_edgetpu_tests",
action="store_true",
help="Whether to generate test cases for edgetpu.")
parser.add_argument(
"--make_forward_compat_test",
action="store_true",
help="Make tests by setting TF forward compatibility horizon to the future")
parser.add_argument(
"--no_tests_limit",
action="store_true",
help="Remove the limit of the number of tests.")
parser.add_argument(
"--no_conversion_report",
action="store_true",
help="Do not create conversion report.")
parser.add_argument(
"--test_sets",
type=str,
help=("Comma-separated list of test set names to generate. "
"If not specified, a test set is selected by parsing the name of "
"'zip_to_output' file."))
# Toco binary path provided by the generate rule.
bin_path = None
def main(unused_args):
options = generate_examples_lib.Options()
options.output_path = FLAGS.output_path
options.zip_to_output = FLAGS.zip_to_output
options.toco = FLAGS.toco
options.known_bugs_are_errors = FLAGS.known_bugs_are_errors
options.ignore_converter_errors = FLAGS.ignore_converter_errors
options.save_graphdefs = FLAGS.save_graphdefs
options.run_with_flex = FLAGS.run_with_flex
options.make_edgetpu_tests = FLAGS.make_edgetpu_tests
options.make_forward_compat_test = FLAGS.make_forward_compat_test
options.tflite_convert_function = toco_convert.toco_convert
options.no_tests_limit = FLAGS.no_tests_limit
options.no_conversion_report = FLAGS.no_conversion_report
if FLAGS.test_sets:
test_sets = FLAGS.test_sets.split(",")
generate_examples_lib.generate_multi_set_examples(options, test_sets)
else:
generate_examples_lib.generate_examples(options)
if __name__ == "__main__":
FLAGS, unparsed = parser.parse_known_args()
if unparsed:
parser.print_usage()
print("\nGot the following unparsed args, %r please fix.\n" % unparsed)
exit(1)
else:
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)