Fix the TF quantizer to run properly with non-default tags.
TF quantizer's PTQ process raises an error when tags other than `{"serve"}` (the default tag) are used. Fix the quantizer to properly use the user-provided tags.
Non-default tags are used in legacy TF1 to identify a specific meta-graph, when multiple meta-graphs could be stored in a single SavedModel.
PiperOrigin-RevId: 454772314
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
index 828a585..8a0c4b8 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Tests for quantize_model."""
-from typing import List, Mapping, Set, Tuple
+from typing import List, Mapping, Sequence, Set, Tuple
import warnings
from absl.testing import parameterized
@@ -71,7 +71,7 @@
def _create_simple_tf1_conv_model(
use_variable_for_filter=False) -> Tuple[core.Tensor, core.Tensor]:
- """Create a basic convolution model.
+ """Creates a basic convolution model.
This is intended to be used for TF1 (graph mode) tests.
@@ -100,6 +100,31 @@
return in_placeholder, output_tensor
+def _create_data_generator(
+ input_key: str,
+ shape: Sequence[int],
+ num_examples=128) -> quantize_model._RepresentativeDataset:
+ """Creates a data generator to be used as representative dataset.
+
+ Supports generating random value input tensors mapped by the `input_key`.
+
+ Args:
+ input_key: The string key that identifies the created tensor as an input.
+ shape: Shape of the tensor data.
+ num_examples: Number of examples in the representative dataset.
+
+ Returns:
+ data_gen: A callable that creates a generator of
+ `quantize_model._RepresentativeSample`.
+ """
+
+ def data_gen():
+ for _ in range(num_examples):
+ yield {input_key: random_ops.random_uniform(shape, minval=-1., maxval=1.)}
+
+ return data_gen
+
+
def _save_tf1_model(sess: session.Session, saved_model_path: str,
signature_key: str, tags: Set[str],
inputs: Mapping[str, core.Tensor],
@@ -129,7 +154,7 @@
tags: Set[str],
input_key: str,
output_key: str,
- use_variable=False) -> None:
+ use_variable=False) -> core.Tensor:
"""Creates and saves a simple convolution model.
This is intended to be used for TF1 (graph mode) tests.
@@ -143,6 +168,9 @@
output_key: The key to the output tensor.
use_variable: Setting this to `True` makes the filter for the conv operation
a `tf.Variable`.
+
+ Returns:
+ in_placeholder: The placeholder tensor used as an input to the model.
"""
with ops.Graph().as_default(), session.Session() as sess:
in_placeholder, output_tensor = _create_simple_tf1_conv_model(
@@ -159,6 +187,8 @@
inputs={input_key: in_placeholder},
outputs={output_key: output_tensor})
+ return in_placeholder
+
@test_util.run_all_in_graph_and_eager_modes
class QuantizationMethodTest(test.TestCase):
@@ -734,20 +764,11 @@
@test_util.run_in_graph_and_eager_modes
def test_ptq_model_with_tf1_saved_model_with_variable(self):
-
- def gen_data():
- for _ in range(255):
- yield {
- 'x':
- random_ops.random_uniform(
- shape=(1, 3, 4, 3), minval=-6, maxval=6)
- }
-
input_saved_model_path = self.create_tempdir('input').full_path
signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
tags = {tag_constants.SERVING}
- _create_and_save_tf1_conv_model(
+ input_placeholder = _create_and_save_tf1_conv_model(
input_saved_model_path,
signature_key,
tags,
@@ -762,13 +783,16 @@
quantization_method=quant_opts_pb2.QuantizationMethod(
experimental_method=_ExperimentalMethod.STATIC_RANGE))
+ data_gen = _create_data_generator(
+ input_key='x', shape=input_placeholder.shape)
+
converted_model = quantize_model.quantize(
input_saved_model_path,
signature_keys,
tags,
output_directory,
quantization_options,
- representative_dataset=gen_data)
+ representative_dataset=data_gen)
self.assertIsNotNone(converted_model)
self.assertEqual(
@@ -780,20 +804,11 @@
@test_util.run_in_graph_and_eager_modes
def test_ptq_model_with_tf1_saved_model(self):
-
- def gen_data():
- for _ in range(255):
- yield {
- 'p':
- random_ops.random_uniform(
- shape=(1, 3, 4, 3), minval=0, maxval=150)
- }
-
input_saved_model_path = self.create_tempdir('input').full_path
tags = {tag_constants.SERVING}
signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
- _create_and_save_tf1_conv_model(
+ input_placeholder = _create_and_save_tf1_conv_model(
input_saved_model_path,
signature_key,
tags,
@@ -808,13 +823,16 @@
quantization_method=quant_opts_pb2.QuantizationMethod(
experimental_method=_ExperimentalMethod.STATIC_RANGE))
+ data_gen = _create_data_generator(
+ input_key='p', shape=input_placeholder.shape)
+
converted_model = quantize_model.quantize(
input_saved_model_path,
signature_keys,
tags,
output_directory,
quantization_options,
- representative_dataset=gen_data)
+ representative_dataset=data_gen)
self.assertIsNotNone(converted_model)
self.assertEqual(
@@ -827,21 +845,11 @@
@test_util.run_in_graph_and_eager_modes
def test_ptq_model_with_tf1_saved_model_invalid_input_key_raises_value_error(
self):
-
- # Representative generator function that yields with an invalid input key.
- def gen_data():
- for _ in range(255):
- yield {
- 'invalid_input_key':
- random_ops.random_uniform(
- shape=(1, 3, 4, 3), minval=-5, maxval=5)
- }
-
input_saved_model_path = self.create_tempdir('input').full_path
signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
tags = {tag_constants.SERVING}
- _create_and_save_tf1_conv_model(
+ input_placeholder = _create_and_save_tf1_conv_model(
input_saved_model_path,
signature_key,
tags,
@@ -856,6 +864,10 @@
quantization_method=quant_opts_pb2.QuantizationMethod(
experimental_method=_ExperimentalMethod.STATIC_RANGE))
+ # Representative generator function that yields with an invalid input key.
+ invalid_data_gen = _create_data_generator(
+ input_key='invalid_input_key', shape=input_placeholder.shape)
+
with self.assertRaisesRegex(
ValueError,
'Failed to run graph for post-training quantization calibration'):
@@ -865,7 +877,84 @@
tags,
output_directory,
quantization_options,
- representative_dataset=gen_data)
+ representative_dataset=invalid_data_gen)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_ptq_model_with_non_default_tags(self):
+ input_saved_model_path = self.create_tempdir('input').full_path
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ # Use a different set of tags other than {"serve"}.
+ tags = {tag_constants.TRAINING, tag_constants.GPU}
+
+ # Non-default tags are usually used when saving multiple metagraphs in TF1.
+ input_placeholder = _create_and_save_tf1_conv_model(
+ input_saved_model_path,
+ signature_key,
+ tags,
+ input_key='input',
+ output_key='output',
+ use_variable=True)
+
+ signature_keys = [signature_key]
+ output_directory = self.create_tempdir().full_path
+
+ quantization_options = quant_opts_pb2.QuantizationOptions(
+ quantization_method=quant_opts_pb2.QuantizationMethod(
+ experimental_method=_ExperimentalMethod.STATIC_RANGE))
+
+ data_gen = _create_data_generator(
+ input_key='input', shape=input_placeholder.shape)
+
+ converted_model = quantize_model.quantize(
+ input_saved_model_path,
+ signature_keys,
+ tags,
+ output_directory,
+ quantization_options,
+ representative_dataset=data_gen)
+
+ self.assertIsNotNone(converted_model)
+ self.assertEqual(
+ list(converted_model.signatures._signatures.keys()), signature_keys)
+
+ output_loader = saved_model_loader.SavedModelLoader(output_directory)
+ output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
+ self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_ptq_model_with_wrong_tags_raises_error(self):
+ input_saved_model_path = self.create_tempdir('input').full_path
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ save_tags = {tag_constants.TRAINING, tag_constants.GPU}
+
+ input_placeholder = _create_and_save_tf1_conv_model(
+ input_saved_model_path,
+ signature_key,
+ save_tags,
+ input_key='input',
+ output_key='output',
+ use_variable=True)
+
+ signature_keys = [signature_key]
+ output_directory = self.create_tempdir().full_path
+
+ quantization_options = quant_opts_pb2.QuantizationOptions(
+ quantization_method=quant_opts_pb2.QuantizationMethod(
+ experimental_method=_ExperimentalMethod.STATIC_RANGE))
+
+ # Try to use a different set of tags to quantize.
+ tags = {tag_constants.SERVING}
+ data_gen = _create_data_generator(
+ input_key='input', shape=input_placeholder.shape)
+ with self.assertRaisesRegex(RuntimeError,
+ 'Failed to retrieve MetaGraphDef'):
+ quantize_model.quantize(
+ input_saved_model_path,
+ signature_keys,
+ tags,
+ output_directory,
+ quantization_options,
+ representative_dataset=data_gen)
@test_util.run_all_in_graph_and_eager_modes
@@ -961,6 +1050,40 @@
# Currently conv is not supported.
self.assertFalse(_contains_quantized_function_call(output_meta_graphdef))
+ def test_conv_model_with_wrong_tags_raises_error(self):
+ input_saved_model_path = self.create_tempdir('input').full_path
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ save_tags = {tag_constants.TRAINING, tag_constants.GPU}
+
+ input_placeholder = _create_and_save_tf1_conv_model(
+ input_saved_model_path,
+ signature_key,
+ save_tags,
+ input_key='input',
+ output_key='output',
+ use_variable=True)
+
+ signature_keys = [signature_key]
+ output_directory = self.create_tempdir().full_path
+
+ quantization_options = quant_opts_pb2.QuantizationOptions(
+ quantization_method=quant_opts_pb2.QuantizationMethod(
+ experimental_method=_ExperimentalMethod.DYNAMIC_RANGE))
+
+ # Try to use a different set of tags to quantize.
+ tags = {tag_constants.SERVING}
+ data_gen = _create_data_generator(
+ input_key='input', shape=input_placeholder.shape)
+ with self.assertRaisesRegex(RuntimeError,
+ 'Failed to retrieve MetaGraphDef'):
+ quantize_model.quantize(
+ input_saved_model_path,
+ signature_keys,
+ tags,
+ output_directory,
+ quantization_options,
+ representative_dataset=data_gen)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
index 7c5e470..23b2828 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
@@ -84,7 +84,13 @@
tags = set([tag_constants.SERVING])
loader = saved_model_loader.SavedModelLoader(saved_model_path)
- meta_graphdef = loader.get_meta_graph_def_from_tags(tags)
+ try:
+ meta_graphdef = loader.get_meta_graph_def_from_tags(tags)
+ except RuntimeError as runtime_error:
+ raise RuntimeError(
+ f'Failed to retrieve MetaGraphDef with tags {tags}'
+ f' from a SavedModel in {saved_model_path}.') from runtime_error
+
signatures = {}
for key, signature_def in meta_graphdef.signature_def.items():
if key == _INIT_OP_SIGNATURE_KEY:
@@ -242,7 +248,7 @@
def _run_graph_for_calibration_graph_mode(
- model_dir: str, signature_keys: List[str],
+ model_dir: str, signature_keys: List[str], tags: Set[str],
representative_dataset: _RepresentativeDataset) -> None:
"""Runs the graph for calibration in graph mode.
@@ -255,6 +261,7 @@
model_dir: Path to SavedModel directory.
signature_keys: A list of signature keys that identifies a function to run
the data samples with.
+ tags: Set of tags identifying the MetaGraphDef within the SavedModel.
representative_dataset: Representative dataset used for calibration.
Raises:
@@ -262,7 +269,7 @@
"""
with session.Session() as sess:
meta_graph: meta_graph_pb2.MetaGraphDef = saved_model_loader.load(
- sess, tags=[tag_constants.SERVING], export_dir=model_dir)
+ sess, tags, export_dir=model_dir)
for sample in representative_dataset():
signature_key, input_data = _get_signature_key_and_input(
@@ -286,7 +293,7 @@
def _run_graph_for_calibration_eager_mode(
- model_dir: str, signature_keys: List[str],
+ model_dir: str, signature_keys: List[str], tags: Set[str],
representative_dataset: _RepresentativeDataset) -> None:
"""Runs the graph for calibration in eager mode.
@@ -299,12 +306,13 @@
model_dir: Path to SavedModel directory.
signature_keys: A list of signature keys that identifies a function to run
the data samples with.
+ tags: Set of tags identifying the MetaGraphDef within the SavedModel.
representative_dataset: Representative dataset used for calibration.
Raises:
ValueError: When the samples in representative dataset is invalid.
"""
- root: autotrackable.AutoTrackable = saved_model_load(model_dir)
+ root: autotrackable.AutoTrackable = saved_model_load(model_dir, tags)
for sample in representative_dataset():
signature_key, input_data = _get_signature_key_and_input(
sample, signature_keys)
@@ -318,11 +326,12 @@
) from ex
-def _static_range_quantize(saved_model_path: str,
- signature_keys=None,
- tags=None,
- output_directory=None,
- representative_dataset=None):
+def _static_range_quantize(
+ saved_model_path: str,
+ signature_keys: List[str],
+ tags: Set[str],
+ output_directory: str,
+ representative_dataset: Optional[_RepresentativeDataset] = None) ->...:
"""Quantizes the given SavedModel via static range quantization.
Args:
@@ -345,6 +354,8 @@
Raises:
ValueError: when representative_dataset is not provided for non-QAT model.
+ RuntimeError: When a MetaGraphDef could not be found associated with `tags`
+ in the SavedModel.
"""
is_qat_saved_model = _is_qat_saved_model(saved_model_path)
signatures = _get_signatures_from_saved_model(saved_model_path,
@@ -390,7 +401,7 @@
"The input SavedModel doesn't contain a valid signature")
v1_builder.add_meta_graph_and_variables(
- sess, [tag_constants.SERVING], signature_def_map=signatures)
+ sess, tags, signature_def_map=signatures)
v1_builder.save()
@@ -401,10 +412,10 @@
try:
if context.executing_eagerly():
_run_graph_for_calibration_eager_mode(float_model_dir, signature_keys,
- representative_dataset)
+ tags, representative_dataset)
else:
_run_graph_for_calibration_graph_mode(float_model_dir, signature_keys,
- representative_dataset)
+ tags, representative_dataset)
except Exception as ex:
raise ValueError(
'Failed to run graph for post-training quantization calibration.'
@@ -435,7 +446,7 @@
graph_def = working_graph.as_graph_def()
v1_builder.add_meta_graph_and_variables(
- sess, [tag_constants.SERVING], signature_def_map=signatures)
+ sess, tags, signature_def_map=signatures)
v1_builder.save()
signatures = _get_signatures_from_saved_model(calibrated_model_dir,
@@ -464,7 +475,7 @@
raise ValueError("The input SavedModel doesn't contain a valid signature")
v1_builder.add_meta_graph_and_variables(
- sess, [tag_constants.SERVING], signature_def_map=signatures)
+ sess, tags, signature_def_map=signatures)
v1_builder.save()
@@ -544,8 +555,8 @@
not provided, this should be a model trained with QAT.
signature_keys: List of keys identifying SignatureDef containing inputs and
outputs. If None, ["serving_default"] is used.
- tags: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. If None, {"serve"} is used.
+ tags: (TF1 SavedModel only) Set of tags identifying the MetaGraphDef within
+ the SavedModel to analyze. If None, {"serve"} is used.
output_directory: The path to save the output SavedModel (must be an empty
directory).
quantization_options: A set of options for quantization.