| # Lint as: python2, python3 |
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Functions used by multiple converter files.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import copy |
| import datetime |
| import sys |
| |
| from absl import logging |
| import six |
| from six.moves import range |
| |
| import flatbuffers |
| from tensorflow.core.protobuf import config_pb2 as _config_pb2 |
| from tensorflow.core.protobuf import graph_debug_info_pb2 |
| from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 |
| from tensorflow.lite.python import lite_constants as _lite_constants |
| from tensorflow.lite.python import schema_py_generated as schema_fb |
| from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs |
| from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes |
| from tensorflow.lite.toco import types_pb2 as _types_pb2 |
| from tensorflow.python.eager import function |
| from tensorflow.python.framework import convert_to_constants as _convert_to_constants |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import error_interpolation as _error_interpolation |
| from tensorflow.python.framework import graph_util as tf_graph_util |
| from tensorflow.python.grappler import tf_optimizer |
| from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph |
| |
| # Map of tf.dtypes to TFLite types_flag_pb2. |
| _MAP_TF_TO_TFLITE_TYPES = { |
| dtypes.float32: _types_pb2.FLOAT, |
| dtypes.float16: _types_pb2.FLOAT16, |
| dtypes.int32: _types_pb2.INT32, |
| dtypes.uint8: _types_pb2.QUANTIZED_UINT8, |
| dtypes.int64: _types_pb2.INT64, |
| dtypes.string: _types_pb2.STRING, |
| dtypes.bool: _types_pb2.BOOL, |
| dtypes.int16: _types_pb2.QUANTIZED_INT16, |
| dtypes.complex64: _types_pb2.COMPLEX64, |
| dtypes.int8: _types_pb2.INT8, |
| dtypes.float64: _types_pb2.FLOAT64, |
| dtypes.complex128: _types_pb2.COMPLEX128, |
| } |
| |
| _MAP_TFLITE_ENUM_TO_TF_TYPES = { |
| 0: dtypes.float32, |
| 1: dtypes.float16, |
| 2: dtypes.int32, |
| 3: dtypes.uint8, |
| 4: dtypes.int64, |
| 5: dtypes.string, |
| 6: dtypes.bool, |
| 7: dtypes.int16, |
| 8: dtypes.complex64, |
| 9: dtypes.int8, |
| 10: dtypes.float64, |
| 11: dtypes.complex128, |
| } |
| |
| _TFLITE_FILE_IDENTIFIER = b"TFL3" |
| |
| _TFLITE_MODEL_INPUT_OUTPUT_TYPES = (_lite_constants.FLOAT, _lite_constants.INT8, |
| _lite_constants.QUANTIZED_UINT8) |
| |
| |
| def convert_dtype_to_tflite_type(tf_dtype): |
| """Converts tf.dtype to TFLite proto type. |
| |
| Args: |
| tf_dtype: tf.dtype |
| |
| Raises: |
| ValueError: Unsupported tf.dtype. |
| |
| Returns: |
| types_flag_pb2. |
| """ |
| result = _MAP_TF_TO_TFLITE_TYPES.get(tf_dtype) |
| if result is None: |
| raise ValueError("Unsupported tf.dtype {0}".format(tf_dtype)) |
| return result |
| |
| |
| def _convert_tflite_enum_type_to_tf_type(tflite_enum_type): |
| """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32). |
| |
| Args: |
| tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32) |
| |
| Raises: |
| ValueError: If an invalid tflite enum type is provided. |
| |
| Returns: |
| tf type (eg: tf.float32) |
| """ |
| tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type) |
| if tf_type is None: |
| raise ValueError( |
| "Unsupported enum {}. The valid map of enum to tf types is : {}" |
| .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES)) |
| return tf_type |
| |
| |
| def _get_dtype_name(tf_type): |
| """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32").""" |
| return "tf." + tf_type.name |
| |
| |
| def get_tensor_name(tensor): |
| """Returns name of the input tensor. |
| |
| Args: |
| tensor: tf.Tensor |
| |
| Returns: |
| str |
| """ |
| parts = six.ensure_str(tensor.name).split(":") |
| if len(parts) > 2: |
| raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format( |
| len(parts) - 1)) |
| |
| # To be consistent with the tensor naming scheme in tensorflow, we need |
| # drop the ':0' suffix for the first tensor. |
| if len(parts) > 1 and parts[1] != "0": |
| return tensor.name |
| return parts[0] |
| |
| |
| def get_tensors_from_tensor_names(graph, tensor_names): |
| """Gets the Tensors associated with the `tensor_names` in the provided graph. |
| |
| Args: |
| graph: TensorFlow Graph. |
| tensor_names: List of strings that represent names of tensors in the graph. |
| |
| Returns: |
| A list of Tensor objects in the same order the names are provided. |
| |
| Raises: |
| ValueError: |
| tensor_names contains an invalid tensor name. |
| """ |
| # Get the list of all of the tensors. |
| tensor_name_to_tensor = {} |
| for op in graph.get_operations(): |
| for tensor in op.values(): |
| tensor_name_to_tensor[get_tensor_name(tensor)] = tensor |
| |
| # Get the tensors associated with tensor_names. |
| tensors = [] |
| invalid_tensors = [] |
| for name in tensor_names: |
| if not isinstance(name, six.string_types): |
| raise ValueError("Invalid type for a tensor name in the provided graph. " |
| "Expected type for a tensor name is 'str', instead got " |
| "type '{}' for tensor name '{}'".format( |
| type(name), name)) |
| |
| tensor = tensor_name_to_tensor.get(name) |
| if tensor is None: |
| invalid_tensors.append(name) |
| else: |
| tensors.append(tensor) |
| |
| # Throw ValueError if any user input names are not valid tensors. |
| if invalid_tensors: |
| raise ValueError("Invalid tensors '{}' were found.".format( |
| ",".join(invalid_tensors))) |
| return tensors |
| |
| |
| def set_tensor_shapes(tensors, shapes): |
| """Sets Tensor shape for each tensor if the shape is defined. |
| |
| Args: |
| tensors: TensorFlow ops.Tensor. |
| shapes: Dict of strings representing input tensor names to list of |
| integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). |
| |
| Raises: |
| ValueError: |
| `shapes` contains an invalid tensor. |
| `shapes` contains an invalid shape for a valid tensor. |
| """ |
| if shapes: |
| tensor_names_to_tensor = { |
| get_tensor_name(tensor): tensor for tensor in tensors |
| } |
| for name, shape in shapes.items(): |
| if name not in tensor_names_to_tensor: |
| raise ValueError("Invalid tensor \'{}\' found in tensor shapes " |
| "map.".format(name)) |
| if shape is not None: |
| tensor = tensor_names_to_tensor[name] |
| try: |
| tensor.set_shape(shape) |
| except ValueError as error: |
| message = ("The shape of tensor '{0}' cannot be changed from {1} to " |
| "{2}. {3}".format(name, tensor.shape, shape, str(error))) |
| raise ValueError(message) |
| |
| |
| def get_grappler_config(optimizers_list): |
| """Creates a tf.compat.v1.ConfigProto for configuring Grappler. |
| |
| Args: |
| optimizers_list: List of strings that represents the list of optimizers. |
| |
| Returns: |
| tf.ConfigProto. |
| """ |
| config = _config_pb2.ConfigProto() |
| rewrite_options = config.graph_options.rewrite_options |
| for optimizer in optimizers_list: |
| rewrite_options.optimizers.append(optimizer) |
| return config |
| |
| |
| def run_graph_optimizations(graph_def, |
| input_arrays, |
| output_arrays, |
| config, |
| graph=None): |
| """Apply standard TensorFlow optimizations to the graph_def. |
| |
| Args: |
| graph_def: Frozen GraphDef to be optimized. |
| input_arrays: List of arrays that are considered inputs of the graph. |
| output_arrays: List of arrays that are considered outputs of the graph. |
| config: tf.ConfigProto. |
| graph: TensorFlow Graph. Required when Eager mode is enabled. (default None) |
| |
| Returns: |
| A new, optimized GraphDef. |
| """ |
| meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph) |
| |
| signature = _meta_graph_pb2.SignatureDef() |
| for array in input_arrays: |
| signature.inputs[array.name].name = array.name |
| signature.inputs[array.name].dtype = array.dtype.as_datatype_enum |
| signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto()) |
| |
| for array in output_arrays: |
| signature.outputs[array.name].name = array.name |
| signature.outputs[array.name].dtype = array.dtype.as_datatype_enum |
| signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto()) |
| |
| meta_graph.signature_def["not_used_key"].CopyFrom(signature) |
| |
| # We need to add a collection called 'train_op' so that grappler |
| # knows what the outputs are. |
| fetch_collection = _meta_graph_pb2.CollectionDef() |
| for array in input_arrays + output_arrays: |
| fetch_collection.node_list.value.append(array.name) |
| meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) |
| |
| return tf_optimizer.OptimizeGraph(config, meta_graph) |
| |
| |
| def _convert_op_hints_if_present(sess, graph_def, output_tensors, |
| hinted_outputs_nodes): |
| if is_frozen_graph(sess): |
| raise ValueError("Try to convert op hints, needs unfrozen graph.") |
| output_arrays = [get_tensor_name(tensor) for tensor in output_tensors] |
| graph_def = tf_graph_util.convert_variables_to_constants( |
| sess, graph_def, output_arrays + hinted_outputs_nodes) |
| graph_def = convert_op_hints_to_stubs(graph_def=graph_def) |
| return graph_def |
| |
| |
| def freeze_graph(sess, input_tensors, output_tensors): |
| """Returns a frozen GraphDef. |
| |
| Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the |
| existing GraphDef is returned. The Grappler pass is only run on models that |
| are frozen in order to inline the functions in the graph. |
| If OpHints is present, it will try to convert the OpHint graph. |
| |
| Args: |
| sess: TensorFlow Session. |
| input_tensors: List of input tensors. |
| output_tensors: List of output tensors (only .name is used from this). |
| |
| Returns: |
| Frozen GraphDef. |
| """ |
| # Runs a Grappler pass in order to inline any functions in the graph. |
| # Asides from inlining any simple function, Grappler will also try to lower |
| # while loop into switch merge representation which is undesired for Ophints, |
| # so we simply remove those attributes to prevent Grappler from doing so. |
| graph_def = _convert_to_constants.disable_lower_using_switch_merge( |
| sess.graph_def) |
| config = get_grappler_config(["function"]) |
| graph_def = run_graph_optimizations( |
| graph_def, input_tensors, output_tensors, config, graph=sess.graph) |
| |
| # If ophints are present, just convert them. |
| hinted_outputs_nodes = find_all_hinted_output_nodes(sess) |
| if hinted_outputs_nodes: |
| return _convert_op_hints_if_present(sess, graph_def, output_tensors, |
| hinted_outputs_nodes) |
| |
| if not is_frozen_graph(sess): |
| output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors] |
| return tf_graph_util.convert_variables_to_constants(sess, graph_def, |
| output_node_names) |
| else: |
| return sess.graph_def |
| |
| |
| def is_frozen_graph(sess): |
| """Determines if the graph is frozen. |
| |
| Determines if a graph has previously been frozen by checking for any |
| operations of type Variable*. If variables are found, the graph is not frozen. |
| |
| Args: |
| sess: TensorFlow Session. |
| |
| Returns: |
| Bool. |
| """ |
| for op in sess.graph.get_operations(): |
| if six.ensure_str(op.type).startswith("Variable") or six.ensure_str( |
| op.type).endswith("VariableOp"): |
| return False |
| return True |
| |
| |
| def build_debug_info_func(original_graph): |
| """Returns a method to retrieve the `GraphDebugInfo` from the original graph. |
| |
| Args: |
| original_graph: The original `Graph` containing all the op stack traces. |
| |
| Returns: |
| A function which retrieves the stack traces from the original graph and |
| converts them to a `GraphDebugInfo` for a given set of nodes. |
| """ |
| |
| def f(original_nodes): |
| """Function to create `GraphDebugInfo` for the given `original_nodes`.""" |
| if not original_graph: |
| return None |
| # For the given nodes, gets all the op definitions in the original graph. |
| useful_ops = [] |
| for func, name in original_nodes: |
| try: |
| if not func: |
| useful_ops.append((func, original_graph.get_operation_by_name(name))) |
| else: |
| sub_func = original_graph._get_function(func) # pylint: disable=protected-access |
| if isinstance(sub_func, function._EagerDefinedFunction): # pylint: disable=protected-access |
| useful_ops.append( |
| (func, sub_func.graph.get_operation_by_name(name))) |
| else: |
| sys.stderr.write( |
| "Use '@tf.function' or '@defun' to decorate the function.") |
| continue |
| except KeyError: |
| # New node created by graph optimizer. No stack trace from source code. |
| continue |
| # Convert all the op definitions to stack traces in terms of GraphDebugInfo. |
| return _error_interpolation.create_graph_debug_info_def(useful_ops) |
| |
| return f |
| |
| |
| def convert_debug_info_func(saved_debug_info): |
| """Returns a method to retrieve the `GraphDebugInfo` from the original graph. |
| |
| Args: |
| saved_debug_info: The `GraphDebugInfo` containing all the debug info. |
| |
| Returns: |
| A function which retrieves the stack traces from the original graph and |
| converts them to a `GraphDebugInfo` for a given set of nodes. |
| """ |
| |
| def f(original_nodes): |
| """Function to create `GraphDebugInfo` for the given `original_nodes`.""" |
| if not saved_debug_info: |
| return None |
| |
| output_debug_info = graph_debug_info_pb2.GraphDebugInfo() |
| # All the files are copied over, so the index wouldn't be changed. |
| output_debug_info.files[:] = saved_debug_info.files |
| # We only copy over the debug info for the input nodes |
| for func, node in original_nodes: |
| debug_key = node + "@" + func |
| output_debug_info.traces[debug_key].CopyFrom( |
| saved_debug_info.traces[debug_key]) |
| return output_debug_info |
| |
| return f |
| |
| |
| def get_debug_info(nodes_to_debug_info_func, converted_graph): |
| """Returns the debug info for the original nodes in the `converted_graph`. |
| |
| Args: |
| nodes_to_debug_info_func: The method to collect the op debug info for the |
| nodes. |
| converted_graph: A `GraphDef` after optimization and transformation. |
| |
| Returns: |
| `GraphDebugInfo` for all the original nodes in `converted_graph`. |
| """ |
| if not nodes_to_debug_info_func: |
| return None |
| |
| # Collect all the debug info nodes from the converted_graph |
| original_nodes = set() |
| for node in converted_graph.node: |
| debug_nodes = node.experimental_debug_info.original_node_names |
| debug_funcs = node.experimental_debug_info.original_func_names |
| # If the `original_node_names` are empty, uses the node name directly. |
| if not debug_nodes: |
| original_nodes.add(("", node.name)) |
| else: |
| for i in range(len(debug_nodes)): |
| debug_func = "" if i >= len(debug_funcs) else debug_funcs[i] |
| original_nodes.add((debug_func, debug_nodes[i])) |
| |
| # Convert the nodes to the debug info proto object. |
| return nodes_to_debug_info_func(original_nodes) |
| |
| |
| def convert_bytes_to_c_source(data, |
| array_name, |
| max_line_width=80, |
| include_guard=None, |
| include_path=None, |
| use_tensorflow_license=False): |
| """Returns strings representing a C constant array containing `data`. |
| |
| Args: |
| data: Byte array that will be converted into a C constant. |
| array_name: String to use as the variable name for the constant array. |
| max_line_width: The longest line length, for formatting purposes. |
| include_guard: Name to use for the include guard macro definition. |
| include_path: Optional path to include in the source file. |
| use_tensorflow_license: Whether to include the standard TensorFlow Apache2 |
| license in the generated files. |
| |
| Returns: |
| Text that can be compiled as a C source file to link in the data as a |
| literal array of values. |
| Text that can be used as a C header file to reference the literal array. |
| """ |
| |
| starting_pad = " " |
| array_lines = [] |
| array_line = starting_pad |
| for value in bytearray(data): |
| if (len(array_line) + 4) > max_line_width: |
| array_lines.append(array_line + "\n") |
| array_line = starting_pad |
| array_line += " 0x%02x," % (value) |
| if len(array_line) > len(starting_pad): |
| array_lines.append(array_line + "\n") |
| array_values = "".join(array_lines) |
| |
| if include_guard is None: |
| include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_" |
| |
| if include_path is not None: |
| include_line = "#include \"{include_path}\"\n".format( |
| include_path=include_path) |
| else: |
| include_line = "" |
| |
| if use_tensorflow_license: |
| license_text = """ |
| /* Copyright {year} The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| """.format(year=datetime.date.today().year) |
| else: |
| license_text = "" |
| |
| source_template = """{license_text} |
| // This is a TensorFlow Lite model file that has been converted into a C data |
| // array using the tensorflow.lite.util.convert_bytes_to_c_source() function. |
| // This form is useful for compiling into a binary for devices that don't have a |
| // file system. |
| |
| {include_line} |
| // We need to keep the data array aligned on some architectures. |
| #ifdef __has_attribute |
| #define HAVE_ATTRIBUTE(x) __has_attribute(x) |
| #else |
| #define HAVE_ATTRIBUTE(x) 0 |
| #endif |
| #if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) |
| #define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4))) |
| #else |
| #define DATA_ALIGN_ATTRIBUTE |
| #endif |
| |
| const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{ |
| {array_values}}}; |
| const int {array_name}_len = {array_length}; |
| """ |
| |
| source_text = source_template.format( |
| array_name=array_name, |
| array_length=len(data), |
| array_values=array_values, |
| license_text=license_text, |
| include_line=include_line) |
| |
| header_template = """ |
| {license_text} |
| |
| // This is a TensorFlow Lite model file that has been converted into a C data |
| // array using the tensorflow.lite.util.convert_bytes_to_c_source() function. |
| // This form is useful for compiling into a binary for devices that don't have a |
| // file system. |
| |
| #ifndef {include_guard} |
| #define {include_guard} |
| |
| extern const unsigned char {array_name}[]; |
| extern const int {array_name}_len; |
| |
| #endif // {include_guard} |
| """ |
| |
| header_text = header_template.format( |
| array_name=array_name, |
| include_guard=include_guard, |
| license_text=license_text) |
| |
| return source_text, header_text |
| |
| |
| def _convert_model_from_bytearray_to_object(model_bytearray): |
| """Converts a tflite model from a bytearray into a parsable object.""" |
| model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0) |
| model_object = schema_fb.ModelT.InitFromObj(model_object) |
| model_object = copy.deepcopy(model_object) |
| model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0] |
| return model_object |
| |
| |
| def _convert_model_from_object_to_bytearray(model_object): |
| """Converts a tflite model from a parsable object into a bytearray.""" |
| # Initial size of the buffer, which will grow automatically if needed |
| builder = flatbuffers.Builder(1024) |
| model_offset = model_object.Pack(builder) |
| builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER) |
| return bytes(builder.Output()) |
| |
| |
| def _remove_tensors_from_model(model, remove_tensors_idxs): |
| """Remove tensors from model.""" |
| if not remove_tensors_idxs: |
| return |
| if len(model.subgraphs) > 1: |
| raise ValueError("Model must only have one subgraph. Instead, it has " |
| "{} subgraphs.".format(len(model.subgraphs))) |
| subgraph = model.subgraphs[0] |
| tensors = subgraph.tensors |
| operators = subgraph.operators |
| |
| logging.debug("Removing tensors at indices : %s", remove_tensors_idxs) |
| # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an |
| # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]). |
| if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs): |
| logging.debug("Removing tensors only at the end of the tensor list") |
| del tensors[min(remove_tensors_idxs):] |
| else: |
| logging.debug("Removing tensors requires updating the model") |
| # Map the old tensor indices to new tensor indices |
| d_old_to_new_tensors = {} |
| left_shift_by = 0 |
| for idx in range(len(tensors)): |
| if idx in remove_tensors_idxs: |
| left_shift_by += 1 |
| else: |
| d_old_to_new_tensors[idx] = idx - left_shift_by |
| logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__()) |
| # Update tensor indices referenced throughout the model |
| def update_tensors(tensor_idxs): |
| for i, ti in enumerate(tensor_idxs): |
| tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1) |
| update_tensors(subgraph.inputs) |
| update_tensors(subgraph.outputs) |
| for op in operators: |
| update_tensors(op.inputs) |
| update_tensors(op.outputs) |
| # Delete the tensors |
| for idx in sorted(remove_tensors_idxs, reverse=True): |
| tensors.pop(idx) |
| logging.debug("Removed tensors marked for deletion") |
| |
| |
| def _validate_and_find_int8_quantized_inputs_outputs(model): |
| """Validate that model input is quantized and output is dequantized.""" |
| if len(model.subgraphs) > 1: |
| raise ValueError("Model must only have one subgraph. Instead, it has " |
| "{} subgraphs.".format(len(model.subgraphs))) |
| subgraph = model.subgraphs[0] |
| tensors = subgraph.tensors |
| operators = subgraph.operators |
| |
| # Ensure model has atleast one quantize and dequantize operator |
| quant_opcode_idx, dequant_opcode_idx = None, None |
| for idx, opcode in enumerate(model.operatorCodes): |
| if opcode.builtinCode == schema_fb.BuiltinOperator.QUANTIZE: |
| quant_opcode_idx = idx |
| elif opcode.builtinCode == schema_fb.BuiltinOperator.DEQUANTIZE: |
| dequant_opcode_idx = idx |
| if quant_opcode_idx is not None and dequant_opcode_idx is not None: |
| break |
| if quant_opcode_idx is None and dequant_opcode_idx is None: |
| raise ValueError("Model is not integer quantized as it does not " |
| "contain quantize/dequantize operators.") |
| |
| # Ensure model inputs and outputs are integer quantized |
| input_quant_ops, output_dequant_ops = [], [] |
| for op in operators: |
| # Find input quantize operator |
| if op.opcodeIndex == quant_opcode_idx and op.inputs[0] in subgraph.inputs: |
| pos, float_tensor, int_tensor = \ |
| "input", tensors[op.inputs[0]], tensors[op.outputs[0]] |
| input_quant_ops.append(op) |
| # Find output dequantize operator |
| elif op.opcodeIndex == dequant_opcode_idx and \ |
| op.outputs[0] in subgraph.outputs: |
| pos, float_tensor, int_tensor = \ |
| "output", tensors[op.outputs[0]], tensors[op.inputs[0]] |
| output_dequant_ops.append(op) |
| # Otherwise, ignore |
| else: |
| continue |
| # If found, validate the input/output tensor type |
| if float_tensor.type != schema_fb.TensorType.FLOAT32: |
| raise ValueError( |
| "Model {} type must be tf.float32. Expected type for tensor with " |
| "name '{}' is tf.float32, instead type is tf.{}".format( |
| pos, float_tensor.name, |
| _convert_tflite_enum_type_to_tf_type(float_tensor.type).name)) |
| if int_tensor.type != schema_fb.TensorType.INT8: |
| raise ValueError( |
| "Model is not integer quantized. Expected type for tensor with " |
| "name '{}' is tf.int8, instead type is tf.{}".format( |
| int_tensor.name, |
| _convert_tflite_enum_type_to_tf_type(int_tensor.type).name)) |
| |
| return input_quant_ops, output_dequant_ops |
| |
| |
| def modify_integer_quantized_model_io_type( |
| model, inference_input_type=_lite_constants.FLOAT, |
| inference_output_type=_lite_constants.FLOAT): |
| """Modify the float input/output type of an integer quantized model. |
| |
| Args: |
| model: An int8 quantized tflite model with float input and output. |
| inference_input_type: tf.DType representing final input type. |
| (default tf.float32) |
| inference_output_type: tf.DType representing final output type. |
| (default tf.float32) |
| |
| Returns: |
| An int8 quantized tflite model with modified input and/or output type. |
| |
| Raises: |
| ValueError: If the model is not int8 quantized or the inference_input_type |
| and/or inference_input_type is unsupported. |
| RuntimeError: If the modification was unsuccessful. |
| |
| """ |
| # Return if input and output types default to float |
| if inference_input_type == _lite_constants.FLOAT and \ |
| inference_output_type == _lite_constants.FLOAT: |
| return model |
| |
| # Validate input and output types |
| if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES: |
| raise ValueError("The `inference_input_type` should be in {}".format( |
| tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES))) |
| if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES: |
| raise ValueError("The `inference_output_type` should be in {}".format( |
| tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES))) |
| |
| logging.debug(("Attempting to modify the model input from tf.float32 to %s " |
| "and output from tf.float32 to %s"), |
| _get_dtype_name(inference_input_type), |
| _get_dtype_name(inference_output_type)) |
| # Convert the model to an object |
| model = _convert_model_from_bytearray_to_object(model) |
| |
| # Validate the integer quantized model |
| input_quant_ops, output_dequant_ops = \ |
| _validate_and_find_int8_quantized_inputs_outputs(model) |
| |
| # Initialize references and variables |
| if len(model.subgraphs) > 1: |
| raise ValueError("Model must only have one subgraph. Instead, it has " |
| "{} subgraphs.".format(len(model.subgraphs))) |
| subgraph = model.subgraphs[0] |
| tensors = subgraph.tensors |
| operators = subgraph.operators |
| remove_tensors_idxs = set() |
| |
| # Modify model input type |
| if inference_input_type == _lite_constants.QUANTIZED_UINT8: |
| # Change quant op (float to int8) to quant op (uint8 to int8) |
| for op in input_quant_ops: |
| int8_quantization = tensors[op.outputs[0]].quantization |
| uint8_quantization = schema_fb.QuantizationParametersT() |
| uint8_quantization.scale = [int8_quantization.scale[0]] |
| uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] |
| tensors[op.inputs[0]].quantization = uint8_quantization |
| tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8 |
| elif inference_input_type == _lite_constants.INT8: |
| # Remove the inputs and the quant operator |
| for op in input_quant_ops: |
| subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0] |
| remove_tensors_idxs.add(op.inputs[0]) |
| operators.remove(op) |
| |
| # Modify model output type |
| if inference_output_type == _lite_constants.QUANTIZED_UINT8: |
| # Change dequant op (int8 to float) to quant op (int8 to uint8) |
| for op in output_dequant_ops: |
| op.opcodeIndex = input_quant_ops[0].opcodeIndex |
| int8_quantization = tensors[op.inputs[0]].quantization |
| uint8_quantization = schema_fb.QuantizationParametersT() |
| uint8_quantization.scale = [int8_quantization.scale[0]] |
| uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] |
| tensors[op.outputs[0]].quantization = uint8_quantization |
| tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8 |
| elif inference_output_type == _lite_constants.INT8: |
| # Remove the outputs and the dequant operator |
| for op in output_dequant_ops: |
| subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0] |
| remove_tensors_idxs.add(op.outputs[0]) |
| operators.remove(op) |
| |
| # Remove tensors marked for deletion. |
| _remove_tensors_from_model(model, remove_tensors_idxs) |
| |
| # Convert the model to a bytearray |
| model = _convert_model_from_object_to_bytearray(model) |
| |
| return model |