blob: 79d2775d1dc81d0b1ef4fbbd416603637dc6172e [file] [log] [blame]
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions used by multiple converter files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import datetime
import sys
from absl import logging
import six
from six.moves import range
import flatbuffers
from tensorflow.core.protobuf import config_pb2 as _config_pb2
from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.lite.python import lite_constants as _lite_constants
from tensorflow.lite.python import schema_py_generated as schema_fb
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
from tensorflow.lite.toco import types_pb2 as _types_pb2
from tensorflow.python.eager import function
from tensorflow.python.framework import convert_to_constants as _convert_to_constants
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation as _error_interpolation
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
# Map of tf.dtypes to TFLite types_flag_pb2.
_MAP_TF_TO_TFLITE_TYPES = {
dtypes.float32: _types_pb2.FLOAT,
dtypes.float16: _types_pb2.FLOAT16,
dtypes.int32: _types_pb2.INT32,
dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
dtypes.int64: _types_pb2.INT64,
dtypes.string: _types_pb2.STRING,
dtypes.bool: _types_pb2.BOOL,
dtypes.int16: _types_pb2.QUANTIZED_INT16,
dtypes.complex64: _types_pb2.COMPLEX64,
dtypes.int8: _types_pb2.INT8,
dtypes.float64: _types_pb2.FLOAT64,
dtypes.complex128: _types_pb2.COMPLEX128,
}
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
0: dtypes.float32,
1: dtypes.float16,
2: dtypes.int32,
3: dtypes.uint8,
4: dtypes.int64,
5: dtypes.string,
6: dtypes.bool,
7: dtypes.int16,
8: dtypes.complex64,
9: dtypes.int8,
10: dtypes.float64,
11: dtypes.complex128,
}
_TFLITE_FILE_IDENTIFIER = b"TFL3"
_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (_lite_constants.FLOAT, _lite_constants.INT8,
_lite_constants.QUANTIZED_UINT8)
def convert_dtype_to_tflite_type(tf_dtype):
"""Converts tf.dtype to TFLite proto type.
Args:
tf_dtype: tf.dtype
Raises:
ValueError: Unsupported tf.dtype.
Returns:
types_flag_pb2.
"""
result = _MAP_TF_TO_TFLITE_TYPES.get(tf_dtype)
if result is None:
raise ValueError("Unsupported tf.dtype {0}".format(tf_dtype))
return result
def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
"""Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
Args:
tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
Raises:
ValueError: If an invalid tflite enum type is provided.
Returns:
tf type (eg: tf.float32)
"""
tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
if tf_type is None:
raise ValueError(
"Unsupported enum {}. The valid map of enum to tf types is : {}"
.format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
return tf_type
def _get_dtype_name(tf_type):
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
return "tf." + tf_type.name
def get_tensor_name(tensor):
"""Returns name of the input tensor.
Args:
tensor: tf.Tensor
Returns:
str
"""
parts = six.ensure_str(tensor.name).split(":")
if len(parts) > 2:
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
len(parts) - 1))
# To be consistent with the tensor naming scheme in tensorflow, we need
# drop the ':0' suffix for the first tensor.
if len(parts) > 1 and parts[1] != "0":
return tensor.name
return parts[0]
def get_tensors_from_tensor_names(graph, tensor_names):
"""Gets the Tensors associated with the `tensor_names` in the provided graph.
Args:
graph: TensorFlow Graph.
tensor_names: List of strings that represent names of tensors in the graph.
Returns:
A list of Tensor objects in the same order the names are provided.
Raises:
ValueError:
tensor_names contains an invalid tensor name.
"""
# Get the list of all of the tensors.
tensor_name_to_tensor = {}
for op in graph.get_operations():
for tensor in op.values():
tensor_name_to_tensor[get_tensor_name(tensor)] = tensor
# Get the tensors associated with tensor_names.
tensors = []
invalid_tensors = []
for name in tensor_names:
if not isinstance(name, six.string_types):
raise ValueError("Invalid type for a tensor name in the provided graph. "
"Expected type for a tensor name is 'str', instead got "
"type '{}' for tensor name '{}'".format(
type(name), name))
tensor = tensor_name_to_tensor.get(name)
if tensor is None:
invalid_tensors.append(name)
else:
tensors.append(tensor)
# Throw ValueError if any user input names are not valid tensors.
if invalid_tensors:
raise ValueError("Invalid tensors '{}' were found.".format(
",".join(invalid_tensors)))
return tensors
def set_tensor_shapes(tensors, shapes):
"""Sets Tensor shape for each tensor if the shape is defined.
Args:
tensors: TensorFlow ops.Tensor.
shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
Raises:
ValueError:
`shapes` contains an invalid tensor.
`shapes` contains an invalid shape for a valid tensor.
"""
if shapes:
tensor_names_to_tensor = {
get_tensor_name(tensor): tensor for tensor in tensors
}
for name, shape in shapes.items():
if name not in tensor_names_to_tensor:
raise ValueError("Invalid tensor \'{}\' found in tensor shapes "
"map.".format(name))
if shape is not None:
tensor = tensor_names_to_tensor[name]
try:
tensor.set_shape(shape)
except ValueError as error:
message = ("The shape of tensor '{0}' cannot be changed from {1} to "
"{2}. {3}".format(name, tensor.shape, shape, str(error)))
raise ValueError(message)
def get_grappler_config(optimizers_list):
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
Args:
optimizers_list: List of strings that represents the list of optimizers.
Returns:
tf.ConfigProto.
"""
config = _config_pb2.ConfigProto()
rewrite_options = config.graph_options.rewrite_options
for optimizer in optimizers_list:
rewrite_options.optimizers.append(optimizer)
return config
def run_graph_optimizations(graph_def,
input_arrays,
output_arrays,
config,
graph=None):
"""Apply standard TensorFlow optimizations to the graph_def.
Args:
graph_def: Frozen GraphDef to be optimized.
input_arrays: List of arrays that are considered inputs of the graph.
output_arrays: List of arrays that are considered outputs of the graph.
config: tf.ConfigProto.
graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)
Returns:
A new, optimized GraphDef.
"""
meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)
signature = _meta_graph_pb2.SignatureDef()
for array in input_arrays:
signature.inputs[array.name].name = array.name
signature.inputs[array.name].dtype = array.dtype.as_datatype_enum
signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
for array in output_arrays:
signature.outputs[array.name].name = array.name
signature.outputs[array.name].dtype = array.dtype.as_datatype_enum
signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
meta_graph.signature_def["not_used_key"].CopyFrom(signature)
# We need to add a collection called 'train_op' so that grappler
# knows what the outputs are.
fetch_collection = _meta_graph_pb2.CollectionDef()
for array in input_arrays + output_arrays:
fetch_collection.node_list.value.append(array.name)
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
return tf_optimizer.OptimizeGraph(config, meta_graph)
def _convert_op_hints_if_present(sess, graph_def, output_tensors,
hinted_outputs_nodes):
if is_frozen_graph(sess):
raise ValueError("Try to convert op hints, needs unfrozen graph.")
output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
graph_def = tf_graph_util.convert_variables_to_constants(
sess, graph_def, output_arrays + hinted_outputs_nodes)
graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
return graph_def
def freeze_graph(sess, input_tensors, output_tensors):
"""Returns a frozen GraphDef.
Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
existing GraphDef is returned. The Grappler pass is only run on models that
are frozen in order to inline the functions in the graph.
If OpHints is present, it will try to convert the OpHint graph.
Args:
sess: TensorFlow Session.
input_tensors: List of input tensors.
output_tensors: List of output tensors (only .name is used from this).
Returns:
Frozen GraphDef.
"""
# Runs a Grappler pass in order to inline any functions in the graph.
# Asides from inlining any simple function, Grappler will also try to lower
# while loop into switch merge representation which is undesired for Ophints,
# so we simply remove those attributes to prevent Grappler from doing so.
graph_def = _convert_to_constants.disable_lower_using_switch_merge(
sess.graph_def)
config = get_grappler_config(["function"])
graph_def = run_graph_optimizations(
graph_def, input_tensors, output_tensors, config, graph=sess.graph)
# If ophints are present, just convert them.
hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
if hinted_outputs_nodes:
return _convert_op_hints_if_present(sess, graph_def, output_tensors,
hinted_outputs_nodes)
if not is_frozen_graph(sess):
output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors]
return tf_graph_util.convert_variables_to_constants(sess, graph_def,
output_node_names)
else:
return sess.graph_def
def is_frozen_graph(sess):
"""Determines if the graph is frozen.
Determines if a graph has previously been frozen by checking for any
operations of type Variable*. If variables are found, the graph is not frozen.
Args:
sess: TensorFlow Session.
Returns:
Bool.
"""
for op in sess.graph.get_operations():
if six.ensure_str(op.type).startswith("Variable") or six.ensure_str(
op.type).endswith("VariableOp"):
return False
return True
def build_debug_info_func(original_graph):
"""Returns a method to retrieve the `GraphDebugInfo` from the original graph.
Args:
original_graph: The original `Graph` containing all the op stack traces.
Returns:
A function which retrieves the stack traces from the original graph and
converts them to a `GraphDebugInfo` for a given set of nodes.
"""
def f(original_nodes):
"""Function to create `GraphDebugInfo` for the given `original_nodes`."""
if not original_graph:
return None
# For the given nodes, gets all the op definitions in the original graph.
useful_ops = []
for func, name in original_nodes:
try:
if not func:
useful_ops.append((func, original_graph.get_operation_by_name(name)))
else:
sub_func = original_graph._get_function(func) # pylint: disable=protected-access
if isinstance(sub_func, function._EagerDefinedFunction): # pylint: disable=protected-access
useful_ops.append(
(func, sub_func.graph.get_operation_by_name(name)))
else:
sys.stderr.write(
"Use '@tf.function' or '@defun' to decorate the function.")
continue
except KeyError:
# New node created by graph optimizer. No stack trace from source code.
continue
# Convert all the op definitions to stack traces in terms of GraphDebugInfo.
return _error_interpolation.create_graph_debug_info_def(useful_ops)
return f
def convert_debug_info_func(saved_debug_info):
"""Returns a method to retrieve the `GraphDebugInfo` from the original graph.
Args:
saved_debug_info: The `GraphDebugInfo` containing all the debug info.
Returns:
A function which retrieves the stack traces from the original graph and
converts them to a `GraphDebugInfo` for a given set of nodes.
"""
def f(original_nodes):
"""Function to create `GraphDebugInfo` for the given `original_nodes`."""
if not saved_debug_info:
return None
output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
# All the files are copied over, so the index wouldn't be changed.
output_debug_info.files[:] = saved_debug_info.files
# We only copy over the debug info for the input nodes
for func, node in original_nodes:
debug_key = node + "@" + func
output_debug_info.traces[debug_key].CopyFrom(
saved_debug_info.traces[debug_key])
return output_debug_info
return f
def get_debug_info(nodes_to_debug_info_func, converted_graph):
"""Returns the debug info for the original nodes in the `converted_graph`.
Args:
nodes_to_debug_info_func: The method to collect the op debug info for the
nodes.
converted_graph: A `GraphDef` after optimization and transformation.
Returns:
`GraphDebugInfo` for all the original nodes in `converted_graph`.
"""
if not nodes_to_debug_info_func:
return None
# Collect all the debug info nodes from the converted_graph
original_nodes = set()
for node in converted_graph.node:
debug_nodes = node.experimental_debug_info.original_node_names
debug_funcs = node.experimental_debug_info.original_func_names
# If the `original_node_names` are empty, uses the node name directly.
if not debug_nodes:
original_nodes.add(("", node.name))
else:
for i in range(len(debug_nodes)):
debug_func = "" if i >= len(debug_funcs) else debug_funcs[i]
original_nodes.add((debug_func, debug_nodes[i]))
# Convert the nodes to the debug info proto object.
return nodes_to_debug_info_func(original_nodes)
def convert_bytes_to_c_source(data,
array_name,
max_line_width=80,
include_guard=None,
include_path=None,
use_tensorflow_license=False):
"""Returns strings representing a C constant array containing `data`.
Args:
data: Byte array that will be converted into a C constant.
array_name: String to use as the variable name for the constant array.
max_line_width: The longest line length, for formatting purposes.
include_guard: Name to use for the include guard macro definition.
include_path: Optional path to include in the source file.
use_tensorflow_license: Whether to include the standard TensorFlow Apache2
license in the generated files.
Returns:
Text that can be compiled as a C source file to link in the data as a
literal array of values.
Text that can be used as a C header file to reference the literal array.
"""
starting_pad = " "
array_lines = []
array_line = starting_pad
for value in bytearray(data):
if (len(array_line) + 4) > max_line_width:
array_lines.append(array_line + "\n")
array_line = starting_pad
array_line += " 0x%02x," % (value)
if len(array_line) > len(starting_pad):
array_lines.append(array_line + "\n")
array_values = "".join(array_lines)
if include_guard is None:
include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_"
if include_path is not None:
include_line = "#include \"{include_path}\"\n".format(
include_path=include_path)
else:
include_line = ""
if use_tensorflow_license:
license_text = """
/* Copyright {year} The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
""".format(year=datetime.date.today().year)
else:
license_text = ""
source_template = """{license_text}
// This is a TensorFlow Lite model file that has been converted into a C data
// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
// This form is useful for compiling into a binary for devices that don't have a
// file system.
{include_line}
// We need to keep the data array aligned on some architectures.
#ifdef __has_attribute
#define HAVE_ATTRIBUTE(x) __has_attribute(x)
#else
#define HAVE_ATTRIBUTE(x) 0
#endif
#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
#else
#define DATA_ALIGN_ATTRIBUTE
#endif
const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{
{array_values}}};
const int {array_name}_len = {array_length};
"""
source_text = source_template.format(
array_name=array_name,
array_length=len(data),
array_values=array_values,
license_text=license_text,
include_line=include_line)
header_template = """
{license_text}
// This is a TensorFlow Lite model file that has been converted into a C data
// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
// This form is useful for compiling into a binary for devices that don't have a
// file system.
#ifndef {include_guard}
#define {include_guard}
extern const unsigned char {array_name}[];
extern const int {array_name}_len;
#endif // {include_guard}
"""
header_text = header_template.format(
array_name=array_name,
include_guard=include_guard,
license_text=license_text)
return source_text, header_text
def _convert_model_from_bytearray_to_object(model_bytearray):
"""Converts a tflite model from a bytearray into a parsable object."""
model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
model_object = schema_fb.ModelT.InitFromObj(model_object)
model_object = copy.deepcopy(model_object)
model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0]
return model_object
def _convert_model_from_object_to_bytearray(model_object):
"""Converts a tflite model from a parsable object into a bytearray."""
# Initial size of the buffer, which will grow automatically if needed
builder = flatbuffers.Builder(1024)
model_offset = model_object.Pack(builder)
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
return bytes(builder.Output())
def _remove_tensors_from_model(model, remove_tensors_idxs):
"""Remove tensors from model."""
if not remove_tensors_idxs:
return
if len(model.subgraphs) > 1:
raise ValueError("Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
# An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
# exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
logging.debug("Removing tensors only at the end of the tensor list")
del tensors[min(remove_tensors_idxs):]
else:
logging.debug("Removing tensors requires updating the model")
# Map the old tensor indices to new tensor indices
d_old_to_new_tensors = {}
left_shift_by = 0
for idx in range(len(tensors)):
if idx in remove_tensors_idxs:
left_shift_by += 1
else:
d_old_to_new_tensors[idx] = idx - left_shift_by
logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
# Update tensor indices referenced throughout the model
def update_tensors(tensor_idxs):
for i, ti in enumerate(tensor_idxs):
tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
update_tensors(subgraph.inputs)
update_tensors(subgraph.outputs)
for op in operators:
update_tensors(op.inputs)
update_tensors(op.outputs)
# Delete the tensors
for idx in sorted(remove_tensors_idxs, reverse=True):
tensors.pop(idx)
logging.debug("Removed tensors marked for deletion")
def _validate_and_find_int8_quantized_inputs_outputs(model):
"""Validate that model input is quantized and output is dequantized."""
if len(model.subgraphs) > 1:
raise ValueError("Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
# Ensure model has atleast one quantize and dequantize operator
quant_opcode_idx, dequant_opcode_idx = None, None
for idx, opcode in enumerate(model.operatorCodes):
if opcode.builtinCode == schema_fb.BuiltinOperator.QUANTIZE:
quant_opcode_idx = idx
elif opcode.builtinCode == schema_fb.BuiltinOperator.DEQUANTIZE:
dequant_opcode_idx = idx
if quant_opcode_idx is not None and dequant_opcode_idx is not None:
break
if quant_opcode_idx is None and dequant_opcode_idx is None:
raise ValueError("Model is not integer quantized as it does not "
"contain quantize/dequantize operators.")
# Ensure model inputs and outputs are integer quantized
input_quant_ops, output_dequant_ops = [], []
for op in operators:
# Find input quantize operator
if op.opcodeIndex == quant_opcode_idx and op.inputs[0] in subgraph.inputs:
pos, float_tensor, int_tensor = \
"input", tensors[op.inputs[0]], tensors[op.outputs[0]]
input_quant_ops.append(op)
# Find output dequantize operator
elif op.opcodeIndex == dequant_opcode_idx and \
op.outputs[0] in subgraph.outputs:
pos, float_tensor, int_tensor = \
"output", tensors[op.outputs[0]], tensors[op.inputs[0]]
output_dequant_ops.append(op)
# Otherwise, ignore
else:
continue
# If found, validate the input/output tensor type
if float_tensor.type != schema_fb.TensorType.FLOAT32:
raise ValueError(
"Model {} type must be tf.float32. Expected type for tensor with "
"name '{}' is tf.float32, instead type is tf.{}".format(
pos, float_tensor.name,
_convert_tflite_enum_type_to_tf_type(float_tensor.type).name))
if int_tensor.type != schema_fb.TensorType.INT8:
raise ValueError(
"Model is not integer quantized. Expected type for tensor with "
"name '{}' is tf.int8, instead type is tf.{}".format(
int_tensor.name,
_convert_tflite_enum_type_to_tf_type(int_tensor.type).name))
return input_quant_ops, output_dequant_ops
def modify_integer_quantized_model_io_type(
model, inference_input_type=_lite_constants.FLOAT,
inference_output_type=_lite_constants.FLOAT):
"""Modify the float input/output type of an integer quantized model.
Args:
model: An int8 quantized tflite model with float input and output.
inference_input_type: tf.DType representing final input type.
(default tf.float32)
inference_output_type: tf.DType representing final output type.
(default tf.float32)
Returns:
An int8 quantized tflite model with modified input and/or output type.
Raises:
ValueError: If the model is not int8 quantized or the inference_input_type
and/or inference_input_type is unsupported.
RuntimeError: If the modification was unsuccessful.
"""
# Return if input and output types default to float
if inference_input_type == _lite_constants.FLOAT and \
inference_output_type == _lite_constants.FLOAT:
return model
# Validate input and output types
if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
raise ValueError("The `inference_input_type` should be in {}".format(
tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES)))
if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
raise ValueError("The `inference_output_type` should be in {}".format(
tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES)))
logging.debug(("Attempting to modify the model input from tf.float32 to %s "
"and output from tf.float32 to %s"),
_get_dtype_name(inference_input_type),
_get_dtype_name(inference_output_type))
# Convert the model to an object
model = _convert_model_from_bytearray_to_object(model)
# Validate the integer quantized model
input_quant_ops, output_dequant_ops = \
_validate_and_find_int8_quantized_inputs_outputs(model)
# Initialize references and variables
if len(model.subgraphs) > 1:
raise ValueError("Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
remove_tensors_idxs = set()
# Modify model input type
if inference_input_type == _lite_constants.QUANTIZED_UINT8:
# Change quant op (float to int8) to quant op (uint8 to int8)
for op in input_quant_ops:
int8_quantization = tensors[op.outputs[0]].quantization
uint8_quantization = schema_fb.QuantizationParametersT()
uint8_quantization.scale = [int8_quantization.scale[0]]
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
tensors[op.inputs[0]].quantization = uint8_quantization
tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
elif inference_input_type == _lite_constants.INT8:
# Remove the inputs and the quant operator
for op in input_quant_ops:
subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
remove_tensors_idxs.add(op.inputs[0])
operators.remove(op)
# Modify model output type
if inference_output_type == _lite_constants.QUANTIZED_UINT8:
# Change dequant op (int8 to float) to quant op (int8 to uint8)
for op in output_dequant_ops:
op.opcodeIndex = input_quant_ops[0].opcodeIndex
int8_quantization = tensors[op.inputs[0]].quantization
uint8_quantization = schema_fb.QuantizationParametersT()
uint8_quantization.scale = [int8_quantization.scale[0]]
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
tensors[op.outputs[0]].quantization = uint8_quantization
tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
elif inference_output_type == _lite_constants.INT8:
# Remove the outputs and the dequant operator
for op in output_dequant_ops:
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
remove_tensors_idxs.add(op.outputs[0])
operators.remove(op)
# Remove tensors marked for deletion.
_remove_tensors_from_model(model, remove_tensors_idxs)
# Convert the model to a bytearray
model = _convert_model_from_object_to_bytearray(model)
return model