[lite] Fix for the flow when lowering to saved model part is triggered.
refactor saved model handling in a shared function and call it.
PiperOrigin-RevId: 417621294
Change-Id: I78f582f088e21a498a6f988e79548ca869f10d00
diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py
index c3a693a..f3431c1 100644
--- a/tensorflow/lite/python/lite.py
+++ b/tensorflow/lite/python/lite.py
@@ -939,6 +939,35 @@
graph=frozen_func.graph)
return graph_def
+ def _convert_from_saved_model(self, graph_def):
+ """Helper method that converts saved model.
+
+ Args:
+ graph_def: GraphDef object for the model, used only for stats.
+
+ Returns:
+ The converted TFLite model.
+ """
+ # Update conversion params with graph_def.
+ self._save_conversion_params_metric(graph_def)
+ # Get quantization options and do some sanity checks.
+ quant_mode = QuantizationMode(
+ self.optimizations, self.target_spec, self.representative_dataset,
+ graph_def, self._experimental_disable_per_channel,
+ self._experimental_new_dynamic_range_quantizer,
+ self._experimental_low_bit_qat)
+ self._validate_inference_input_output_types(quant_mode)
+ converter_kwargs = {
+ "enable_tflite_resource_variables":
+ self.experimental_enable_resource_variables
+ }
+ converter_kwargs.update(self._get_base_converter_args())
+ converter_kwargs.update(quant_mode.converter_flags())
+
+ result = _convert_saved_model(**converter_kwargs)
+ return self._optimize_tflite_model(
+ result, quant_mode, quant_io=self.experimental_new_quantizer)
+
def convert(self, graph_def, input_tensors, output_tensors):
"""Converts a TensorFlow GraphDef based on instance variables.
@@ -1052,27 +1081,7 @@
_convert_debug_info_func(self._trackable_obj.graph_debug_info),
graph_def)
- # Update conversion params with graph_def.
- self._save_conversion_params_metric(graph_def)
- # Get quantization options and do some sanity checks.
- quant_mode = QuantizationMode(
- self.optimizations, self.target_spec, self.representative_dataset,
- graph_def, self._experimental_disable_per_channel,
- self._experimental_new_dynamic_range_quantizer,
- self._experimental_low_bit_qat)
- self._validate_inference_input_output_types(quant_mode)
-
- converter_kwargs = {
- "enable_tflite_resource_variables":
- self.experimental_enable_resource_variables
- }
- converter_kwargs.update(self._get_base_converter_args())
- converter_kwargs.update(quant_mode.converter_flags())
-
- result = _convert_saved_model(**converter_kwargs)
-
- return self._optimize_tflite_model(
- result, quant_mode, quant_io=self.experimental_new_quantizer)
+ return self._convert_from_saved_model(graph_def)
class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
@@ -1337,11 +1346,11 @@
"""
temp_dir = tempfile.mkdtemp()
try:
- graph_def, input_tensors, output_tensors = (
+ graph_def, input_tensors, _ = (
self._convert_concrete_functions_to_saved_model(temp_dir))
if self.saved_model_dir:
- return super(TFLiteFrozenGraphConverterV2,
- self).convert(graph_def, input_tensors, output_tensors)
+ self._validate_inputs(graph_def, input_tensors)
+ return self._convert_from_saved_model(graph_def)
finally:
shutil.rmtree(temp_dir, True)
return None