Fix the TF quantizer to run properly with non-default tags.

TF quantizer's PTQ process raises an error when tags other than `{"serve"}` (the default tag) are used. Fix the quantizer to properly use the user-provided tags.
Non-default tags are used in legacy TF1 to identify a specific meta-graph, when multiple meta-graphs could be stored in a single SavedModel.

PiperOrigin-RevId: 454772314
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
index 828a585..8a0c4b8 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 # ==============================================================================
 """Tests for quantize_model."""
-from typing import List, Mapping, Set, Tuple
+from typing import List, Mapping, Sequence, Set, Tuple
 import warnings
 
 from absl.testing import parameterized
@@ -71,7 +71,7 @@
 
 def _create_simple_tf1_conv_model(
     use_variable_for_filter=False) -> Tuple[core.Tensor, core.Tensor]:
-  """Create a basic convolution model.
+  """Creates a basic convolution model.
 
   This is intended to be used for TF1 (graph mode) tests.
 
@@ -100,6 +100,31 @@
   return in_placeholder, output_tensor
 
 
+def _create_data_generator(
+    input_key: str,
+    shape: Sequence[int],
+    num_examples=128) -> quantize_model._RepresentativeDataset:
+  """Creates a data generator to be used as representative dataset.
+
+  Supports generating random value input tensors mapped by the `input_key`.
+
+  Args:
+    input_key: The string key that identifies the created tensor as an input.
+    shape: Shape of the tensor data.
+    num_examples: Number of examples in the representative dataset.
+
+  Returns:
+    data_gen: A callable that creates a generator of
+    `quantize_model._RepresentativeSample`.
+  """
+
+  def data_gen():
+    for _ in range(num_examples):
+      yield {input_key: random_ops.random_uniform(shape, minval=-1., maxval=1.)}
+
+  return data_gen
+
+
 def _save_tf1_model(sess: session.Session, saved_model_path: str,
                     signature_key: str, tags: Set[str],
                     inputs: Mapping[str, core.Tensor],
@@ -129,7 +154,7 @@
                                     tags: Set[str],
                                     input_key: str,
                                     output_key: str,
-                                    use_variable=False) -> None:
+                                    use_variable=False) -> core.Tensor:
   """Creates and saves a simple convolution model.
 
   This is intended to be used for TF1 (graph mode) tests.
@@ -143,6 +168,9 @@
     output_key: The key to the output tensor.
     use_variable: Setting this to `True` makes the filter for the conv operation
       a `tf.Variable`.
+
+  Returns:
+    in_placeholder: The placeholder tensor used as an input to the model.
   """
   with ops.Graph().as_default(), session.Session() as sess:
     in_placeholder, output_tensor = _create_simple_tf1_conv_model(
@@ -159,6 +187,8 @@
         inputs={input_key: in_placeholder},
         outputs={output_key: output_tensor})
 
+  return in_placeholder
+
 
 @test_util.run_all_in_graph_and_eager_modes
 class QuantizationMethodTest(test.TestCase):
@@ -734,20 +764,11 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_ptq_model_with_tf1_saved_model_with_variable(self):
-
-    def gen_data():
-      for _ in range(255):
-        yield {
-            'x':
-                random_ops.random_uniform(
-                    shape=(1, 3, 4, 3), minval=-6, maxval=6)
-        }
-
     input_saved_model_path = self.create_tempdir('input').full_path
     signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
     tags = {tag_constants.SERVING}
 
