| # Lint as: python2, python3 |
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Python command line interface for converting TF models to TFLite models.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import argparse |
| import os |
| import sys |
| import warnings |
| |
| import six |
| from six.moves import zip |
| |
| from tensorflow.lite.python import lite |
| from tensorflow.lite.python import lite_constants |
| from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 |
| from tensorflow.lite.toco.logging import gen_html |
| from tensorflow.python import keras |
| from tensorflow.python import tf2 |
| from tensorflow.python.platform import app |
| |
| |
| def _parse_array(values, type_fn=str): |
| if values is not None: |
| return [type_fn(val) for val in six.ensure_str(values).split(",") if val] |
| return None |
| |
| |
| def _parse_set(values): |
| if values is not None: |
| return set([item for item in six.ensure_str(values).split(",") if item]) |
| return None |
| |
| |
| def _parse_inference_type(value, flag): |
| """Converts the inference type to the value of the constant. |
| |
| Args: |
| value: str representing the inference type. |
| flag: str representing the flag name. |
| |
| Returns: |
| tf.dtype. |
| |
| Raises: |
| ValueError: Unsupported value. |
| """ |
| if value == "FLOAT": |
| return lite_constants.FLOAT |
| if value == "QUANTIZED_UINT8": |
| return lite_constants.QUANTIZED_UINT8 |
| if value == "INT8": |
| return lite_constants.INT8 |
| raise ValueError("Unsupported value for --{0}. Only FLOAT and " |
| "QUANTIZED_UINT8 are supported.".format(flag)) |
| |
| |
| def _get_tflite_converter(flags): |
| """Makes a TFLiteConverter object based on the flags provided. |
| |
| Args: |
| flags: argparse.Namespace object containing TFLite flags. |
| |
| Returns: |
| TFLiteConverter object. |
| |
| Raises: |
| ValueError: Invalid flags. |
| """ |
| # Parse input and output arrays. |
| input_arrays = _parse_array(flags.input_arrays) |
| input_shapes = None |
| if flags.input_shapes: |
| input_shapes_list = [ |
| _parse_array(shape, type_fn=int) |
| for shape in six.ensure_str(flags.input_shapes).split(":") |
| ] |
| input_shapes = dict(list(zip(input_arrays, input_shapes_list))) |
| output_arrays = _parse_array(flags.output_arrays) |
| |
| converter_kwargs = { |
| "input_arrays": input_arrays, |
| "input_shapes": input_shapes, |
| "output_arrays": output_arrays |
| } |
| |
| # Create TFLiteConverter. |
| if flags.graph_def_file: |
| converter_fn = lite.TFLiteConverter.from_frozen_graph |
| converter_kwargs["graph_def_file"] = flags.graph_def_file |
| elif flags.saved_model_dir: |
| converter_fn = lite.TFLiteConverter.from_saved_model |
| converter_kwargs["saved_model_dir"] = flags.saved_model_dir |
| converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) |
| converter_kwargs["signature_key"] = flags.saved_model_signature_key |
| elif flags.keras_model_file: |
| converter_fn = lite.TFLiteConverter.from_keras_model_file |
| converter_kwargs["model_file"] = flags.keras_model_file |
| else: |
| raise ValueError("--graph_def_file, --saved_model_dir, or " |
| "--keras_model_file must be specified.") |
| |
| return converter_fn(**converter_kwargs) |
| |
| |
| def _convert_tf1_model(flags): |
| """Calls function to convert the TensorFlow 1.X model into a TFLite model. |
| |
| Args: |
| flags: argparse.Namespace object. |
| |
| Raises: |
| ValueError: Invalid flags. |
| """ |
| # Create converter. |
| converter = _get_tflite_converter(flags) |
| if flags.inference_type: |
| converter.inference_type = _parse_inference_type(flags.inference_type, |
| "inference_type") |
| if flags.inference_input_type: |
| converter.inference_input_type = _parse_inference_type( |
| flags.inference_input_type, "inference_input_type") |
| if flags.output_format: |
| converter.output_format = _toco_flags_pb2.FileFormat.Value( |
| flags.output_format) |
| |
| if flags.mean_values and flags.std_dev_values: |
| input_arrays = converter.get_input_arrays() |
| std_dev_values = _parse_array(flags.std_dev_values, type_fn=float) |
| |
| # In quantized inference, mean_value has to be integer so that the real |
| # value 0.0 is exactly representable. |
| if converter.inference_type == lite_constants.QUANTIZED_UINT8: |
| mean_values = _parse_array(flags.mean_values, type_fn=int) |
| else: |
| mean_values = _parse_array(flags.mean_values, type_fn=float) |
| quant_stats = list(zip(mean_values, std_dev_values)) |
| if ((not flags.input_arrays and len(input_arrays) > 1) or |
| (len(input_arrays) != len(quant_stats))): |
| raise ValueError("Mismatching --input_arrays, --std_dev_values, and " |
| "--mean_values. The flags must have the same number of " |
| "items. The current input arrays are '{0}'. " |
| "--input_arrays must be present when specifying " |
| "--std_dev_values and --mean_values with multiple input " |
| "tensors in order to map between names and " |
| "values.".format(",".join(input_arrays))) |
| converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats))) |
| if (flags.default_ranges_min is not None) and (flags.default_ranges_max is |
| not None): |
| converter.default_ranges_stats = (flags.default_ranges_min, |
| flags.default_ranges_max) |
| |
| if flags.drop_control_dependency: |
| converter.drop_control_dependency = flags.drop_control_dependency |
| if flags.reorder_across_fake_quant: |
| converter.reorder_across_fake_quant = flags.reorder_across_fake_quant |
| if flags.change_concat_input_ranges: |
| converter.change_concat_input_ranges = ( |
| flags.change_concat_input_ranges == "TRUE") |
| |
| if flags.allow_custom_ops: |
| converter.allow_custom_ops = flags.allow_custom_ops |
| if flags.custom_opdefs: |
| converter._custom_opdefs = _parse_array(flags.custom_opdefs) # pylint: disable=protected-access |
| if flags.target_ops: |
| ops_set_options = lite.OpsSet.get_options() |
| converter.target_spec.supported_ops = set() |
| for option in six.ensure_str(flags.target_ops).split(","): |
| if option not in ops_set_options: |
| raise ValueError("Invalid value for --target_ops. Options: " |
| "{0}".format(",".join(ops_set_options))) |
| converter.target_spec.supported_ops.add(lite.OpsSet(option)) |
| |
| if flags.post_training_quantize: |
| converter.optimizations = [lite.Optimize.DEFAULT] |
| if converter.inference_type == lite_constants.QUANTIZED_UINT8: |
| print("--post_training_quantize quantizes a graph of inference_type " |
| "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.") |
| converter.inference_type = lite_constants.FLOAT |
| |
| if flags.quantize_to_float16: |
| converter.target_spec.supported_types = [lite.constants.FLOAT16] |
| if not flags.post_training_quantize: |
| print("--quantize_to_float16 will only take effect with the " |
| "--post_training_quantize flag enabled.") |
| |
| if flags.dump_graphviz_dir: |
| converter.dump_graphviz_dir = flags.dump_graphviz_dir |
| if flags.dump_graphviz_video: |
| converter.dump_graphviz_vode = flags.dump_graphviz_video |
| if flags.conversion_summary_dir: |
| converter.conversion_summary_dir = flags.conversion_summary_dir |
| |
| if flags.experimental_new_converter is not None: |
| converter.experimental_new_converter = flags.experimental_new_converter |
| |
| # Convert model. |
| output_data = converter.convert() |
| with open(flags.output_file, "wb") as f: |
| f.write(six.ensure_binary(output_data)) |
| |
| |
| def _convert_tf2_model(flags): |
| """Calls function to convert the TensorFlow 2.0 model into a TFLite model. |
| |
| Args: |
| flags: argparse.Namespace object. |
| |
| Raises: |
| ValueError: Unsupported file format. |
| """ |
| # Load the model. |
| if flags.saved_model_dir: |
| converter = lite.TFLiteConverterV2.from_saved_model(flags.saved_model_dir) |
| elif flags.keras_model_file: |
| model = keras.models.load_model(flags.keras_model_file) |
| converter = lite.TFLiteConverterV2.from_keras_model(model) |
| |
| if flags.experimental_new_converter is not None: |
| converter.experimental_new_converter = flags.experimental_new_converter |
| |
| # Convert the model. |
| tflite_model = converter.convert() |
| with open(flags.output_file, "wb") as f: |
| f.write(six.ensure_binary(tflite_model)) |
| |
| |
| def _check_tf1_flags(flags, unparsed): |
| """Checks the parsed and unparsed flags to ensure they are valid in 1.X. |
| |
| Raises an error if previously support unparsed flags are found. Raises an |
| error for parsed flags that don't meet the required conditions. |
| |
| Args: |
| flags: argparse.Namespace object containing TFLite flags. |
| unparsed: List of unparsed flags. |
| |
| Raises: |
| ValueError: Invalid flags. |
| """ |
| |
| # Check unparsed flags for common mistakes based on previous TOCO. |
| def _get_message_unparsed(flag, orig_flag, new_flag): |
| if six.ensure_str(flag).startswith(orig_flag): |
| return "\n Use {0} instead of {1}".format(new_flag, orig_flag) |
| return "" |
| |
| if unparsed: |
| output = "" |
| for flag in unparsed: |
| output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") |
| output += _get_message_unparsed(flag, "--savedmodel_directory", |
| "--saved_model_dir") |
| output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") |
| output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") |
| output += _get_message_unparsed(flag, "--dump_graphviz", |
| "--dump_graphviz_dir") |
| if output: |
| raise ValueError(output) |
| |
| # Check that flags are valid. |
| if flags.graph_def_file and (not flags.input_arrays or |
| not flags.output_arrays): |
| raise ValueError("--input_arrays and --output_arrays are required with " |
| "--graph_def_file") |
| |
| if flags.input_shapes: |
| if not flags.input_arrays: |
| raise ValueError("--input_shapes must be used with --input_arrays") |
| if flags.input_shapes.count(":") != flags.input_arrays.count(","): |
| raise ValueError("--input_shapes and --input_arrays must have the same " |
| "number of items") |
| |
| if flags.std_dev_values or flags.mean_values: |
| if bool(flags.std_dev_values) != bool(flags.mean_values): |
| raise ValueError("--std_dev_values and --mean_values must be used " |
| "together") |
| if flags.std_dev_values.count(",") != flags.mean_values.count(","): |
| raise ValueError("--std_dev_values, --mean_values must have the same " |
| "number of items") |
| |
| if (flags.default_ranges_min is None) != (flags.default_ranges_max is None): |
| raise ValueError("--default_ranges_min and --default_ranges_max must be " |
| "used together") |
| |
| if flags.dump_graphviz_video and not flags.dump_graphviz_dir: |
| raise ValueError("--dump_graphviz_video must be used with " |
| "--dump_graphviz_dir") |
| |
| if flags.custom_opdefs and not flags.experimental_new_converter: |
| raise ValueError("--custom_opdefs must be used with " |
| "--experimental_new_converter") |
| if flags.custom_opdefs and not flags.allow_custom_ops: |
| raise ValueError("--custom_opdefs must be used with --allow_custom_ops") |
| |
| |
| def _check_tf2_flags(flags): |
| """Checks the parsed and unparsed flags to ensure they are valid in 2.X. |
| |
| Args: |
| flags: argparse.Namespace object containing TFLite flags. |
| |
| Raises: |
| ValueError: Invalid flags. |
| """ |
| if not flags.keras_model_file and not flags.saved_model_dir: |
| raise ValueError("one of the arguments --saved_model_dir " |
| "--keras_model_file is required") |
| |
| |
| def _get_tf1_flags(parser): |
| """Returns ArgumentParser for tflite_convert for TensorFlow 1.X. |
| |
| Args: |
| parser: ArgumentParser |
| """ |
| # Input file flags. |
| input_file_group = parser.add_mutually_exclusive_group(required=True) |
| input_file_group.add_argument( |
| "--graph_def_file", |
| type=str, |
| help="Full filepath of file containing frozen TensorFlow GraphDef.") |
| input_file_group.add_argument( |
| "--saved_model_dir", |
| type=str, |
| help="Full filepath of directory containing the SavedModel.") |
| input_file_group.add_argument( |
| "--keras_model_file", |
| type=str, |
| help="Full filepath of HDF5 file containing tf.Keras model.") |
| |
| # Model format flags. |
| parser.add_argument( |
| "--output_format", |
| type=str.upper, |
| choices=["TFLITE", "GRAPHVIZ_DOT"], |
| help="Output file format.") |
| parser.add_argument( |
| "--inference_type", |
| type=str.upper, |
| choices=["FLOAT", "QUANTIZED_UINT8", "INT8"], |
| help="Target data type of real-number arrays in the output file.") |
| parser.add_argument( |
| "--inference_input_type", |
| type=str.upper, |
| choices=["FLOAT", "QUANTIZED_UINT8", "INT8"], |
| help=("Target data type of real-number input arrays. Allows for a " |
| "different type for input arrays in the case of quantization.")) |
| |
| # Input and output arrays flags. |
| parser.add_argument( |
| "--input_arrays", |
| type=str, |
| help="Names of the input arrays, comma-separated.") |
| parser.add_argument( |
| "--input_shapes", |
| type=str, |
| help="Shapes corresponding to --input_arrays, colon-separated.") |
| parser.add_argument( |
| "--output_arrays", |
| type=str, |
| help="Names of the output arrays, comma-separated.") |
| |
| # SavedModel related flags. |
| parser.add_argument( |
| "--saved_model_tag_set", |
| type=str, |
| help=("Comma-separated set of tags identifying the MetaGraphDef within " |
| "the SavedModel to analyze. All tags must be present. In order to " |
| "pass in an empty tag set, pass in \"\". (default \"serve\")")) |
| parser.add_argument( |
| "--saved_model_signature_key", |
| type=str, |
| help=("Key identifying the SignatureDef containing inputs and outputs. " |
| "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) |
| |
| # Quantization flags. |
| parser.add_argument( |
| "--std_dev_values", |
| type=str, |
| help=("Standard deviation of training data for each input tensor, " |
| "comma-separated floats. Used for quantized input tensors. " |
| "(default None)")) |
| parser.add_argument( |
| "--mean_values", |
| type=str, |
| help=("Mean of training data for each input tensor, comma-separated " |
| "floats. Used for quantized input tensors. (default None)")) |
| parser.add_argument( |
| "--default_ranges_min", |
| type=float, |
| help=("Default value for min bound of min/max range values used for all " |
| "arrays without a specified range, Intended for experimenting with " |
| "quantization via \"dummy quantization\". (default None)")) |
| parser.add_argument( |
| "--default_ranges_max", |
| type=float, |
| help=("Default value for max bound of min/max range values used for all " |
| "arrays without a specified range, Intended for experimenting with " |
| "quantization via \"dummy quantization\". (default None)")) |
| # quantize_weights is DEPRECATED. |
| parser.add_argument( |
| "--quantize_weights", |
| dest="post_training_quantize", |
| action="store_true", |
| help=argparse.SUPPRESS) |
| parser.add_argument( |
| "--post_training_quantize", |
| dest="post_training_quantize", |
| action="store_true", |
| help=( |
| "Boolean indicating whether to quantize the weights of the " |
| "converted float model. Model size will be reduced and there will " |
| "be latency improvements (at the cost of accuracy). (default False)")) |
| parser.add_argument( |
| "--quantize_to_float16", |
| dest="quantize_to_float16", |
| action="store_true", |
| help=("Boolean indicating whether to quantize weights to fp16 instead of " |
| "the default int8 when post-training quantization " |
| "(--post_training_quantize) is enabled. (default False)")) |
| # Graph manipulation flags. |
| parser.add_argument( |
| "--drop_control_dependency", |
| action="store_true", |
| help=("Boolean indicating whether to drop control dependencies silently. " |
| "This is due to TensorFlow not supporting control dependencies. " |
| "(default True)")) |
| parser.add_argument( |
| "--reorder_across_fake_quant", |
| action="store_true", |
| help=("Boolean indicating whether to reorder FakeQuant nodes in " |
| "unexpected locations. Used when the location of the FakeQuant " |
| "nodes is preventing graph transformations necessary to convert " |
| "the graph. Results in a graph that differs from the quantized " |
| "training graph, potentially causing differing arithmetic " |
| "behavior. (default False)")) |
| # Usage for this flag is --change_concat_input_ranges=true or |
| # --change_concat_input_ranges=false in order to make it clear what the flag |
| # is set to. This keeps the usage consistent with other usages of the flag |
| # where the default is different. The default value here is False. |
| parser.add_argument( |
| "--change_concat_input_ranges", |
| type=str.upper, |
| choices=["TRUE", "FALSE"], |
| help=("Boolean to change behavior of min/max ranges for inputs and " |
| "outputs of the concat operator for quantized models. Changes the " |
| "ranges of concat operator overlap when true. (default False)")) |
| |
| # Permitted ops flags. |
| parser.add_argument( |
| "--allow_custom_ops", |
| action="store_true", |
| help=("Boolean indicating whether to allow custom operations. When false " |
| "any unknown operation is an error. When true, custom ops are " |
| "created for any op that is unknown. The developer will need to " |
| "provide these to the TensorFlow Lite runtime with a custom " |
| "resolver. (default False)")) |
| parser.add_argument( |
| "--custom_opdefs", |
| type=str, |
| help=("String representing a list of custom ops OpDefs delineated with " |
| "commas that are included in the GraphDef. Required when using " |
| "custom operations with --experimental_new_converter.")) |
| parser.add_argument( |
| "--target_ops", |
| type=str, |
| help=("Experimental flag, subject to change. Set of OpsSet options " |
| "indicating which converter to use. Options: {0}. One or more " |
| "option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))" |
| "".format(",".join(lite.OpsSet.get_options())))) |
| |
| # Logging flags. |
| parser.add_argument( |
| "--dump_graphviz_dir", |
| type=str, |
| help=("Full filepath of folder to dump the graphs at various stages of " |
| "processing GraphViz .dot files. Preferred over --output_format=" |
| "GRAPHVIZ_DOT in order to keep the requirements of the output " |
| "file.")) |
| parser.add_argument( |
| "--dump_graphviz_video", |
| action="store_true", |
| help=("Boolean indicating whether to dump the graph after every graph " |
| "transformation")) |
| parser.add_argument( |
| "--conversion_summary_dir", |
| type=str, |
| help=("Full filepath to store the conversion logs, which includes " |
| "graphviz of the model before/after the conversion, an HTML report " |
| "and the conversion proto buffers. This will only be generated " |
| "when passing --experimental_new_converter")) |
| |
| |
| def _get_tf2_flags(parser): |
| """Returns ArgumentParser for tflite_convert for TensorFlow 2.0. |
| |
| Args: |
| parser: ArgumentParser |
| """ |
| # Input file flags. |
| input_file_group = parser.add_mutually_exclusive_group() |
| input_file_group.add_argument( |
| "--saved_model_dir", |
| type=str, |
| help="Full path of the directory containing the SavedModel.") |
| input_file_group.add_argument( |
| "--keras_model_file", |
| type=str, |
| help="Full filepath of HDF5 file containing tf.Keras model.") |
| |
| # Enables 1.X converter in 2.X. |
| parser.add_argument( |
| "--enable_v1_converter", |
| action="store_true", |
| help=("Enables the TensorFlow V1 converter in 2.0")) |
| |
| |
| class _ParseExperimentalNewConverter(argparse.Action): |
| """Helper class to parse --experimental_new_converter argument.""" |
| |
| def __init__(self, option_strings, dest, nargs=None, **kwargs): |
| if nargs != "?": |
| # This should never happen. This class is only used once below with |
| # nargs="?". |
| raise ValueError( |
| "This parser only supports nargs='?' (0 or 1 additional arguments)") |
| super(_ParseExperimentalNewConverter, self).__init__( |
| option_strings, dest, nargs=nargs, **kwargs) |
| |
| def __call__(self, parser, namespace, values, option_string=None): |
| if values is None: |
| # Handling `--experimental_new_converter`. |
| # Without additional arguments, it implies enabling the new converter. |
| experimental_new_converter = True |
| elif values.lower() == "true": |
| # Handling `--experimental_new_converter=true`. |
| # (Case insensitive after the equal sign) |
| experimental_new_converter = True |
| elif values.lower() == "false": |
| # Handling `--experimental_new_converter=false`. |
| # (Case insensitive after the equal sign) |
| experimental_new_converter = False |
| else: |
| raise ValueError("Invalid --experimental_new_converter argument.") |
| setattr(namespace, self.dest, experimental_new_converter) |
| |
| |
| def _get_parser(use_v2_converter): |
| """Returns an ArgumentParser for tflite_convert. |
| |
| Args: |
| use_v2_converter: Indicates which converter to return. |
| Return: ArgumentParser. |
| """ |
| parser = argparse.ArgumentParser( |
| description=("Command line tool to run TensorFlow Lite Converter.")) |
| |
| # Output file flag. |
| parser.add_argument( |
| "--output_file", |
| type=str, |
| help="Full filepath of the output file.", |
| required=True) |
| |
| if use_v2_converter: |
| _get_tf2_flags(parser) |
| else: |
| _get_tf1_flags(parser) |
| |
| parser.add_argument( |
| "--experimental_new_converter", |
| action=_ParseExperimentalNewConverter, |
| nargs="?", |
| help=("Experimental flag, subject to change. Enables MLIR-based " |
| "conversion instead of TOCO conversion. (default True)")) |
| return parser |
| |
| |
| def run_main(_): |
| """Main in tflite_convert.py.""" |
| use_v2_converter = tf2.enabled() |
| parser = _get_parser(use_v2_converter) |
| tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) |
| |
| # If the user is running TensorFlow 2.X but has passed in enable_v1_converter |
| # then parse the flags again with the 1.X converter flags. |
| if tf2.enabled() and tflite_flags.enable_v1_converter: |
| use_v2_converter = False |
| parser = _get_parser(use_v2_converter) |
| tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) |
| |
| # Checks if the flags are valid. |
| try: |
| if use_v2_converter: |
| _check_tf2_flags(tflite_flags) |
| else: |
| _check_tf1_flags(tflite_flags, unparsed) |
| except ValueError as e: |
| parser.print_usage() |
| file_name = os.path.basename(sys.argv[0]) |
| sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e))) |
| sys.exit(1) |
| |
| # Convert the model according to the user provided flag. |
| if use_v2_converter: |
| _convert_tf2_model(tflite_flags) |
| else: |
| try: |
| _convert_tf1_model(tflite_flags) |
| finally: |
| if tflite_flags.conversion_summary_dir: |
| if tflite_flags.experimental_new_converter: |
| gen_html.gen_conversion_log_html(tflite_flags.conversion_summary_dir, |
| tflite_flags.post_training_quantize, |
| tflite_flags.output_file) |
| else: |
| warnings.warn( |
| "Conversion summary will only be generated when enabling" |
| " the new converter via --experimental_new_converter. ") |
| |
| |
| def main(): |
| app.run(main=run_main, argv=sys.argv[:1]) |
| |
| |
| if __name__ == "__main__": |
| main() |