blob: 90555043f12ab7151086838884fa410a5ebabc10 [file] [log] [blame]
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite tooling helper functionality."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
import functools
import pprint
import shutil
import tempfile
import time
import warnings
from absl import logging
import six
from six import PY2
from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op # pylint: disable=unused-import
from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.lite.python.convert import convert_saved_model as _convert_saved_model
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
from tensorflow.lite.python.convert import mlir_quantize as _mlir_quantize
from tensorflow.lite.python.convert import mlir_sparsify as _mlir_sparsify
from tensorflow.lite.python.convert import OpsSet
from tensorflow.lite.python.convert import toco_convert # pylint: disable=unused-import
from tensorflow.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl
from tensorflow.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.lite.python.convert_phase import Component
from tensorflow.lite.python.convert_phase import convert_phase
from tensorflow.lite.python.convert_phase import SubComponent
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
from tensorflow.lite.python.interpreter import OpResolverType # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
from tensorflow.lite.python.optimize import calibrator as _calibrator
from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func
from tensorflow.lite.python.util import convert_debug_info_func as _convert_debug_info_func
from tensorflow.lite.python.util import freeze_graph as _freeze_graph
from tensorflow.lite.python.util import get_debug_info as _get_debug_info
from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
from tensorflow.lite.python.util import get_tf_type_name as _get_tf_type_name
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import model_input_signature as _model_input_signature
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
from tensorflow.lite.python.util import trace_model_call as _trace_model_call
from tensorflow.python import saved_model as _saved_model
from tensorflow.python.client import session as _session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function as _def_function
from tensorflow.python.eager import function as _function
from tensorflow.python.framework import convert_to_constants as _convert_to_constants
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader_impl as _loader_impl
from tensorflow.python.saved_model import save_options as _save_options
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
from tensorflow.python.saved_model.load import load as _load
from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info as _parse_saved_model_with_debug_info
from tensorflow.python.util import deprecation as _deprecation
from tensorflow.python.util import keras_deps
from tensorflow.python.util.tf_export import tf_export as _tf_export
# pylint: disable=g-import-not-at-top
try:
from tensorflow.lite.python import metrics_portable as metrics
except ImportError:
from tensorflow.lite.python import metrics_nonportable as metrics
# pylint: enable=g-import-not-at-top
@_tf_export("lite.Optimize")
class Optimize(enum.Enum):
"""Enum defining the optimizations to apply when generating a tflite model.
DEFAULT
Default optimization strategy that quantizes model weights. Enhanced
optimizations are gained by providing a representative dataset that
quantizes biases and activations as well.
Converter will do its best to reduce size and latency, while minimizing
the loss in accuracy.
OPTIMIZE_FOR_SIZE
Deprecated. Does the same as DEFAULT.
OPTIMIZE_FOR_LATENCY
Deprecated. Does the same as DEFAULT.
EXPERIMENTAL_SPARSITY
Experimental flag, subject to change.
Enable optimization by taking advantage of the sparse model weights
trained with pruning.
The converter will inspect the sparsity pattern of the model weights and
do its best to improve size and latency.
The flag can be used alone to optimize float32 models with sparse weights.
It can also be used together with the DEFAULT optimization mode to
optimize quantized models with sparse weights.
"""
# Default optimization strategy that quantizes model weights. Enhanced
# optimizations are gained by providing a representative dataset that
# quantizes biases and activations as well.
# Converter will do its best to reduce size and latency, while minimizing
# the loss in accuracy.
DEFAULT = "DEFAULT"
# Deprecated. Does the same as DEFAULT.
OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
# Deprecated. Does the same as DEFAULT.
OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
# Experimental flag, subject to change.
# Enable optimization by taking advantage of the sparse model weights trained
# with pruning.
#
# The converter will inspect the sparsity pattern of the model weights and do
# its best to improve size and latency.
# The flag can be used alone to optimize float32 models with sparse weights.
# It can also be used together with the DEFAULT optimization mode to optimize
# quantized models with sparse weights.
# TODO(b/161560631): Add log message when this optimization is applied.
EXPERIMENTAL_SPARSITY = "EXPERIMENTAL_SPARSITY"
def __str__(self):
return str(self.value)
@_tf_export("lite.RepresentativeDataset")
class RepresentativeDataset(object):
"""Representative dataset used to optimize the model.
This is a generator function that provides a small dataset to calibrate or
estimate the range, i.e, (min, max) of all floating-point arrays in the model
(such as model input, activation outputs of intermediate layers, and model
output) for quantization. Usually, this is a small subset of a few hundred
samples randomly chosen, in no particular order, from the training or
evaluation dataset.
"""
def __init__(self, input_gen):
"""Creates a representative dataset.
Args:
input_gen: A generator function that generates input samples for the
model and has the same order, type and shape as the inputs to the model.
Usually, this is a small subset of a few hundred samples randomly
chosen, in no particular order, from the training or evaluation dataset.
"""
self.input_gen = input_gen
@_tf_export("lite.TargetSpec")
class TargetSpec(object):
"""Specification of target device used to optimize the model.
Attributes:
supported_ops: Experimental flag, subject to change. Set of `tf.lite.OpsSet`
options, where each option represents a set of operators supported by the
target device. (default {tf.lite.OpsSet.TFLITE_BUILTINS}))
supported_types: Set of `tf.dtypes.DType` data types supported on the target
device. If initialized, optimization might be driven by the smallest type
in this set. (default set())
experimental_select_user_tf_ops: Experimental flag, subject to change. Set
of user's TensorFlow operators' names that are required in the TensorFlow
Lite runtime. These ops will be exported as select TensorFlow ops in the
model (in conjunction with the tf.lite.OpsSet.SELECT_TF_OPS flag). This is
an advanced feature that should only be used if the client is using TF ops
that may not be linked in by default with the TF ops that are provided
when using the SELECT_TF_OPS path. The client is responsible for linking
these ops into the target runtime.
"""
def __init__(self,
supported_ops=None,
supported_types=None,
experimental_select_user_tf_ops=None):
if supported_ops is None:
supported_ops = {OpsSet.TFLITE_BUILTINS}
self.supported_ops = supported_ops
if supported_types is None:
supported_types = set()
self.supported_types = supported_types
if experimental_select_user_tf_ops is None:
experimental_select_user_tf_ops = set()
self.experimental_select_user_tf_ops = experimental_select_user_tf_ops
self._experimental_custom_op_registerers = []
# Hint for the supported accumulation type used for inference. Typically
# used for fp16 post-training quantization, where some models can use fp16
# accumulators instead of the typical fp32 type.
# TODO(b/188185962): Provide full API and authoring support for
# reduced precision accumulation types.
self._experimental_supported_accumulation_type = None
class QuantizationMode(object):
"""QuantizationMode determines the quantization type from user options."""
def __init__(self, optimizations, target_spec, representative_dataset,
graph_def):
self._optimizations = optimizations
for deprecated_optimization in [
Optimize.OPTIMIZE_FOR_SIZE, Optimize.OPTIMIZE_FOR_LATENCY
]:
if deprecated_optimization in self._optimizations:
logging.warning(
"Optimization option %s is deprecated, please use optimizations="
"[Optimize.DEFAULT] instead.", deprecated_optimization)
self._target_spec = target_spec
self._representative_dataset = representative_dataset
self._graph_def = graph_def
self._validate_int8_required()
# TODO(b/162537905): Refactor the following quantization functions -
# re-organize and refactor for better readability.
def post_training_int8_no_float(self):
return (self.any_optimization_enabled() and
self._is_int8_target_required() and
not self._is_int16x8_target_required() and
not self.is_allow_float() and
self._representative_dataset is not None)
def post_training_int8_allow_float(self):
return (self.any_optimization_enabled() and
not self._is_int16x8_target_required() and
self._representative_dataset is not None and
self._smallest_supported_type() == _dtypes.int8)
def is_post_training_integer_quantize_8(self):
return (self.post_training_int8_no_float() or
self.post_training_int8_allow_float())
def is_post_training_integer_quantize_16x8(self):
return (self.post_training_int16x8_no_float() or
self.post_training_int16x8_allow_float())
def is_post_training_integer_quantize(self):
return (self.is_post_training_integer_quantize_8() or
self.is_post_training_integer_quantize_16x8())
def is_integer_quantize(self):
return (self.is_post_training_integer_quantize() or
self.is_training_time_int8_allow_float())
def is_training_time_int8_allow_float(self):
return (self.any_optimization_enabled() and
self.contains_training_quant_op())
def is_bfloat16_inference_allowed(self):
return (self.any_optimization_enabled() and
self._smallest_supported_type().size == 2 and
_dtypes.bfloat16 in self._target_spec.supported_types)
def post_training_int16x8_no_float(self):
return (self.any_optimization_enabled() and
not self._is_int8_target_required() and
self._is_int16x8_target_required() and
not self.is_allow_float() and
self._representative_dataset is not None)
def post_training_int16x8_allow_float(self):
return (self.any_optimization_enabled() and
self._is_int16x8_target_required() and
self.is_allow_float())
def post_training_dynamic_range_int8(self):
# Post-training dynamic range quantization is only enabled if post-training
# int8 quantization and training time quantization was not done.
return (self.any_optimization_enabled() and
self._representative_dataset is None and
not self.contains_training_quant_op() and
self._smallest_supported_type() == _dtypes.int8)
def post_training_fp16(self):
return (self.any_optimization_enabled() and
self._smallest_supported_type().size == 2 and
_dtypes.float16 in self._target_spec.supported_types)
def fp32_execution(self):
"""If none of the above are true."""
return not (self.is_integer_quantize() or
self.post_training_dynamic_range_int8() or
self.post_training_fp16())
def activations_type(self):
if self.is_integer_quantize():
if self._is_int16x8_target_required():
return _dtypes.int16
else:
return _dtypes.int8
else:
return _dtypes.float32
def converter_flags(self, inference_ty=None, inference_input_ty=None):
"""Flags to the converter."""
if self.is_integer_quantize():
return {
"inference_type": (
inference_ty if inference_ty else self.activations_type()),
"inference_input_type": _dtypes.float32,
"post_training_quantize": False, # disable dynamic range quantization
"quantize_to_float16": False # disable float16 quantization
}
elif self.post_training_dynamic_range_int8():
return {
"inference_type": _dtypes.float32,
"inference_input_type": _dtypes.float32,
"post_training_quantize": True, # enable dynamic range quantization
"quantize_to_float16": False # disable float16 quantization
}
elif self.post_training_fp16():
return {
"inference_type": _dtypes.float32,
"inference_input_type": _dtypes.float32,
"post_training_quantize": True,
"quantize_to_float16": True, # enable float16 quantization
"accumulation_type":
self._target_spec._experimental_supported_accumulation_type,
"allow_bfloat16":
self.is_bfloat16_inference_allowed()
}
else:
# Note this might still trigger (uint8) quantization to be compatible with
# TOCO.
return {
"inference_type": inference_ty if inference_ty else _dtypes.float32,
"inference_input_type": inference_input_ty,
"post_training_quantize": False, # enable dynamic range quantization
"quantize_to_float16": False, # disable float16 quantization
"allow_bfloat16": self.is_bfloat16_inference_allowed()
}
# Below are helpers for the above functions.
def _validate_int8_required(self):
"""Int8 mode requires certain parameters to exist and be compatible."""
if not self._is_int8_target_required():
return
if self._target_spec.supported_types and (self._smallest_supported_type() !=
_dtypes.int8):
raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported "
"type to be INT8.")
if self._representative_dataset:
if not isinstance(self._representative_dataset, RepresentativeDataset):
self._representative_dataset = RepresentativeDataset(
self._representative_dataset)
if self._representative_dataset.input_gen is None:
raise ValueError(
"Provide an input generator for representative_dataset")
else:
# TODO(b/162537905): Relax this check for QAT.
raise ValueError("representative_dataset is required when specifying "
"TFLITE_BUILTINS_INT8 or INT8 supported types.")
def _is_int8_target_required(self):
return (OpsSet.TFLITE_BUILTINS_INT8 in set(
self._target_spec.supported_ops)) or (set(
self._target_spec.supported_types) == set([_dtypes.int8]))
def _is_int16x8_target_required(self):
return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
in set(self._target_spec.supported_ops))
def is_allow_float(self):
return (OpsSet.TFLITE_BUILTINS in set(
self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set(
self._target_spec.supported_ops))
def any_optimization_enabled(self):
return bool(
set(self._optimizations).intersection([
Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE,
Optimize.DEFAULT
]))
def _smallest_supported_type(self):
if self._target_spec.supported_types:
return min(self._target_spec.supported_types, key=lambda x: x.size)
else:
# The default smallest supported type is INT8.
return _dtypes.int8
def contains_training_quant_op(self):
"""Checks if the graph contains any training-time quantization ops."""
training_quant_ops = frozenset({
"FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxVarsPerChannel",
"FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsPerChannel",
"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"
})
if self._graph_def:
for node_def in self._graph_def.node:
if node_def.op in training_quant_ops:
return True
for function in self._graph_def.library.function:
for node_def in function.node_def:
if node_def.op in training_quant_ops:
return True
return False
class TFLiteConverterBase(object):
"""Converter subclass to share functionality between V1 and V2 converters."""
def __init__(self):
self.optimizations = set()
self.representative_dataset = None
self.target_spec = TargetSpec()
self.allow_custom_ops = False
self.experimental_new_converter = True
self.experimental_new_quantizer = True
self.experimental_enable_resource_variables = False
self._experimental_new_quantizer = None
self._experimental_calibrate_only = False
self._experimental_sparsify_model = False
self._experimental_disable_per_channel = False
self._debug_info = None # contains the stack traces of all the original
# nodes in the `GraphDef` to the converter.
self.saved_model_dir = None
self._saved_model_tags = None
self._saved_model_version = 0
self._saved_model_exported_names = []
self._tflite_metrics = metrics.TFLiteConverterMetrics()
self._collected_converter_params = {}
self._experimental_disable_batchmatmul_unfold = False
self._experimental_lower_tensor_list_ops = True
def _grappler_config(self, optimizers=None):
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
Args:
optimizers: List of strings that represents the list of optimizers.
Returns:
tf.ConfigProto.
"""
if not optimizers:
optimizers = []
# MLIR converter will take care of constant folding instead of grappler.
if not self.experimental_new_converter:
optimizers.append("constfold")
is_only_flex_enabled = (
set([OpsSet.SELECT_TF_OPS]) == set(self.target_spec.supported_ops))
if is_only_flex_enabled:
# The layout optimizer turns NHCW to NCHW. This provides performance
# optimizations when Flex mode is enabled. However, this is not compatible
# with builtin ops.
optimizers.append("layout")
return _get_grappler_config(optimizers)
def _quantize(self, result, input_type, output_type, activations_type,
allow_float):
"""Quantize the model."""
# pylint: disable=protected-access
custom_op_registerers_by_name = [
x for x in self.target_spec._experimental_custom_op_registerers
if isinstance(x, str)
]
custom_op_registerers_by_func = [
x for x in self.target_spec._experimental_custom_op_registerers
if not isinstance(x, str)
]
# pylint: enable=protected-access
if not isinstance(self.representative_dataset, RepresentativeDataset):
self.representative_dataset = RepresentativeDataset(
self.representative_dataset)
# Add intermediate tensors to the model if needed.
result = _calibrator.add_intermediate_tensors(result)
calibrate_quantize = _calibrator.Calibrator(result,
custom_op_registerers_by_name,
custom_op_registerers_by_func)
if self._experimental_calibrate_only or self.experimental_new_quantizer:
calibrated = calibrate_quantize.calibrate(
self.representative_dataset.input_gen)
if self._experimental_calibrate_only:
return calibrated
elif self.experimental_new_quantizer and (
activations_type != _dtypes.int16):
# TODO(b/175659372): remove the activations_type restriction and enable
# it for all the activation types.
return _mlir_quantize(
calibrated,
self._experimental_disable_per_channel,
input_data_type=input_type,
output_data_type=output_type)
else:
return calibrate_quantize.calibrate_and_quantize(
self.representative_dataset.input_gen, input_type, output_type,
allow_float, activations_type,
disable_per_channel=self._experimental_disable_per_channel)
def _is_unknown_shapes_allowed(self):
# Unknown dimensions are only allowed with the new converter.
return self.experimental_new_converter
def _get_base_converter_args(self):
"""Returns the base converter args.
Returns:
{key str: val}
"""
args = {
"input_format": constants.TENSORFLOW_GRAPHDEF,
"allow_custom_ops": self.allow_custom_ops,
"debug_info": self._debug_info,
"target_ops": self.target_spec.supported_ops,
"enable_mlir_converter": self.experimental_new_converter,
"select_user_tf_ops": self.target_spec.experimental_select_user_tf_ops,
"unfold_batchmatmul": not self._experimental_disable_batchmatmul_unfold,
"lower_tensor_list_ops": self._experimental_lower_tensor_list_ops,
}
if self.saved_model_dir:
args.update({
"saved_model_dir": self.saved_model_dir,
"saved_model_version": self._saved_model_version,
"saved_model_tags": self._saved_model_tags,
"saved_model_exported_names": self._saved_model_exported_names,
})
return args
def _contains_function_with_implements_attr(self, saved_model_proto):
meta_graph = saved_model_proto.meta_graphs[0]
for function in meta_graph.graph_def.library.function:
if function.attr.get("_implements", None) or function.attr.get(
"api_implements", None):
return True
return False
def _parse_saved_model_args(self, always_enable_saved_model_import=False):
"""Parses SavedModel arguments from the given Keras/RNN SavedModel.
Args:
always_enable_saved_model_import: Bool. When the value is true, it enables
MLIR saved model import path regardless of checking the conditions.
"""
if not self.experimental_new_converter:
self.saved_model_dir = None
return
if self.saved_model_dir:
try:
saved_model_proto, _ = (
_parse_saved_model_with_debug_info(self.saved_model_dir))
except OSError:
# If it fails to read the given saved model, it will fall back to the
# frozen graph def path.
self.saved_model_dir = None
return
if (not always_enable_saved_model_import and
not self._contains_function_with_implements_attr(saved_model_proto)):
self.saved_model_dir = None
return
if not self._saved_model_exported_names:
self._saved_model_exported_names = []
self._saved_model_version = saved_model_proto.saved_model_schema_version
if self._saved_model_version == 0:
self.saved_model_dir = None
logging.warning("SavedModel schema version is zero.")
return
if self._saved_model_version not in [1, 2]:
raise ValueError("SavedModel file format({0}) is not supported".format(
self._saved_model_version))
def _sparsify_model(self):
return Optimize.EXPERIMENTAL_SPARSITY in self.optimizations
def _validate_experimental_new_quantizer_flag(self):
if self._experimental_new_quantizer is not None:
raise ValueError("Please use 'experimental_new_quantizer' instead.")
def _increase_conversion_attempt_metric(self):
self._tflite_metrics.increase_counter_converter_attempt()
def _increase_conversion_success_metric(self):
self._tflite_metrics.increase_counter_converter_success()
def _save_conversion_params_metric(self,
graph_def=None,
inference_type=None,
inference_input_type=None):
"""Set conversion parameter metrics."""
converter_kwargs = self._collected_converter_params
converter_kwargs.update(self._get_base_converter_args())
# Optimization parameters.
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
self.representative_dataset, graph_def)
converter_kwargs.update({
"optimization_default":
quant_mode.any_optimization_enabled(),
"optimization_post_training_dynamic_range":
quant_mode.post_training_dynamic_range_int8(),
"optimization_post_training_float16":
quant_mode.post_training_fp16(),
"optimization_post_training_integer_quantize":
quant_mode.is_post_training_integer_quantize(),
"optimization_qat":
quant_mode.is_training_time_int8_allow_float(),
"optimization_sparsify":
self._sparsify_model(),
"activations_type":
quant_mode.activations_type()
})
converter_kwargs.update(
quant_mode.converter_flags(inference_type, inference_input_type))
# pylint: disable=protected-access
if self.target_spec._experimental_supported_accumulation_type:
converter_kwargs.update({
"accumulation_type":
self.target_spec._experimental_supported_accumulation_type
})
# pylint: enable=protected-access
def format_element(elem):
if isinstance(elem, enum.Enum):
return str(elem.value)
return pprint.pformat(elem)
def format_param(param):
if isinstance(param, (list, tuple, set)):
if not param:
return "None" # Return None if empty.
string_list = [format_element(x) for x in param]
return ",".join(sorted(string_list))
return format_element(param)
for key, value in converter_kwargs.items():
self._tflite_metrics.set_converter_param(key, format_param(value))
self._tflite_metrics.set_export_required()
def _set_conversion_latency_metric(self, value):
self._tflite_metrics.set_converter_latency(value)
@convert_phase(Component.OPTIMIZE_TFLITE_MODEL)
def _optimize_tflite_model(self, model, quant_mode, quant_io=True):
"""Apply optimizations on a TFLite model."""
if quant_mode.is_integer_quantize():
in_type, out_type = self.inference_input_type, self.inference_output_type
if quant_mode.is_post_training_integer_quantize():
q_in_type = in_type if in_type and quant_io else _dtypes.float32
q_out_type = out_type if out_type and quant_io else _dtypes.float32
q_activations_type = quant_mode.activations_type()
q_allow_float = quant_mode.is_allow_float()
model = self._quantize(
model, q_in_type, q_out_type, q_activations_type, q_allow_float)
m_in_type = in_type if in_type else _dtypes.float32
m_out_type = out_type if out_type else _dtypes.float32
model = _modify_model_io_type(model, m_in_type, m_out_type)
if self._sparsify_model():
model = _mlir_sparsify(model)
return model
def _convert_and_export_metrics(self, convert_func, *args, **kwargs):
"""Wraps around convert function to export metrics.
Args:
convert_func: The convert function to wrap.
*args: Positional arguments of the convert function.
**kwargs: The keyword arguments of the convert function.
Returns:
The decorator to wrap the convert function.
"""
self._increase_conversion_attempt_metric()
self._save_conversion_params_metric()
start_time = time.process_time()
result = convert_func(self, *args, **kwargs)
elapsed_time_ms = (time.process_time() - start_time) * 1000
if result:
self._increase_conversion_success_metric()
self._set_conversion_latency_metric(round(elapsed_time_ms))
self._tflite_metrics.export_metrics()
return result
def _export_metrics(convert_func):
"""The decorator around convert function to export metrics."""
@functools.wraps(convert_func)
def wrapper(self, *args, **kwargs):
# pylint: disable=protected-access
return self._convert_and_export_metrics(convert_func, *args, **kwargs)
# pylint: enable=protected-access
return wrapper
class TFLiteConverterBaseV2(TFLiteConverterBase):
"""Converter subclass to share functionality between V2 converters."""
def __init__(self):
"""Constructor for TFLiteConverter."""
super(TFLiteConverterBaseV2, self).__init__()
self.inference_input_type = _dtypes.float32
self.inference_output_type = _dtypes.float32
self._collected_converter_params.update({"api_version": 2})
def _validate_inference_input_output_types(self, quant_mode):
"""Validate inference_input_type and inference_output_type flags."""
default_types = [_dtypes.float32]
# We support integer input/output for integer quantized models only.
if quant_mode.is_integer_quantize():
if quant_mode.is_post_training_integer_quantize_16x8():
all_types = default_types + [_dtypes.int16]
else:
all_types = default_types + [_dtypes.int8, _dtypes.uint8]
if (self.inference_input_type not in all_types or
self.inference_output_type not in all_types):
all_types_names = ["tf." + t.name for t in all_types]
raise ValueError("The inference_input_type and inference_output_type "
"must be in {}.".format(all_types_names))
elif (self.inference_input_type not in default_types or
self.inference_output_type not in default_types):
raise ValueError("The inference_input_type and inference_output_type "
"must be tf.float32.")
@convert_phase(Component.PREPARE_TF_MODEL, SubComponent.LOAD_SAVED_MODEL)
def _load_saved_model(self, saved_model_dir, saved_model_tags):
"""Load graph_def from saved model with the default serving signature key.
Args:
saved_model_dir: Directory of the SavedModel.
saved_model_tags: Set of tags identifying the MetaGraphDef within the
SavedModel to analyze.
Returns:
graph_def: The loaded GraphDef.
input_tensors: List of input tensors.
output_tensors: List of output tensors.
"""
graph = _ops.Graph()
saved_model = _loader_impl.SavedModelLoader(saved_model_dir)
saved_model.load_graph(graph, tags=saved_model_tags)
meta_graph = saved_model.get_meta_graph_def_from_tags(saved_model_tags)
graph_def = meta_graph.graph_def
signature_def = meta_graph.signature_def[
_signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
input_tensors = [
graph.get_tensor_by_name(signature_def.inputs[key].name)
for key in signature_def.inputs
]
output_tensors = [
graph.get_tensor_by_name(signature_def.outputs[key].name)
for key in signature_def.outputs
]
return graph_def, input_tensors, output_tensors
@convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS)
def _validate_inputs(self, graph_def, input_tensors):
"""Validate the input parameters.
Args:
graph_def: The TensorFlow GraphDef.
input_tensors: List of input tensors.
Raise:
ValueError:
Input shape is not specified.
Invalid quantization parameters.
"""
# Update conversion params with graph_def.
self._save_conversion_params_metric(graph_def)
self._quant_mode = QuantizationMode(self.optimizations, self.target_spec,
self.representative_dataset, graph_def)
self._validate_inference_input_output_types(self._quant_mode)
self._validate_experimental_new_quantizer_flag()
if not self._is_unknown_shapes_allowed():
# Checks dimensions in input tensor.
for tensor in input_tensors:
# Note that shape_list might be empty for scalar shapes.
shape_list = tensor.shape.as_list()
if None in shape_list[1:]:
raise ValueError(
"None is only supported in the 1st dimension. Tensor '{0}' has "
"invalid shape '{1}'.".format(
_get_tensor_name(tensor), shape_list))
elif shape_list and shape_list[0] is None:
# Set the batch size to 1 if undefined.
shape = tensor.shape.as_list()
shape[0] = 1
tensor.set_shape(shape)
if (self._trackable_obj is None or
not hasattr(self._trackable_obj, "graph_debug_info")):
self._debug_info = _get_debug_info(
_build_debug_info_func(self._funcs[0].graph), graph_def)
else:
self._debug_info = _get_debug_info(
_convert_debug_info_func(self._trackable_obj.graph_debug_info),
graph_def)
@convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL)
def _optimize_tf_model(self, graph_def, input_tensors, output_tensors,
frozen_func):
"""Run a Grappler pass to optimize the TensorFlow graph.
Args:
graph_def: Frozen GraphDef to be optimized.
input_tensors: List of input tensors.
output_tensors: List of output tensors.
frozen_func: TensorFlow Graph.
Returns:
The optimized TensorFlow graph.
"""
grappler_config = self._grappler_config()
# Skip running grappler when there are no optimizers to run. If not,
# grappler will run with the default optimizer set and it will lead to
# causing an unexpected behavior.
if grappler_config.graph_options.rewrite_options.optimizers:
graph_def = _run_graph_optimizations(
graph_def,
input_tensors,
output_tensors,
config=grappler_config,
graph=frozen_func.graph)
return graph_def
def convert(self, graph_def, input_tensors, output_tensors):
"""Converts a TensorFlow GraphDef based on instance variables.
Args:
graph_def: Frozen TensorFlow GraphDef.
input_tensors: List of input tensors.
output_tensors: List of output tensors.
Returns:
The converted data in serialized format.
Raises:
ValueError:
No concrete functions is specified.
Multiple concrete functions are specified.
Input shape is not specified.
Invalid quantization parameters.
"""
self._validate_inputs(graph_def, input_tensors)
converter_kwargs = self._get_base_converter_args()
converter_kwargs.update(self._quant_mode.converter_flags())
if not self.experimental_new_converter:
logging.warning(
"Please consider switching to the new converter by setting "
"experimental_new_converter=True. "
"The old converter (TOCO) is deprecated.")
else:
logging.info("Using new converter: If you encounter a problem "
"please file a bug. You can opt-out "
"by setting experimental_new_converter=False")
# Converts model.
result = _toco_convert_impl(
input_data=graph_def,
input_tensors=input_tensors,
output_tensors=output_tensors,
**converter_kwargs)
return self._optimize_tflite_model(
result, self._quant_mode, quant_io=self.experimental_new_quantizer)
class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
"""Converts the given SavedModel into TensorFlow Lite model.
Attributes:
saved_model_dir: Directory of the SavedModel.
"""
def __init__(self,
saved_model_dir,
saved_model_tags=None,
saved_model_exported_names=None,
trackable_obj=None):
"""Constructor for TFLiteConverter.
Args:
saved_model_dir: Directory of the SavedModel.
saved_model_tags: Set of tags identifying the MetaGraphDef within the
SavedModel to analyze. All tags in the tag set must be present. (default
{tf.saved_model.SERVING}).
saved_model_exported_names: Names to be exported when the saved model
import path is on.
trackable_obj: tf.AutoTrackable object associated with `funcs`. A
reference to this object needs to be maintained so that Variables do not
get garbage collected since functions have a weak reference to
Variables. This is only required when the tf.AutoTrackable object is not
maintained by the user (e.g. `from_saved_model`).
"""
super(TFLiteSavedModelConverterV2, self).__init__()
self.saved_model_dir = saved_model_dir
self._saved_model_tags = saved_model_tags
self._saved_model_exported_names = saved_model_exported_names
self._trackable_obj = trackable_obj
self._parse_saved_model_args(always_enable_saved_model_import=True)
@_export_metrics
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
Returns:
The converted data in serialized format.
Raises:
ValueError:
No concrete functions is specified.
Multiple concrete functions are specified.
Input shape is not specified.
Invalid quantization parameters.
"""
graph_def, input_tensors, output_tensors = self._load_saved_model(
self.saved_model_dir, self._saved_model_tags)
# If we can't use saved model importer, then fallback
# to frozen graph conversion path.
if self.saved_model_dir is None or not self.experimental_new_converter:
graph_def, _, _, _ = _freeze_saved_model(
self.saved_model_dir, None, None, None, self._saved_model_tags,
_signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
# We make sure to clear the saved_model_dir as there is some
# legacy code down in the caller that checks this.
# TODO(b/162537905): Clean these indirect dependencies.
self.saved_model_dir = None
return super(TFLiteSavedModelConverterV2,
self).convert(graph_def, input_tensors, output_tensors)
if self._trackable_obj is None:
self._debug_info = _get_debug_info(
_build_debug_info_func(self._funcs[0].graph), graph_def)
else:
self._debug_info = _get_debug_info(
_convert_debug_info_func(self._trackable_obj.graph_debug_info),
graph_def)
# Update conversion params with graph_def.
self._save_conversion_params_metric(graph_def)
# Get quantization options and do some sanity checks.
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
self.representative_dataset, graph_def)
self._validate_inference_input_output_types(quant_mode)
converter_kwargs = {
"enable_tflite_resource_variables":
self.experimental_enable_resource_variables
}
converter_kwargs.update(self._get_base_converter_args())
converter_kwargs.update(quant_mode.converter_flags())
result = _convert_saved_model(**converter_kwargs)
return self._optimize_tflite_model(
result, quant_mode, quant_io=self.experimental_new_quantizer)
class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
"""Converts the given Keras model into TensorFlow Lite model."""
def __init__(self, keras_model, trackable_obj=None):
"""Constructor for TFLiteConverter.
Args:
keras_model: tf.Keras.Model.
trackable_obj: tf.AutoTrackable object associated with `funcs`. A
reference to this object needs to be maintained so that Variables do not
get garbage collected since functions have a weak reference to
Variables. This is only required when the tf.AutoTrackable object is not
maintained by the user (e.g. `from_saved_model`).
"""
super(TFLiteKerasModelConverterV2, self).__init__()
self._keras_model = keras_model
self._trackable_obj = trackable_obj
self.experimental_lower_to_saved_model = True
@convert_phase(Component.PREPARE_TF_MODEL,
SubComponent.CONVERT_KERAS_TO_SAVED_MODEL)
def _convert_keras_to_saved_model(self, output_dir):
"""Save Keras model to the SavedModel format.
Args:
output_dir: The output directory to save the SavedModel.
Returns:
graph_def: The frozen GraphDef.
input_tensors: List of input tensors.
output_tensors: List of output tensors.
"""
try:
_saved_model.save(
self._keras_model,
output_dir,
options=_save_options.SaveOptions(save_debug_info=True))
except Exception: # pylint: disable=broad-except
# When storing the given keras model to a saved model is failed, let's
# use original keras model conversion pipeline.
return None, None, None
self.saved_model_dir = output_dir
self._saved_model_tags = set([_tag_constants.SERVING])
self._saved_model_exported_names = [
_signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
]
self._parse_saved_model_args(
always_enable_saved_model_import=self.experimental_lower_to_saved_model)
if self.saved_model_dir:
graph_def, input_tensors, output_tensors = self._load_saved_model(
self.saved_model_dir, self._saved_model_tags)
self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags)
return graph_def, input_tensors, output_tensors
return None, None, None
@convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL)
def _freeze_keras_model(self):
"""Freeze Keras model to frozen graph.
Returns:
graph_def: The frozen GraphDef.
input_tensors: List of input tensors.
output_tensors: List of output tensors.
frozen_func: The frozen ConcreteFunction.
"""
input_signature = None
# If the model's call is not a `tf.function`, then we need to first get its
# input signature from `model_input_signature` method. We can't directly
# call `trace_model_call` because otherwise the batch dimension is set
# to None.
# Once we have better support for dynamic shapes, we can remove this.
if not isinstance(self._keras_model.call, _def_function.Function):
# Pass `keep_original_batch_size=True` will ensure that we get an input
# signature including the batch dimension specified by the user.
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
input_signature = _model_input_signature(
self._keras_model, keep_original_batch_size=True)
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
func = _trace_model_call(self._keras_model, input_signature)
concrete_func = func.get_concrete_function()
self._funcs = [concrete_func]
frozen_func, graph_def = (
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
self._funcs[0], lower_control_flow=False))
input_tensors = [
tensor for tensor in frozen_func.inputs
if tensor.dtype != _dtypes.resource
]
output_tensors = frozen_func.outputs
return graph_def, input_tensors, output_tensors, frozen_func
def _convert_as_saved_model(self):
"""Converts a Keras model as a saved model.
Returns:
The converted data in serialized format.
"""
temp_dir = tempfile.mkdtemp()
try:
graph_def, input_tensors, output_tensors = (
self._convert_keras_to_saved_model(temp_dir))
if self.saved_model_dir:
return super(TFLiteKerasModelConverterV2,
self).convert(graph_def, input_tensors, output_tensors)
finally:
shutil.rmtree(temp_dir, True)
@_export_metrics
def convert(self):
"""Converts a keras model based on instance variables.
Returns:
The converted data in serialized format.
Raises:
ValueError:
Multiple concrete functions are specified.
Input shape is not specified.
Invalid quantization parameters.
"""
saved_model_convert_result = self._convert_as_saved_model()
if saved_model_convert_result:
return saved_model_convert_result
graph_def, input_tensors, output_tensors, frozen_func = (
self._freeze_keras_model())
graph_def = self._optimize_tf_model(graph_def, input_tensors,
output_tensors, frozen_func)
return super(TFLiteKerasModelConverterV2,
self).convert(graph_def, input_tensors, output_tensors)
class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
"""Converts the given frozen graph into TensorFlow Lite model."""
def __init__(self, funcs, trackable_obj=None):
"""Constructor for TFLiteConverter.
Args:
funcs: List of TensorFlow ConcreteFunctions. The list should not contain
duplicate elements.
trackable_obj: tf.AutoTrackable object associated with `funcs`. A
reference to this object needs to be maintained so that Variables do not
get garbage collected since functions have a weak reference to
Variables. This is only required when the tf.AutoTrackable object is not
maintained by the user (e.g. `from_saved_model`).
"""
super(TFLiteFrozenGraphConverterV2, self).__init__()
self._funcs = funcs
self._trackable_obj = trackable_obj
self.experimental_lower_to_saved_model = True
@convert_phase(Component.PREPARE_TF_MODEL,
SubComponent.FREEZE_CONCRETE_FUNCTION)
def _freeze_concrete_function(self):
"""Convert the given ConcreteFunction to frozen graph.
Returns:
graph_def: The frozen GraphDef.
input_tensors: List of input tensors.
output_tensors: List of output tensors.
frozen_func: The frozen ConcreteFunction.
Raises:
ValueError: none or multiple ConcreteFunctions provided.
"""
# TODO(b/130297984): Add support for converting multiple function.
if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test
raise ValueError("No ConcreteFunction is specified.")
if len(self._funcs) > 1:
raise ValueError("This converter can only convert a single "
"ConcreteFunction. Converting multiple functions is "
"under development.")
frozen_func, graph_def = (
_convert_to_constants.convert_variables_to_constants_v2_as_graph(
self._funcs[0], lower_control_flow=False))
input_tensors = [
tensor for tensor in frozen_func.inputs
if tensor.dtype != _dtypes.resource
]
output_tensors = frozen_func.outputs
return graph_def, input_tensors, output_tensors, frozen_func
@convert_phase(Component.PREPARE_TF_MODEL,
SubComponent.CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL)
def _convert_concrete_functions_to_saved_model(self, output_dir):
"""Save concrete functions to the SavedModel format.
Args:
output_dir: The output directory to save the SavedModel.
Returns:
graph_def: The frozen GraphDef.
input_tensors: List of input tensors.
output_tensors: List of output tensors.
"""
if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test
raise ValueError("No ConcreteFunction is specified.")
if len(self._funcs) > 1:
raise ValueError("This converter can only convert a single "
"ConcreteFunction. Converting multiple functions is "
"under development.")
func = self._funcs[0]
if not self.experimental_lower_to_saved_model:
return None, None, None
# Without the provided trackable obj, it is not able to serialize the given
# concrete functions as a saved model format.
if not self._trackable_obj:
return None, None, None
try:
_saved_model.save(
self._trackable_obj,
output_dir,
signatures={"serving_default": func},
options=_save_options.SaveOptions(save_debug_info=True))
except Exception: # pylint: disable=broad-except
# When storing the given concrete function to a saved model is failed,
# let's use original concrete function conversion pipeline.
return None, None, None
self.saved_model_dir = output_dir
self._saved_model_tags = set([_tag_constants.SERVING])
self._saved_model_exported_names = [
_signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
]
self._parse_saved_model_args(always_enable_saved_model_import=True)
if self.saved_model_dir:
graph_def, input_tensors, output_tensors = self._load_saved_model(
self.saved_model_dir, self._saved_model_tags)
self._trackable_obj = _load(self.saved_model_dir, self._saved_model_tags)
return graph_def, input_tensors, output_tensors
return None, None, None
def _convert_as_saved_model(self):
"""Converts the given concrete functions as a saved model format.
Returns:
The converted data in serialized format.
"""
temp_dir = tempfile.mkdtemp()
try:
graph_def, input_tensors, output_tensors = (
self._convert_concrete_functions_to_saved_model(temp_dir))
if self.saved_model_dir:
return super(TFLiteFrozenGraphConverterV2,
self).convert(graph_def, input_tensors, output_tensors)
finally:
shutil.rmtree(temp_dir, True)
return None
@_export_metrics
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
Returns:
The converted data in serialized format.
Raises:
ValueError:
No concrete functions is specified.
Multiple concrete functions are specified.
Input shape is not specified.
Invalid quantization parameters.
"""
if self.experimental_lower_to_saved_model:
saved_model_convert_result = self._convert_as_saved_model()
if saved_model_convert_result:
return saved_model_convert_result
graph_def, input_tensors, output_tensors, frozen_func = (
self._freeze_concrete_function())
graph_def = self._optimize_tf_model(graph_def, input_tensors,
output_tensors, frozen_func)
return super(TFLiteFrozenGraphConverterV2,
self).convert(graph_def, input_tensors, output_tensors)
@_tf_export("lite.TFLiteConverter", v1=[])
class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
"""Converts a TensorFlow model into TensorFlow Lite model.
Attributes:
optimizations: Experimental flag, subject to change. Set of optimizations
to apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
set of values of type `tf.lite.Optimize`)
representative_dataset: A generator function used for integer quantization
where each generated sample has the same order, type and shape as the
inputs to the model. Usually, this is a small subset of a few hundred
samples randomly chosen, in no particular order, from the training or
evaluation dataset. This is an optional attribute, but required for full
integer quantization, i.e, if `tf.int8` is the only supported type in
`target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
(default None)
target_spec: Experimental flag, subject to change. Specifications of target
device, including supported ops set, supported types and a set of user's
defined TensorFlow operators required in the TensorFlow Lite runtime.
Refer to `tf.lite.TargetSpec`.
inference_input_type: Data type of the input layer. Note that integer types
(tf.int8 and tf.uint8) are currently only supported for post training
integer quantization and quantization aware training. (default tf.float32,
must be in {tf.float32, tf.int8, tf.uint8})
inference_output_type: Data type of the output layer. Note that integer
types (tf.int8 and tf.uint8) are currently only supported for post
training integer quantization and quantization aware training. (default
tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
allow_custom_ops: Boolean indicating whether to allow custom operations.
When False, any unknown operation is an error. When True, custom ops are
created for any op that is unknown. The developer needs to provide these
to the TensorFlow Lite runtime with a custom resolver. (default False)
experimental_new_converter: Experimental flag, subject to change. Enables
MLIR-based conversion instead of TOCO conversion. (default True)
experimental_new_quantizer: Experimental flag, subject to change. Enables
MLIR-based quantization conversion instead of Flatbuffer-based conversion.
(default True)
experimental_enable_resource_variables: Experimental flag, subject to
change. Enables resource variables to be converted by this converter.
This is only allowed if from_saved_model interface is used.
(default False)
Example usage:
```python
# Converting a SavedModel to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
# Converting a tf.Keras model to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Converting ConcreteFunctions to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model)
tflite_model = converter.convert()
```
"""
# pylint: disable=useless-super-delegation
def __init__(self, funcs, trackable_obj=None):
"""Constructor for TFLiteConverter.
Args:
funcs: List of TensorFlow ConcreteFunctions. The list should not contain
duplicate elements.
trackable_obj: tf.AutoTrackable object associated with `funcs`. A
reference to this object needs to be maintained so that Variables do not
get garbage collected since functions have a weak reference to
Variables. This is only required when the tf.AutoTrackable object is not
maintained by the user (e.g. `from_saved_model`).
"""
super(TFLiteConverterV2, self).__init__(funcs, trackable_obj)
@classmethod
def from_concrete_functions(cls, funcs, trackable_obj=None):
"""Creates a TFLiteConverter object from ConcreteFunctions.
Args:
funcs: List of TensorFlow ConcreteFunctions. The list should not contain
duplicate elements. Currently converter can only convert a single
ConcreteFunction. Converting multiple functions is under development.
trackable_obj: An `AutoTrackable` object (typically `tf.module`)
associated with `funcs`. A reference to this object needs to be
maintained so that Variables do not get garbage collected since
functions have a weak reference to Variables.
Returns:
TFLiteConverter object.
Raises:
Invalid input type.
"""
for func in funcs:
if not isinstance(func, _function.ConcreteFunction):
message = "This function takes in a list of ConcreteFunction."
if isinstance(func, _def_function.Function):
message += (" To get the ConcreteFunction from a Function,"
" call get_concrete_function.")
raise ValueError(message)
return cls(funcs, trackable_obj)
@classmethod
def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None):
"""Creates a TFLiteConverter object from a SavedModel directory.
Args:
saved_model_dir: SavedModel directory to convert.
signature_keys: List of keys identifying SignatureDef containing inputs
and outputs. Elements should not be duplicated. By default the
`signatures` attribute of the MetaGraphdef is used. (default
saved_model.signatures)
tags: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present. (default
{tf.saved_model.SERVING} or {'serve'})
Returns:
TFLiteConverter object.
Raises:
Invalid signature keys.
"""
# When run without eager enabled, this will return the legacy
# TFLiteConverter.
if not context.executing_eagerly():
signature_key = None
if signature_keys:
if len(signature_keys) != 1:
raise ValueError("Only support a single signature key.")
else:
signature_key = signature_keys[0]
logging.warning("Invoking the TF1 implementation of TFLiteConverter "
"because eager is disabled. Consider enabling eager.")
return TFLiteConverter.from_saved_model(
saved_model_dir, signature_key=signature_key, tag_set=tags)
# Ensures any graphs created in Eager mode are able to run. This is required
# in order to create a tf.estimator.Exporter that exports a TFLite model.
if tags is None:
tags = set([_tag_constants.SERVING])
with context.eager_mode():
saved_model = _load(saved_model_dir, tags)
if not signature_keys:
signature_keys = saved_model.signatures
if not signature_keys:
raise ValueError("Only support at least one signature key.")
funcs = []
for key in signature_keys:
if key not in saved_model.signatures:
raise ValueError("Invalid signature key '{}' found. Valid keys are "
"'{}'.".format(key, ",".join(saved_model.signatures)))
funcs.append(saved_model.signatures[key])
saved_model_converter = TFLiteSavedModelConverterV2(saved_model_dir, tags,
signature_keys,
saved_model)
if saved_model_converter.saved_model_dir:
return saved_model_converter
return cls(funcs, saved_model)
@classmethod
def from_keras_model(cls, model):
"""Creates a TFLiteConverter object from a Keras model.
Args:
model: tf.Keras.Model
Returns:
TFLiteConverter object.
"""
return TFLiteKerasModelConverterV2(model)
# pylint: disable=useless-super-delegation
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
Returns:
The converted data in serialized format.
Raises:
ValueError:
No concrete functions is specified.
Multiple concrete functions are specified.
Input shape is not specified.
Invalid quantization parameters.
"""
return super(TFLiteConverterV2, self).convert()
class TFLiteConverterBaseV1(TFLiteConverterBase):
"""Converter subclass to share functionality between V1 converters."""
def __init__(self, experimental_debug_info_func):
"""Constructor for TFLiteConverter.
Args:
experimental_debug_info_func: An experimental function to retrieve the
graph debug info for a set of nodes from the `graph_def`.
"""
super(TFLiteConverterBaseV1, self).__init__()
self.inference_type = _dtypes.float32
self.inference_input_type = None
self.inference_output_type = None
self.output_format = constants.TFLITE
self.quantized_input_stats = {}
self.default_ranges_stats = None
self.drop_control_dependency = True
self.reorder_across_fake_quant = False
self.change_concat_input_ranges = False
self.dump_graphviz_dir = None
self.dump_graphviz_video = False
self.conversion_summary_dir = None
self._debug_info_func = experimental_debug_info_func
self._experimental_allow_all_select_tf_ops = False
def __setattr__(self, name, value):
if name == "post_training_quantize":
warnings.warn("Property %s is deprecated, "
"please use optimizations=[Optimize.DEFAULT]"
" instead." % name)
if value:
self.optimizations = [Optimize.DEFAULT]
else:
self.optimizations = []
return
if name == "target_ops":
warnings.warn("Property %s is deprecated, please use "
"target_spec.supported_ops instead." % name)
self.target_spec.supported_ops = value
return
object.__setattr__(self, name, value)
def __getattribute__(self, name):
if name == "post_training_quantize":
warnings.warn("Property %s is deprecated, "
"please use optimizations=[Optimize.DEFAULT]"
" instead." % name)
return Optimize.DEFAULT in set(self.optimizations)
if name == "target_ops":
warnings.warn("Property %s is deprecated, please use "
"target_spec.supported_ops instead." % name)
return self.target_spec.supported_ops
return object.__getattribute__(self, name)
def _validate_quantized_input_stats(self, converter_kwargs, quant_mode):
"""Ensure the `quantized_input_stats` flag is provided if required."""
quantized_types = frozenset({_dtypes.int8, _dtypes.uint8})
requires_quantized_input_stats = (
(converter_kwargs["inference_type"] in quantized_types or
converter_kwargs["inference_input_type"] in quantized_types) and
not quant_mode.is_post_training_integer_quantize())
if (requires_quantized_input_stats and
not converter_kwargs["quantized_input_stats"]):
raise ValueError(
"The `quantized_input_stats` flag must be defined when either "
"`inference_type` flag or `inference_input_type` flag is set to "
"tf.int8 or tf.uint8. Currently, `inference_type={}` and "
"`inference_input_type={}`.".format(
_get_tf_type_name(converter_kwargs["inference_type"]),
_get_tf_type_name(converter_kwargs["inference_input_type"])))
@convert_phase(Component.PREPARE_TF_MODEL, SubComponent.VALIDATE_INPUTS)
def _validate_inputs(self, input_tensors, quantized_input_stats):
"""Validate input parameters.
Args:
input_tensors: List of input tensors.
quantized_input_stats: Map of input tensor names to a tuple of floats
representing the mean and standard deviation of the training data.
Raises:
ValueError:
Input shape is not specified.
Quantization input stats is required but not provided.
"""
if (not self._is_unknown_shapes_allowed() and self._has_valid_tensors()):
# Checks dimensions in input tensor.
for tensor in input_tensors:
shape = tensor.shape
if not shape:
raise ValueError("Provide an input shape for input array "
"'{0}'.".format(_get_tensor_name(tensor)))
# Note that shape_list might be empty for scalar shapes.
shape_list = shape.as_list()
if None in shape_list[1:]:
raise ValueError(
"None is only supported in the 1st dimension. Tensor '{0}' has "
"invalid shape '{1}'.".format(
_get_tensor_name(tensor), shape_list))
elif shape_list and shape_list[0] is None:
self._set_batch_size(batch_size=1)
# Get quantization stats. Ensures there is one stat per name if the stats
# are specified.
if quantized_input_stats:
self._quantized_stats = []
invalid_stats = []
for name in self.get_input_arrays():
if name in quantized_input_stats:
self._quantized_stats.append(quantized_input_stats[name])
else:
invalid_stats.append(name)
if invalid_stats:
raise ValueError("Quantization input stats are not available for input "
"tensors '{0}'.".format(",".join(invalid_stats)))
else:
self._quantized_stats = None
@convert_phase(Component.PREPARE_TF_MODEL, SubComponent.OPTIMIZE_TF_MODEL)
def _optimize_tf_model(self, graph_def, input_tensors, output_tensors,
quant_mode):
"""Run a Grappler pass to optimize the TensorFlow graph.
Args:
graph_def: Frozen GraphDef to be optimized.
input_tensors: List of input tensors.
output_tensors: List of output tensors.
quant_mode: the quantization mode.
Returns:
The optimized TensorFlow graph.
"""
# Disable grappler constant folding if there are training quant ops.
if self.saved_model_dir or quant_mode.contains_training_quant_op():
return graph_def
try:
# TODO(b/150163103): Merge `disabling lower using switch merge' calls.
# Grappler will also try to lower while loop into switch merge
# representation which is undesired for Ophints, so we simply remove
# those attributes to prevent Grappler from doing so.
graph = _convert_to_constants.disable_lower_using_switch_merge(graph_def)
# Run function inlining optimization to ensure any models generated
# through the from_frozen_graph path have been inlined.
optimized_graph = _run_graph_optimizations(
graph,
input_tensors,
output_tensors,
config=self._grappler_config(["function"]))
return optimized_graph
except Exception: # pylint: disable=broad-except
return graph_def
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
Returns:
The converted data in serialized format. Either a TFLite Flatbuffer or a
Graphviz graph depending on value in `output_format`.
Raises:
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
"""
self._validate_inputs(self._input_tensors, self.quantized_input_stats)
quant_mode = QuantizationMode(self.optimizations, self.target_spec,
self.representative_dataset, self._graph_def)
optimized_graph = self._optimize_tf_model(self._graph_def,
self._input_tensors,
self._output_tensors, quant_mode)
self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
converter_kwargs = self._get_base_converter_args()
converter_kwargs.update(
quant_mode.converter_flags(self.inference_type,
self.inference_input_type))
converter_kwargs.update({
"output_format": self.output_format,
"quantized_input_stats": self._quantized_stats,
"default_ranges_stats": self.default_ranges_stats,
"drop_control_dependency": self.drop_control_dependency,
"reorder_across_fake_quant": self.reorder_across_fake_quant,
"change_concat_input_ranges": self.change_concat_input_ranges,
"dump_graphviz_dir": self.dump_graphviz_dir,
"dump_graphviz_video": self.dump_graphviz_video,
"conversion_summary_dir": self.conversion_summary_dir,
"allow_all_select_tf_ops": self._experimental_allow_all_select_tf_ops,
})
if not self.experimental_new_converter:
logging.warning(
"Please consider switching to the new converter by setting "
"experimental_new_converter=True. "
"The old converter (TOCO) is deprecated.")
else:
logging.info("Using experimental converter: If you encountered a problem "
"please file a bug. You can opt-out "
"by setting experimental_new_converter=False")
self._validate_quantized_input_stats(converter_kwargs, quant_mode)
self._validate_experimental_new_quantizer_flag()
# Converts model.
if self._has_valid_tensors():
result = _toco_convert_impl(
input_data=optimized_graph,
input_tensors=self._input_tensors,
output_tensors=self._output_tensors,
**converter_kwargs)
else:
result = _toco_convert_graph_def(
input_data=optimized_graph,
input_arrays_with_shape=self._input_arrays_with_shape,
output_arrays=self._output_arrays,
control_output_arrays=self._control_output_arrays,
**converter_kwargs)
return self._optimize_tflite_model(
result, quant_mode, quant_io=not self.experimental_new_converter)
def get_input_arrays(self):
"""Returns a list of the names of the input tensors.
Returns:
List of strings.
"""
if self._has_valid_tensors():
return [_get_tensor_name(tensor) for tensor in self._input_tensors]
else:
return [name for name, _ in self._input_arrays_with_shape]
def _has_valid_tensors(self):
"""Checks if the input and output tensors have been initialized.
Returns:
Bool.
"""
return self._input_tensors is not None and self._output_tensors
def _set_batch_size(self, batch_size):
"""Sets the first dimension of the input tensor to `batch_size`.
Args:
batch_size: Batch size for the model. Replaces the first dimension of an
input size array if undefined. (default 1)
Raises:
ValueError: input_tensor is not defined.
"""
if not self._has_valid_tensors():
raise ValueError("The batch size cannot be set for this model. Please "
"use input_shapes parameter.")
for tensor in self._input_tensors:
shape = tensor.shape.as_list()
if shape[0] is None:
shape[0] = batch_size
tensor.set_shape(shape)
def _is_unknown_shapes_allowed(self):
# Ophint Converted nodes will need the shapes to be known.
if _is_ophint_converted(self._graph_def):
return False
if not super(TFLiteConverterBaseV1, self)._is_unknown_shapes_allowed():
return False
# `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by
# the MLIR converter.
if self.conversion_summary_dir:
logging.warning(
"`conversion_summary_dir` does not work with unknown shapes. "
"Graphs with unknown shapes might be different than when this flag "
"is disabled.")
return False
return True
def _save_conversion_params_metric(self):
self._collected_converter_params.update({
"output_format": self.output_format,
"default_ranges_stats": self.default_ranges_stats,
"drop_control_dependency": self.drop_control_dependency,
"reorder_across_fake_quant": self.reorder_across_fake_quant,
"change_concat_input_ranges": self.change_concat_input_ranges,
"dump_graphviz_dir": self.dump_graphviz_dir,
"dump_graphviz_video": self.dump_graphviz_video,
"conversion_summary_dir": self.conversion_summary_dir,
"api_version": 1,
})
super(TFLiteConverterBaseV1,
self)._save_conversion_params_metric(self._graph_def,
self.inference_type,
self.inference_input_type)
class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
"""Converts the given SavedModel into TensorFlow Lite model.
Attributes:
saved_model_dir: Directory of the SavedModel.
"""
def __init__(self,
saved_model_dir,
saved_model_tags,
saved_model_exported_names,
experimental_debug_info_func=None):
"""Constructor for TFLiteConverter.
Args:
saved_model_dir: Directory of the SavedModel.
saved_model_tags: Set of tags identifying the MetaGraphDef within the
SavedModel to analyze. All tags in the tag set must be present. (default
{tf.saved_model.SERVING}).
saved_model_exported_names: Names to be exported when the saved model
import path is on.
experimental_debug_info_func: An experimental function to retrieve the
graph debug info for a set of nodes from the `graph_def`.
Raises:
ValueError: Invalid arguments.
"""
super(TFLiteSavedModelConverter,
self).__init__(experimental_debug_info_func)
self.saved_model_dir = saved_model_dir
self._saved_model_tags = saved_model_tags
self._saved_model_exported_names = saved_model_exported_names
signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
if len(self._saved_model_exported_names) != 1:
raise ValueError("Only support a single signature key.")
signature_key = self._saved_model_exported_names[0]
result = _freeze_saved_model(self.saved_model_dir, None, None, None,
self._saved_model_tags, signature_key)
self._graph_def = result[0]
self._input_tensors = result[1]
self._output_tensors = result[2]
self._parse_saved_model_args()
@_export_metrics
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
Returns:
The converted data in serialized format. Either a TFLite Flatbuffer or a
Graphviz graph depending on value in `output_format`.
Raises:
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
"""
return super(TFLiteSavedModelConverter, self).convert()
class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
"""Converts the given SavedModel into TensorFlow Lite model."""
def __init__(self,
model_file,
input_arrays=None,
input_shapes=None,
output_arrays=None,
custom_objects=None):
"""Constructor for TFLiteConverter.
Args:
model_file: Full filepath of HDF5 file containing the tf.keras model.
input_arrays: List of input tensors to freeze graph with. Uses input
arrays from SignatureDef when none are provided. (default None)
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
custom_objects: Dict mapping names (strings) to custom classes or
functions to be considered during model deserialization. (default None)
Raises:
ValueError: Invalid arguments.
"""
super(TFLiteKerasModelConverter,
self).__init__(experimental_debug_info_func=None)
# Handles Keras when Eager mode is enabled.
if context.executing_eagerly():
if input_arrays or output_arrays:
raise ValueError("`input_arrays` and `output_arrays` are unsupported "
"with Eager mode. If your model requires any of these "
"parameters, please use disable_eager_execution().")
keras_model = keras_deps.get_load_model_function()(model_file,
custom_objects)
function = _trace_model_call(keras_model)
concrete_func = function.get_concrete_function()
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
concrete_func, lower_control_flow=False)
_set_tensor_shapes(frozen_func.inputs, input_shapes)
self._keras_model = keras_model
self._graph_def = frozen_func.graph.as_graph_def()
self._input_tensors = frozen_func.inputs
self._output_tensors = frozen_func.outputs
self._debug_info_func = _build_debug_info_func(frozen_func.graph)
return
# Handles Keras when Eager mode is disabled.
keras_deps.get_clear_session_function()()
keras_model = keras_deps.get_load_model_function()(model_file,
custom_objects)
sess = keras_deps.get_get_session_function()()
# Get input and output tensors.
if input_arrays:
input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
else:
input_tensors = keras_model.inputs
if output_arrays:
output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
else:
output_tensors = keras_model.outputs
_set_tensor_shapes(input_tensors, input_shapes)
graph_def = _freeze_graph(sess, input_tensors, output_tensors)
self._keras_model = keras_model
self._graph_def = graph_def
self._input_tensors = input_tensors
self._output_tensors = output_tensors
self._debug_info_func = _build_debug_info_func(sess.graph)
@convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_KERAS_MODEL)
def _freeze_keras_model(self, output_dir):
"""Save Keras model to Saved Model format.
Args:
output_dir: The output directory to save the SavedModel.
"""
try:
self._keras_model.save(output_dir, save_format="tf")
except Exception: # pylint: disable=broad-except
# When storing the given keras model to a saved model is failed, let's
# use original keras model conversion pipeline.
return None
tag_set = set([_tag_constants.SERVING])
signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
graph_def, input_tensors, output_tensors, sess_graph = _freeze_saved_model(
output_dir, None, None, None, tag_set, signature_key)
self.saved_model_dir = output_dir
self._saved_model_tags = tag_set
self._saved_model_exported_names = [signature_key]
self._parse_saved_model_args()
if self.saved_model_dir:
self._graph_def = graph_def
self._input_tensors = input_tensors
self._output_tensors = output_tensors
self._debug_info_func = _build_debug_info_func(sess_graph)
def _convert_as_saved_model(self):
"""Converts a Keras model as a saved model.
Returns:
The converted data in serialized format.
"""
temp_dir = tempfile.mkdtemp()
try:
self._freeze_keras_model(temp_dir)
if self.saved_model_dir:
return super(TFLiteKerasModelConverter, self).convert()
finally:
shutil.rmtree(temp_dir, True)
@_export_metrics
def convert(self):
"""Converts a Keras model based on instance variables.
Returns:
The converted data in serialized format. Either a TFLite Flatbuffer or a
Graphviz graph depending on value in `output_format`.
Raises:
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
"""
saved_model_convert_result = self._convert_as_saved_model()
if saved_model_convert_result:
return saved_model_convert_result
return super(TFLiteKerasModelConverter, self).convert()
class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1):
"""Converts the given frozen graph def into TensorFlow Lite model."""
def __init__(self,
graph_def,
input_tensors,
output_tensors,
input_arrays_with_shape=None,
output_arrays=None,
experimental_debug_info_func=None):
"""Constructor for TFLiteConverter.
Args:
graph_def: Frozen TensorFlow GraphDef.
input_tensors: List of input tensors. Type and shape are computed using
`foo.shape` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
input_arrays_with_shape: Tuple of strings representing input tensor names
and list of integers representing input shapes
(e.g., [("foo", [1, 16, 16, 3])]). Use only when graph cannot be loaded
into TensorFlow and when `input_tensors` and `output_tensors` are
None. (default None)
output_arrays: List of output tensors to freeze graph with. Use only when
graph cannot be loaded into TensorFlow and when `input_tensors` and
`output_tensors` are None. (default None)
experimental_debug_info_func: An experimental function to retrieve the
graph debug info for a set of nodes from the `graph_def`.
Raises:
ValueError: Invalid arguments.
"""
super(TFLiteFrozenGraphConverter,
self).__init__(experimental_debug_info_func)
self._graph_def = graph_def
self._input_tensors = input_tensors
self._output_tensors = output_tensors
self._control_output_arrays = None
# Attributes are used by models that cannot be loaded into TensorFlow.
if not self._has_valid_tensors():
self._input_arrays_with_shape = input_arrays_with_shape
self._output_arrays = output_arrays
if input_tensors is not None and input_arrays_with_shape is not None:
logging.warning("input_arrays_with_shape will be ignored when both the "
"given input_tensors and input_arrays_with_shape are not "
"None.")
if output_tensors is not None and output_arrays is not None:
logging.warning("output_arrays will be ignored when both the given "
"output_tensors and output_arrays are not None.")
@_export_metrics
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
Returns:
The converted data in serialized format. Either a TFLite Flatbuffer or a
Graphviz graph depending on value in `output_format`.
Raises:
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
"""
if not self._has_valid_tensors():
if not self._input_arrays_with_shape or not (self._output_arrays or
self._control_output_arrays):
raise ValueError(
"If input_tensors and output_tensors are None, both "
"input_arrays_with_shape and output_arrays|control_output_arrays "
"must be defined.")
return super(TFLiteFrozenGraphConverter, self).convert()
@_tf_export(v1=["lite.TFLiteConverter"])
class TFLiteConverter(TFLiteFrozenGraphConverter):
"""Convert a TensorFlow model into `output_format`.
This is used to convert from a TensorFlow GraphDef, SavedModel or tf.keras
model into either a TFLite FlatBuffer or graph visualization.
Attributes:
optimizations: Experimental flag, subject to change. Set of optimizations to
apply. e.g {tf.lite.Optimize.DEFAULT}. (default None, must be None or a
set of values of type `tf.lite.Optimize`)
representative_dataset: A generator function used for integer quantization
where each generated sample has the same order, type and shape as the
inputs to the model. Usually, this is a small subset of a few hundred
samples randomly chosen, in no particular order, from the training or
evaluation dataset. This is an optional attribute, but required for full
integer quantization, i.e, if `tf.int8` is the only supported type in
`target_spec.supported_types`. Refer to `tf.lite.RepresentativeDataset`.
(default None)
target_spec: Experimental flag, subject to change. Specifications of target
device, including supported ops set, supported types and a set of user's
defined TensorFlow operators required in the TensorFlow Lite runtime.
Refer to `tf.lite.TargetSpec`.
inference_type: Data type of numeric arrays, excluding the input layer.
(default tf.float32, must be in {tf.float32, tf.int8, tf.uint8})
inference_input_type: Data type of the numeric arrays in the input layer. If
`inference_input_type` is in {tf.int8, tf.uint8}, then
`quantized_input_stats` must be provided. (default is the value assigned
to `inference_type`, must be in {tf.float32, tf.int8, tf.uint8})
inference_output_type: Data type of the numeric arrays in the output layer.
(default is the value assigned to `inference_type`, must be in
{tf.float32, tf.int8, tf.uint8})
quantized_input_stats: Map of input tensor names to a tuple of floats
representing the mean and standard deviation of the training data.
(e.g., {"foo" : (0., 1.)}). Required if `inference_input_type` is tf.int8
or tf.uint8. (default None)
default_ranges_stats: Tuple of integers (min, max) representing range values
for all numeric arrays without a specified range. Intended for
experimenting with quantization via "dummy quantization". (default None)
allow_custom_ops: Boolean indicating whether to allow custom operations.
When False any unknown operation is an error. When True, custom ops are
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver. (default
False)
drop_control_dependency: Boolean indicating whether to drop control
dependencies silently. This is due to TFLite not supporting control
dependencies. (default True)
reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
nodes in unexpected locations. Used when the location of the FakeQuant
nodes is preventing graph transformations necessary to convert the graph.
Results in a graph that differs from the quantized training graph,
potentially causing differing arithmetic behavior. (default False)
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
inputs and outputs of the concat operator for quantized models. Changes
the ranges of concat operator overlap when true. (default False)
output_format: Output file format. (default
tf.compat.v1.lite.constants.TFLITE, must be in
{tf.compat.v1.lite.constants.TFLITE,
tf.compat.v1.lite.constants.GRAPHVIZ_DOT})
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
`output_format=tf.compat.v1.lite.constants.GRAPHVIZ_DOT` in order to keep
the requirements of the output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the GraphViz .dot
files after every graph transformation. Requires the `dump_graphviz_dir`
flag to be specified. (default False)
conversion_summary_dir: Full path of the directory to store conversion logs.
(default None)
target_ops: Deprecated. Please use `target_spec.supported_ops` instead.
post_training_quantize: Deprecated. Please use `optimizations` instead and
set it to `{tf.lite.Optimize.DEFAULT}`. (default False)
experimental_new_converter: Experimental flag, subject to change. Enables
MLIR-based conversion instead of TOCO conversion. (default True)
experimental_new_quantizer: Experimental flag, subject to change. Enables
MLIR-based quantization conversion instead of Flatbuffer-based conversion.
(default True)
Example usage:
```python
# Converting a GraphDef from session.
converter = tf.compat.v1.lite.TFLiteConverter.from_session(
sess, in_tensors, out_tensors)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
# Converting a GraphDef from file.
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
# Converting a SavedModel.
converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
saved_model_dir)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
# Converting a tf.keras model.
converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
keras_model)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
```
"""
# pylint: disable=useless-super-delegation
def __init__(self,
graph_def,
input_tensors,
output_tensors,
input_arrays_with_shape=None,
output_arrays=None,
experimental_debug_info_func=None):
"""Constructor for TFLiteConverter.
Args:
graph_def: Frozen TensorFlow GraphDef.
input_tensors: List of input tensors. Type and shape are computed using
`foo.shape` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
input_arrays_with_shape: Tuple of strings representing input tensor names
and list of integers representing input shapes
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
into TensorFlow and when `input_tensors` and `output_tensors` are
None. (default None)
output_arrays: List of output tensors to freeze graph with. Use only when
graph cannot be loaded into TensorFlow and when `input_tensors` and
`output_tensors` are None. (default None)
experimental_debug_info_func: An experimental function to retrieve the
graph debug info for a set of nodes from the `graph_def`.
Raises:
ValueError: Invalid arguments.
"""
super(TFLiteConverter,
self).__init__(graph_def, input_tensors, output_tensors,
input_arrays_with_shape, output_arrays,
experimental_debug_info_func)
@classmethod
def from_session(cls, sess, input_tensors, output_tensors):
"""Creates a TFLiteConverter class from a TensorFlow Session.
Args:
sess: TensorFlow Session.
input_tensors: List of input tensors. Type and shape are computed using
`foo.shape` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
Returns:
TFLiteConverter class.
"""
graph_def = _freeze_graph(sess, input_tensors, output_tensors)
return cls(
graph_def,
input_tensors,
output_tensors,
experimental_debug_info_func=_build_debug_info_func(sess.graph))
@classmethod
def from_frozen_graph(cls,
graph_def_file,
input_arrays,
output_arrays,
input_shapes=None):
"""Creates a TFLiteConverter class from a file containing a frozen GraphDef.
Args:
graph_def_file: Full filepath of file containing frozen GraphDef.
input_arrays: List of input tensors to freeze graph with.
output_arrays: List of output tensors to freeze graph with.
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
None}). (default None)
Returns:
TFLiteConverter class.
Raises:
IOError:
File not found.
Unable to parse input file.
ValueError:
The graph is not frozen.
input_arrays or output_arrays contains an invalid tensor name.
input_shapes is not correctly defined when required
"""
with _ops.Graph().as_default():
with _session.Session() as sess:
# Read GraphDef from file.
if not gfile.Exists(graph_def_file):
raise IOError("File '{0}' does not exist.".format(graph_def_file))
with gfile.GFile(graph_def_file, "rb") as f:
file_content = f.read()
try:
graph_def = _graph_pb2.GraphDef()
graph_def.ParseFromString(file_content)
except (_text_format.ParseError, DecodeError):
try:
print("Ignore 'tcmalloc: large alloc' warnings.")
if not isinstance(file_content, str):
if PY2:
file_content = six.ensure_binary(file_content, "utf-8")
else:
file_content = six.ensure_text(file_content, "utf-8")
graph_def = _graph_pb2.GraphDef()
_text_format.Merge(file_content, graph_def)
except (_text_format.ParseError, DecodeError):
raise IOError(
"Unable to parse input file '{}'.".format(graph_def_file))
# Handles models with custom TFLite ops that cannot be resolved in
# TensorFlow.
load_model_in_session = True
try:
_import_graph_def(graph_def, name="")
except _NotFoundError:
load_model_in_session = False
if load_model_in_session:
# Check if graph is frozen.
if not _is_frozen_graph(sess):
raise ValueError("Please freeze the graph using freeze_graph.py.")
# Get input and output tensors.
input_tensors = _get_tensors_from_tensor_names(
sess.graph, input_arrays)
output_tensors = _get_tensors_from_tensor_names(
sess.graph, output_arrays)
_set_tensor_shapes(input_tensors, input_shapes)
return cls(sess.graph_def, input_tensors, output_tensors)
else:
if not input_shapes:
raise ValueError("input_shapes must be defined for this model.")
if set(input_arrays) != set(input_shapes.keys()):
raise ValueError("input_shapes must contain a value for each item "
"in input_array.")
input_arrays_with_shape = [
(name, input_shapes[name]) for name in input_arrays
]
return cls(
graph_def,
input_tensors=None,
output_tensors=None,
input_arrays_with_shape=input_arrays_with_shape,
output_arrays=output_arrays)
@classmethod
def from_saved_model(cls,
saved_model_dir,
input_arrays=None,
input_shapes=None,
output_arrays=None,
tag_set=None,
signature_key=None):
"""Creates a TFLiteConverter class from a SavedModel.
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input
arrays from SignatureDef when none are provided. (default None)
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present. (default
{tf.saved_model.SERVING})
signature_key: Key identifying SignatureDef containing inputs and outputs.
(default tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
Returns:
TFLiteConverter class.
"""
if tag_set is None:
tag_set = set([_tag_constants.SERVING])
if signature_key is None:
signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
saved_model_converter = TFLiteSavedModelConverter(saved_model_dir, tag_set,
[signature_key])
if saved_model_converter.saved_model_dir:
return saved_model_converter
result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key)
return cls(
graph_def=result[0],
input_tensors=result[1],
output_tensors=result[2],
experimental_debug_info_func=_build_debug_info_func(result[3]))
@classmethod
def from_keras_model_file(cls,
model_file,
input_arrays=None,
input_shapes=None,
output_arrays=None,
custom_objects=None):
"""Creates a TFLiteConverter class from a tf.keras model file.
Args:
model_file: Full filepath of HDF5 file containing the tf.keras model.
input_arrays: List of input tensors to freeze graph with. Uses input
arrays from SignatureDef when none are provided. (default None)
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
custom_objects: Dict mapping names (strings) to custom classes or
functions to be considered during model deserialization. (default None)
Returns:
TFLiteConverter class.
"""
return TFLiteKerasModelConverter(model_file, input_arrays, input_shapes,
output_arrays, custom_objects)
# pylint: disable=useless-super-delegation
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
Returns:
The converted data in serialized format. Either a TFLite Flatbuffer or a
Graphviz graph depending on value in `output_format`.
Raises:
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
"""
return super(TFLiteConverter, self).convert()
@_tf_export(v1=["lite.TocoConverter"])
class TocoConverter(object):
"""Convert a TensorFlow model into `output_format` using TOCO.
This class has been deprecated. Please use `lite.TFLiteConverter` instead.
"""
@classmethod
@_deprecation.deprecated(None,
"Use `lite.TFLiteConverter.from_session` instead.")
def from_session(cls, sess, input_tensors, output_tensors):
"""Creates a TocoConverter class from a TensorFlow Session."""
return TFLiteConverter.from_session(sess, input_tensors, output_tensors)
@classmethod
@_deprecation.deprecated(
None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.")
def from_frozen_graph(cls,
graph_def_file,
input_arrays,
output_arrays,
input_shapes=None):
"""Creates a TocoConverter class from a file containing a frozen graph."""
return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays,
output_arrays, input_shapes)
@classmethod
@_deprecation.deprecated(
None, "Use `lite.TFLiteConverter.from_saved_model` instead.")
def from_saved_model(cls,
saved_model_dir,
input_arrays=None,
input_shapes=None,
output_arrays=None,
tag_set=None,
signature_key=None):
"""Creates a TocoConverter class from a SavedModel."""
return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays,
input_shapes, output_arrays,
tag_set, signature_key)
@classmethod
@_deprecation.deprecated(
None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
def from_keras_model_file(cls,
model_file,
input_arrays=None,
input_shapes=None,
output_arrays=None):
"""Creates a TocoConverter class from a tf.keras model file."""
return TFLiteConverter.from_keras_model_file(model_file, input_arrays,
input_shapes, output_arrays)