-    _create_and_save_tf1_conv_model(
+    input_placeholder = _create_and_save_tf1_conv_model(
         input_saved_model_path,
         signature_key,
         tags,
@@ -762,13 +783,16 @@
         quantization_method=quant_opts_pb2.QuantizationMethod(
             experimental_method=_ExperimentalMethod.STATIC_RANGE))
 
+    data_gen = _create_data_generator(
+        input_key='x', shape=input_placeholder.shape)
+
     converted_model = quantize_model.quantize(
         input_saved_model_path,
         signature_keys,
         tags,
         output_directory,
         quantization_options,
-        representative_dataset=gen_data)
+        representative_dataset=data_gen)
 
     self.assertIsNotNone(converted_model)
     self.assertEqual(
@@ -780,20 +804,11 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_ptq_model_with_tf1_saved_model(self):
-
-    def gen_data():
-      for _ in range(255):
-        yield {
-            'p':
-                random_ops.random_uniform(
-                    shape=(1, 3, 4, 3), minval=0, maxval=150)
-        }
-
     input_saved_model_path = self.create_tempdir('input').full_path
     tags = {tag_constants.SERVING}
     signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
 
-    _create_and_save_tf1_conv_model(
+    input_placeholder = _create_and_save_tf1_conv_model(
         input_saved_model_path,
         signature_key,
         tags,
@@ -808,13 +823,16 @@
         quantization_method=quant_opts_pb2.QuantizationMethod(
             experimental_method=_ExperimentalMethod.STATIC_RANGE))
 
+    data_gen = _create_data_generator(
+        input_key='p', shape=input_placeholder.shape)
+
     converted_model = quantize_model.quantize(
         input_saved_model_path,
         signature_keys,
         tags,
         output_directory,
         quantization_options,
-        representative_dataset=gen_data)
+        representative_dataset=data_gen)
 
     self.assertIsNotNone(converted_model)
     self.assertEqual(
@@ -827,21 +845,11 @@
   @test_util.run_in_graph_and_eager_modes
   def test_ptq_model_with_tf1_saved_model_invalid_input_key_raises_value_error(
       self):
-
-    # Representative generator function that yields with an invalid input key.
-    def gen_data():
-      for _ in range(255):
-        yield {
-            'invalid_input_key':
-                random_ops.random_uniform(
-                    shape=(1, 3, 4, 3), minval=-5, maxval=5)
-        }
-
     input_saved_model_path = self.create_tempdir('input').full_path
     signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
     tags = {tag_constants.SERVING}
 
-    _create_and_save_tf1_conv_model(
+    input_placeholder = _create_and_save_tf1_conv_model(
         input_saved_model_path,
         signature_key,
         tags,
@@ -856,6 +864,10 @@
         quantization_method=quant_opts_pb2.QuantizationMethod(
             experimental_method=_ExperimentalMethod.STATIC_RANGE))
 
+    # Representative generator function that yields with an invalid input key.
+    invalid_data_gen = _create_data_generator(
+        input_key='invalid_input_key', shape=input_placeholder.shape)
+
     with self.assertRaisesRegex(
         ValueError,
         'Failed to run graph for post-training quantization calibration'):
@@ -865,7 +877,84 @@
           tags,
           output_directory,
           quantization_options,
-          representative_dataset=gen_data)
+          representative_dataset=invalid_data_gen)
+
+  @test_util.run_in_graph_and_eager_modes
+  def test_ptq_model_with_non_default_tags(self):
+    input_saved_model_path = self.create_tempdir('input').full_path
+    signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+    # Use a different set of tags other than {"serve"}.
+    tags = {tag_constants.TRAINING, tag_constants.GPU}
+
+    # Non-default tags are usually used when saving multiple metagraphs in TF1.
+    input_placeholder = _create_and_save_tf1_conv_model(
+        input_saved_model_path,
+        signature_key,
+        tags,
+        input_key='input',
+        output_key='output',
+        use_variable=True)
+
+    signature_keys = [signature_key]
+    output_directory = self.create_tempdir().full_path
+
+    quantization_options = quant_opts_pb2.QuantizationOptions(
+        quantization_method=quant_opts_pb2.QuantizationMethod(
+            experimental_method=_ExperimentalMethod.STATIC_RANGE))
+
+    data_gen = _create_data_generator(
+        input_key='input', shape=input_placeholder.shape)
+
+    converted_model = quantize_model.quantize(
+        input_saved_model_path,
+        signature_keys,
+        tags,
+        output_directory,
+        quantization_options,
+        representative_dataset=data_gen)
+
+    self.assertIsNotNone(converted_model)
+    self.assertEqual(
+        list(converted_model.signatures._signatures.keys()), signature_keys)
+
+    output_loader = saved_model_loader.SavedModelLoader(output_directory)
+    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
+    self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
+
+  @test_util.run_in_graph_and_eager_modes
+  def test_ptq_model_with_wrong_tags_raises_error(self):
+    input_saved_model_path = self.create_tempdir('input').full_path
+    signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+    save_tags = {tag_constants.TRAINING, tag_constants.GPU}
+
+    input_placeholder = _create_and_save_tf1_conv_model(
+        input_saved_model_path,
+        signature_key,
+        save_tags,
+        input_key='input',
+        output_key='output',
+        use_variable=True)
+
+    signature_keys = [signature_key]
+    output_directory = self.create_tempdir().full_path
+
+    quantization_options = quant_opts_pb2.QuantizationOptions(
+        quantization_method=quant_opts_pb2.QuantizationMethod(
+            experimental_method=_ExperimentalMethod.STATIC_RANGE))
+
+    # Try to use a different set of tags to quantize.
+    tags = {tag_constants.SERVING}
+    data_gen = _create_data_generator(
+        input_key='input', shape=input_placeholder.shape)
+    with self.assertRaisesRegex(RuntimeError,
+                                'Failed to retrieve MetaGraphDef'):
+      quantize_model.quantize(
+          input_saved_model_path,
+          signature_keys,
+          tags,
+          output_directory,
+          quantization_options,
+          representative_dataset=data_gen)
 
 
 @test_util.run_all_in_graph_and_eager_modes
@@ -961,6 +1050,40 @@
     # Currently conv is not supported.
     self.assertFalse(_contains_quantized_function_call(output_meta_graphdef))
 
+  def test_conv_model_with_wrong_tags_raises_error(self):
+    input_saved_model_path = self.create_tempdir('input').full_path
+    signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+    save_tags = {tag_constants.TRAINING, tag_constants.GPU}
+
+    input_placeholder = _create_and_save_tf1_conv_model(
+        input_saved_model_path,
+        signature_key,
+        save_tags,
+        input_key='input',
+        output_key='output',
+        use_variable=True)
+
+    signature_keys = [signature_key]
+    output_directory = self.create_tempdir().full_path
+
+    quantization_options = quant_opts_pb2.QuantizationOptions(
+        quantization_method=quant_opts_pb2.QuantizationMethod(
+            experimental_method=_ExperimentalMethod.DYNAMIC_RANGE))
+
+    # Try to use a different set of tags to quantize.
+    tags = {tag_constants.SERVING}
+    data_gen = _create_data_generator(
+        input_key='input', shape=input_placeholder.shape)
+    with self.assertRaisesRegex(RuntimeError,
+                                'Failed to retrieve MetaGraphDef'):
+      quantize_model.quantize(
+          input_saved_model_path,
+          signature_keys,
+          tags,
+          output_directory,
+          quantization_options,
+          representative_dataset=data_gen)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
index 7c5e470..23b2828 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
@@ -84,7 +84,13 @@
     tags = set([tag_constants.SERVING])
 
   loader = saved_model_loader.SavedModelLoader(saved_model_path)
-  meta_graphdef = loader.get_meta_graph_def_from_tags(tags)
+  try:
+    meta_graphdef = loader.get_meta_graph_def_from_tags(tags)
+  except RuntimeError as runtime_error:
+    raise RuntimeError(
+        f'Failed to retrieve MetaGraphDef with tags {tags}'
+        f' from a SavedModel in {saved_model_path}.') from runtime_error
+
   signatures = {}
   for key, signature_def in meta_graphdef.signature_def.items():
     if key == _INIT_OP_SIGNATURE_KEY:
@@ -242,7 +248,7 @@
 
 
 def _run_graph_for_calibration_graph_mode(
-    model_dir: str, signature_keys: List[str],
+    model_dir: str, signature_keys: List[str], tags: Set[str],
     representative_dataset: _RepresentativeDataset) -> None:
   """Runs the graph for calibration in graph mode.
 
@@ -255,6 +261,7 @@
     model_dir: Path to SavedModel directory.
     signature_keys: A list of signature keys that identifies a function to run
       the data samples with.
+    tags: Set of tags identifying the MetaGraphDef within the SavedModel.
     representative_dataset: Representative dataset used for calibration.
 
   Raises:
@@ -262,7 +269,7 @@
   """
   with session.Session() as sess:
     meta_graph: meta_graph_pb2.MetaGraphDef = saved_model_loader.load(
-        sess, tags=[tag_constants.SERVING], export_dir=model_dir)
+        sess, tags, export_dir=model_dir)
 
     for sample in representative_dataset():
       signature_key, input_data = _get_signature_key_and_input(
@@ -286,7 +293,7 @@
 
 
 def _run_graph_for_calibration_eager_mode(
-    model_dir: str, signature_keys: List[str],
+    model_dir: str, signature_keys: List[str], tags: Set[str],
     representative_dataset: _RepresentativeDataset) -> None:
   """Runs the graph for calibration in eager mode.
 
@@ -299,12 +306,13 @@
     model_dir: Path to SavedModel directory.
     signature_keys: A list of signature keys that identifies a function to run
       the data samples with.
+    tags: Set of tags identifying the MetaGraphDef within the SavedModel.
     representative_dataset: Representative dataset used for calibration.
 
   Raises:
     ValueError: When the samples in representative dataset is invalid.
   """
-  root: autotrackable.AutoTrackable = saved_model_load(model_dir)
+  root: autotrackable.AutoTrackable = saved_model_load(model_dir, tags)
   for sample in representative_dataset():
     signature_key, input_data = _get_signature_key_and_input(
         sample, signature_keys)
@@ -318,11 +326,12 @@
       ) from ex
 
 
