Fix some comments and specifications for int8 quantization.
The cl made the following changes:
- add int8 to all the related argument comments
- when the "inference_type" is int8, grappler optimization is disabled
- use "inference_type", instead of "inference_input_type" to verify quant stats is specified when it is not post-training quantization.
PiperOrigin-RevId: 285229735
Change-Id: Ie8da5c4d79fb60100c1041bd4573fe603cd304e6
diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py
index 78bfdc8..bac1cb6 100644
--- a/tensorflow/lite/python/convert.py
+++ b/tensorflow/lite/python/convert.py
@@ -255,10 +255,10 @@
`foo.shape` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
inference_type: Target data type of real-number arrays in the output file.
- Must be `{tf.float32, tf.uint8}`. (default tf.float32)
- Must be `{tf.float32, tf.uint8}`. (default `inference_type`)
+ Must be `{tf.float32, tf.uint8, tf.int8}`. (default tf.float32)
inference_input_type: Target data type of real-number input arrays. Allows
for a different type for input arrays in the case of quantization.
+ Must be `{tf.float32, tf.uint8, tf.int8}`. (default `inference_type`)
input_format: Type of data to read Currently must be
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
input_shapes: Input array shape. It needs to be a list of the same length
@@ -267,7 +267,7 @@
GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: List of tuples of floats representing the mean and
standard deviation. Each tuple maps to the corresponding input tensor.
- Only need if `inference_input_type` is `QUANTIZED_UINT8`.
+ Only need if `inference_input_type` is `QUANTIZED_UINT8` or `INT8`.
real_input_value = (quantized_input_value - mean_value) / std_dev_value.
(default None)
default_ranges_stats: Tuple of integers representing (min, max) range values
@@ -363,11 +363,10 @@
input_array.data_type = util.convert_dtype_to_tflite_type(
input_tensor.dtype)
- if toco.inference_input_type in \
- [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]:
- if not quantized_input_stats:
- raise ValueError("std_dev and mean must be defined when "
- "inference_input_type is QUANTIZED_UINT8.")
+ if toco.inference_type in [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]:
+ if not quantized_input_stats and not post_training_quantize:
+ raise ValueError("std_dev and mean must be defined when inference_type "
+ "is QUANTIZED_UINT8 or INT8.")
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
if input_shapes is None:
shape = input_tensor.shape
@@ -418,11 +417,13 @@
for idx, (name, shape) in enumerate(input_arrays_with_shape):
input_array = model_flags.input_arrays.add()
- if toco_flags.inference_input_type == _types_pb2.QUANTIZED_UINT8:
- if (("quantized_input_stats" not in kwargs) or
- (not kwargs["quantized_input_stats"])):
+ if toco_flags.inference_type in (
+ [_types_pb2.QUANTIZED_UINT8, _types_pb2.INT8]):
+ if ((("quantized_input_stats" not in kwargs) or
+ (not kwargs["quantized_input_stats"])) and
+ not toco_flags.post_training_quantize):
raise ValueError("std_dev and mean must be defined when "
- "inference_input_type is QUANTIZED_UINT8.")
+ "inference_type is QUANTIZED_UINT8 or INT8.")
input_array.mean_value, input_array.std_value = kwargs[
"quantized_input_stats"][idx]
input_array.name = name
diff --git a/tensorflow/lite/python/convert_test.py b/tensorflow/lite/python/convert_test.py
index 543ddda..fcd3128 100644
--- a/tensorflow/lite/python/convert_test.py
+++ b/tensorflow/lite/python/convert_test.py
@@ -76,8 +76,17 @@
sess.graph_def, [in_tensor], [out_tensor],
inference_type=lite_constants.QUANTIZED_UINT8)
self.assertEqual(
- "std_dev and mean must be defined when inference_input_type is "
- "QUANTIZED_UINT8.", str(error.exception))
+ "std_dev and mean must be defined when inference_type is "
+ "QUANTIZED_UINT8 or INT8.", str(error.exception))
+
+ with self.assertRaises(ValueError) as error:
+ convert.toco_convert(
+ sess.graph_def, [in_tensor], [out_tensor],
+ inference_type=lite_constants.QUANTIZED_UINT8,
+ inference_input_type=lite_constants.FLOAT)
+ self.assertEqual(
+ "std_dev and mean must be defined when inference_type is "
+ "QUANTIZED_UINT8 or INT8.", str(error.exception))
def testGraphDefBasic(self):
with ops.Graph().as_default():
@@ -176,8 +185,8 @@
enable_mlir_converter=False,
inference_type=lite_constants.QUANTIZED_UINT8)
self.assertEqual(
- "std_dev and mean must be defined when inference_input_type is "
- "QUANTIZED_UINT8.", str(error.exception))
+ "std_dev and mean must be defined when inference_type is "
+ "QUANTIZED_UINT8 or INT8.", str(error.exception))
class ConvertTestOpHint(test_util.TensorFlowTestCase):
diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py
index 3ae456e..83e97f1 100644
--- a/tensorflow/lite/python/lite.py
+++ b/tensorflow/lite/python/lite.py
@@ -998,7 +998,12 @@
"are not enabled.")
optimized_graph = self._graph_def
- if self.inference_type != constants.QUANTIZED_UINT8:
+ # if it is not uint8 or int8 with post-training quantization, it is not
+ # quantization aware training, then graph optimization is applied.
+ # Graph optimization is disabled for quantization aware training.
+ if (self.inference_type != constants.QUANTIZED_UINT8 or
+ (self.inference_type == constants.INT8 and
+ (post_training_optimize or weight_only_quantize))):
try:
optimized_graph = _run_graph_optimizations(
self._graph_def,
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index f3842ae..f4a6a4e 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -787,6 +787,12 @@
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
+ # Enable hybrid quantization, same result
+ converter.experimental_new_converter = True
+ converter.optimizations = [lite.Optimize.DEFAULT]
+ hybrid_tflite_model = converter.convert()
+ actual_value = self._evaluateTFLiteModel(hybrid_tflite_model, [input_data])
+ np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
if __name__ == '__main__':
test.main()