-def _static_range_quantize(saved_model_path: str,
-                           signature_keys=None,
-                           tags=None,
-                           output_directory=None,
-                           representative_dataset=None):
+def _static_range_quantize(
+    saved_model_path: str,
+    signature_keys: List[str],
+    tags: Set[str],
+    output_directory: str,
+    representative_dataset: Optional[_RepresentativeDataset] = None) ->...:
   """Quantizes the given SavedModel via static range quantization.
 
   Args:
@@ -345,6 +354,8 @@
 
   Raises:
     ValueError: when representative_dataset is not provided for non-QAT model.
+    RuntimeError: When a MetaGraphDef could not be found associated with `tags`
+      in the SavedModel.
   """
   is_qat_saved_model = _is_qat_saved_model(saved_model_path)
   signatures = _get_signatures_from_saved_model(saved_model_path,
@@ -390,7 +401,7 @@
             "The input SavedModel doesn't contain a valid signature")
 
       v1_builder.add_meta_graph_and_variables(
-          sess, [tag_constants.SERVING], signature_def_map=signatures)
+          sess, tags, signature_def_map=signatures)
 
     v1_builder.save()
 
@@ -401,10 +412,10 @@
     try:
       if context.executing_eagerly():
         _run_graph_for_calibration_eager_mode(float_model_dir, signature_keys,
-                                              representative_dataset)
+                                              tags, representative_dataset)
       else:
         _run_graph_for_calibration_graph_mode(float_model_dir, signature_keys,
-                                              representative_dataset)
+                                              tags, representative_dataset)
     except Exception as ex:
       raise ValueError(
           'Failed to run graph for post-training quantization calibration.'
@@ -435,7 +446,7 @@
       graph_def = working_graph.as_graph_def()
 
       v1_builder.add_meta_graph_and_variables(
-          sess, [tag_constants.SERVING], signature_def_map=signatures)
+          sess, tags, signature_def_map=signatures)
 
     v1_builder.save()
     signatures = _get_signatures_from_saved_model(calibrated_model_dir,
@@ -464,7 +475,7 @@
       raise ValueError("The input SavedModel doesn't contain a valid signature")
 
     v1_builder.add_meta_graph_and_variables(
-        sess, [tag_constants.SERVING], signature_def_map=signatures)
+        sess, tags, signature_def_map=signatures)
 
   v1_builder.save()
 
@@ -544,8 +555,8 @@
       not provided, this should be a model trained with QAT.
     signature_keys: List of keys identifying SignatureDef containing inputs and
       outputs. If None, ["serving_default"] is used.
-    tags: Set of tags identifying the MetaGraphDef within the SavedModel to
-      analyze. If None, {"serve"} is used.
+    tags: (TF1 SavedModel only) Set of tags identifying the MetaGraphDef within
+      the SavedModel to analyze. If None, {"serve"} is used.
     output_directory: The path to save the output SavedModel (must be an empty
       directory).
     quantization_options: A set of options for quantization.