Check in the Python codes for TensorFlow quantization passes

PiperOrigin-RevId: 431033331
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD
index 7b9480b..61047e7 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD
@@ -22,7 +22,6 @@
         "passes/util.h",
     ],
     deps = [
-        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:QuantOps",
@@ -38,7 +37,6 @@
         "passes/quantize_composite_functions.td",
         "passes/utils.td",
     ],
-    visibility = ["//visibility:private"],
     deps = [
         "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
@@ -123,8 +121,6 @@
         "//tensorflow/core/platform:env",
         "//tensorflow/core/platform:macros",
         "//tensorflow/core/platform:path",
-        "//tensorflow/lite/c:common",
-        "//tensorflow/lite/kernels:padding",
         "//tensorflow/lite/kernels/internal:quantization_util",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/random",
@@ -147,7 +143,6 @@
     deps = [
         ":passes",
         "//tensorflow/compiler/mlir:init_mlir",
-        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
         "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD
index 9560c5b..808be4d 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD
@@ -5,6 +5,7 @@
     "tf_copts",
     "tf_custom_op_library",
     "tf_gen_op_wrapper_py",
+    "tf_py_test",
 )
 
 package(
@@ -74,14 +75,16 @@
     ],
 )
 
-# TODO(b/220688154): Re-enable test once the quantize model wrapper is ready.
-#py_test(
-#    name = "custom_aggregator_op_test",
-#    size = "small",
-#    srcs = ["custom_aggregator_op_test.py"],
-#    deps = [
-#        ":custom_aggregator_op",
-#        "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_model_wrapper",
-#        "//tensorflow:tensorflow_py",
-#    ],
-#)
+tf_py_test(
+    name = "custom_aggregator_op_test",
+    size = "small",
+    srcs = ["custom_aggregator_op_test.py"],
+    tags = [
+        "no_oss",  # TODO(b/220688154): Enable OSS tests.
+    ],
+    deps = [
+        ":custom_aggregator_op",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_wrapper",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py
index 1a82541..e7e8a5d 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py
@@ -14,12 +14,10 @@
 # ==============================================================================
 """Custom Aggregator op is for collecting numeric metrics from the given input."""
 
-# pylint: disable=g-direct-tensorflow-import
 from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import custom_aggregator_op_wrapper
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import load_library
 from tensorflow.python.platform import resource_loader
-# pylint: enable=g-direct-tensorflow-import
 
 _custom_aggregator_op = load_library.load_op_library(
     resource_loader.get_path_to_datafile('_custom_aggregator_op.so'))
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op_test.py
index 23c2195..06bc5d0 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op_test.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op_test.py
@@ -14,16 +14,16 @@
 # ==============================================================================
 """Tests for Custom Aggregator op."""
 
-# pylint:disable=g-direct-tensorflow-import
-from tensorflow.compiler.mlir.quantization.tensorflow import convert_model_wrapper
+
 from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import custom_aggregator_op
-from tensorflow.python import tf
+from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model_wrapper
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-
-# pylint:enable=g-direct-tensorflow-import
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
 
 
-class CustomAggregatorTest(tf.test.TestCase):
+class CustomAggregatorTest(test.TestCase):
 
   def setUp(self):
     super(CustomAggregatorTest, self).setUp()
@@ -31,54 +31,59 @@
 
   def testBypassAndMinMax(self):
     with self.test_session():
-      convert_model_wrapper.clear_calibrator()
-      input_tensor = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0], tf.float32)
+      quantize_model_wrapper.clear_calibrator()
+      input_tensor = array_ops.constant([1.0, 2.0, 3.0, 4.0, 5.0],
+                                        dtypes.float32)
       aggregator = custom_aggregator_op.custom_aggregator(
           input_tensor, tensor_id='1')
       self.assertAllEqual(aggregator.eval(), [1.0, 2.0, 3.0, 4.0, 5.0])
-      min_max = convert_model_wrapper.get_min_max_from_calibrator('1')
+      min_max = quantize_model_wrapper.get_min_max_from_calibrator('1')
       self.assertAllEqual(min_max, (1.0, 5.0))
 
   def testTwoIdentities(self):
     with self.test_session():
-      convert_model_wrapper.clear_calibrator()
-      input_tensor1 = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0], tf.float32)
+      quantize_model_wrapper.clear_calibrator()
+      input_tensor1 = array_ops.constant([1.0, 2.0, 3.0, 4.0, 5.0],
+                                         dtypes.float32)
       aggregator1 = custom_aggregator_op.custom_aggregator(
           input_tensor1, tensor_id='2')
       self.assertAllEqual(aggregator1.eval(), [1.0, 2.0, 3.0, 4.0, 5.0])
-      input_tensor2 = tf.constant([-1.0, -2.0, -3.0, -4.0, -5.0], tf.float32)
+      input_tensor2 = array_ops.constant([-1.0, -2.0, -3.0, -4.0, -5.0],
+                                         dtypes.float32)
       aggregator2 = custom_aggregator_op.custom_aggregator(
           input_tensor2, tensor_id='3')
       self.assertAllEqual(aggregator2.eval(), [-1.0, -2.0, -3.0, -4.0, -5.0])
 
-      min_max = convert_model_wrapper.get_min_max_from_calibrator('2')
+      min_max = quantize_model_wrapper.get_min_max_from_calibrator('2')
       self.assertAllEqual(min_max, (1.0, 5.0))
-      min_max = convert_model_wrapper.get_min_max_from_calibrator('3')
+      min_max = quantize_model_wrapper.get_min_max_from_calibrator('3')
       self.assertAllEqual(min_max, (-5.0, -1.0))
 
   def testClearData(self):
     with self.test_session():
-      convert_model_wrapper.clear_calibrator()
-      input_tensor1 = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0], tf.float32)
+      quantize_model_wrapper.clear_calibrator()
+      input_tensor1 = array_ops.constant([1.0, 2.0, 3.0, 4.0, 5.0],
+                                         dtypes.float32)
       aggregator1 = custom_aggregator_op.custom_aggregator(
           input_tensor1, tensor_id='4')
       self.assertAllEqual(aggregator1.eval(), [1.0, 2.0, 3.0, 4.0, 5.0])
-      input_tensor2 = tf.constant([-1.0, -2.0, -3.0, -4.0, -5.0], tf.float32)
+      input_tensor2 = array_ops.constant([-1.0, -2.0, -3.0, -4.0, -5.0],
+                                         dtypes.float32)
       aggregator2 = custom_aggregator_op.custom_aggregator(
           input_tensor2, tensor_id='5')
       self.assertAllEqual(aggregator2.eval(), [-1.0, -2.0, -3.0, -4.0, -5.0])
 
-      min_max = convert_model_wrapper.get_min_max_from_calibrator('4')
+      min_max = quantize_model_wrapper.get_min_max_from_calibrator('4')
       self.assertAllEqual(min_max, (1.0, 5.0))
-      min_max = convert_model_wrapper.get_min_max_from_calibrator('5')
+      min_max = quantize_model_wrapper.get_min_max_from_calibrator('5')
       self.assertAllEqual(min_max, (-5.0, -1.0))
 
-      convert_model_wrapper.clear_data_from_calibrator('4')
+      quantize_model_wrapper.clear_data_from_calibrator('4')
       with self.assertRaises(ValueError):
-        convert_model_wrapper.get_min_max_from_calibrator('4')
-      min_max = convert_model_wrapper.get_min_max_from_calibrator('5')
+        quantize_model_wrapper.get_min_max_from_calibrator('4')
+      min_max = quantize_model_wrapper.get_min_max_from_calibrator('5')
       self.assertAllEqual(min_max, (-5.0, -1.0))
 
 
 if __name__ == '__main__':
-  tf.test.main()
+  test.main()
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc
index d7aa55c..0a0802b 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc
@@ -21,7 +21,6 @@
 #include "mlir/InitAllPasses.h"  // from @llvm-project
 #include "mlir/Support/MlirOptMain.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/init_mlir.h"
-#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
@@ -33,12 +32,11 @@
   mlir::registerTensorFlowPasses();
 
   mlir::DialectRegistry registry;
-  registry
-      .insert<mlir::scf::SCFDialect, mlir::TF::TensorFlowDialect,
-              mlir::tf_saved_model::TensorFlowSavedModelDialect,
-              mlir::StandardOpsDialect, mlir::shape::ShapeDialect,
-              mlir::arith::ArithmeticDialect, mlir::quant::QuantizationDialect,
-              mlir::TFL::TensorFlowLiteDialect>();
+  registry.insert<mlir::scf::SCFDialect, mlir::TF::TensorFlowDialect,
+                  mlir::tf_saved_model::TensorFlowSavedModelDialect,
+                  mlir::StandardOpsDialect, mlir::shape::ShapeDialect,
+                  mlir::arith::ArithmeticDialect,
+                  mlir::quant::QuantizationDialect>();
   return failed(
       mlir::MlirOptMain(argc, argv, "TF quant Pass Driver\n", registry));
 }
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/util.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/util.cc
index 2083c84..9cc2d1c 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/util.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/util.cc
@@ -14,14 +14,13 @@
 ==============================================================================*/
 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/util.h"
 
-#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
 
 namespace mlir {
 namespace quant {
 
 bool HasQuantizedTensors(Operation *op) {
-  if (!isa<TFL::QuantizeOp>(op) && IsOpNotQuantizable(op)) return false;
+  if (IsOpNotQuantizable(op)) return false;
   for (Type operand_type : op->getOperandTypes()) {
     auto tensor_type = operand_type.dyn_cast<TensorType>();
     if (tensor_type && tensor_type.getElementType().isa<QuantizedType>()) {
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
new file mode 100644
index 0000000..e3f6180
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
@@ -0,0 +1,125 @@
+load("//tensorflow:tensorflow.bzl", "pytype_strict_library")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
+
+package(
+    default_visibility = [
+        "//tensorflow/compiler/mlir/quantization:__subpackages__",
+    ],
+    licenses = ["notice"],
+)
+
+cc_library(
+    name = "quantize_model_lib",
+    srcs = [
+        "quantize_model.cc",
+    ],
+    hdrs = [
+        "quantize_model.h",
+    ],
+    deps = [
+        "//tensorflow/cc/saved_model:loader",
+        "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
+        "//tensorflow/compiler/mlir/quantization/tensorflow:passes",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:error_util",
+        "//tensorflow/compiler/mlir/tensorflow:export_graphdef",
+        "//tensorflow/compiler/mlir/tensorflow:mlir_import_options",
+        "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
+        "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
+        "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
+        "//tensorflow/compiler/mlir/tensorflow:translate_lib",
+        "//tensorflow/core:core_cpu_base",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal_impl",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:path",
+        "//tensorflow/core/platform:statusor",
+        "//tensorflow/core/platform:stringpiece",
+        "@com_google_absl//absl/status",
+        "@com_google_absl//absl/status:statusor",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:QuantOps",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:Shape",
+        "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "quantize_model_wrapper",
+    srcs = [
+        "quantize_model_wrapper.cc",
+    ],
+    module_name = "quantize_model_wrapper",
+    deps = [
+        ":quantize_model_lib",
+        "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton",
+        "//tensorflow/lite/python/interpreter_wrapper:python_utils",
+        "//tensorflow/python/lib/core:pybind11_lib",
+        "//tensorflow/python/lib/core:pybind11_status",
+        "@com_google_absl//absl/strings",
+        "@pybind11",
+    ],
+)
+
+pytype_strict_library(
+    name = "quantize_model",
+    srcs = [
+        "quantize_model.py",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":quantize_model_wrapper",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op",
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python/client:session",
+        "//tensorflow/python/saved_model:builder",
+        "//tensorflow/python/saved_model:load",
+        "//tensorflow/python/saved_model:loader",
+        "//tensorflow/python/saved_model:signature_constants",
+        "//tensorflow/python/saved_model:tag_constants",
+        "//tensorflow/python/saved_model:utils",
+        "@flatbuffers//:runtime_py",
+    ],
+)
+
+tf_py_test(
+    name = "quantize_model_test",
+    size = "medium",
+    srcs = ["integration_test/quantize_model_test.py"],
+    tags = [
+        "no_oss",  # TODO(b/220688154): Enable OSS tests.
+    ],
+    deps = [
+        ":quantize_model",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+tf_py_test(
+    name = "concurrency_test",
+    size = "medium",
+    srcs = ["integration_test/concurrency_test.py"],
+    tags = [
+        "no_oss",  # TODO(b/220688154): Enable OSS tests.
+    ],
+    deps = [
+        ":quantize_model",
+        "//tensorflow:tensorflow_py",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py
new file mode 100644
index 0000000..5cc4139
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py
@@ -0,0 +1,89 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Concurrency tests for quantize_model."""
+
+from concurrent import futures
+
+import numpy as np
+import tensorflow as tf  # pylint: disable=unused-import
+
+from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import save as saved_model_save
+from tensorflow.python.saved_model.saved_model import tag_constants
+from tensorflow.python.training.tracking import tracking
+
+
+class MultiThreadedTest(test.TestCase):
+  """Tests involving multiple threads."""
+
+  def setUp(self):
+    super(MultiThreadedTest, self).setUp()
+    self.pool = futures.ThreadPoolExecutor(max_workers=4)
+
+  def _convert_with_calibration(self):
+
+    class ModelWithAdd(tracking.AutoTrackable):
+      """Basic model with addition."""
+
+      @def_function.function(input_signature=[
+          tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='x'),
+          tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='y')
+      ])
+      def add(self, x, y):
+        res = math_ops.add(x, y)
+        return {'output': res}
+
+    def data_gen():
+      for _ in range(255):
+        yield {
+            'x':
+                ops.convert_to_tensor(
+                    np.random.uniform(size=(10)).astype('f4')),
+            'y':
+                ops.convert_to_tensor(
+                    np.random.uniform(size=(10)).astype('f4'))
+        }
+
+    root = ModelWithAdd()
+
+    temp_path = self.create_tempdir().full_path
+    saved_model_save.save(
+        root, temp_path, signatures=root.add.get_concrete_function())
+
+    model = quantize_model.quantize(
+        temp_path, ['serving_default'], [tag_constants.SERVING],
+        representative_dataset=data_gen)
+    self.assertIsNotNone(model)
+
+  def testMultipleConversionJobsWithCalibration(self):
+    # Ensure that multiple conversion jobs with calibration won't encounter any
+    # concurrency issue.
+    with self.pool:
+      jobs = []
+      for _ in range(10):
+        jobs.append(self.pool.submit(self._convert_with_calibration))
+
+      for job in jobs:
+        job.result()
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
new file mode 100644
index 0000000..9c86f5e
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
@@ -0,0 +1,271 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for quantize_model."""
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf  # pylint: disable=unused-import
+
+from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.module import module
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import loader_impl as saved_model_loader
+from tensorflow.python.saved_model import save as saved_model_save
+from tensorflow.python.saved_model.saved_model import signature_constants
+from tensorflow.python.saved_model.saved_model import tag_constants
+from tensorflow.python.training.tracking import tracking
+
+
+def _contains_quantized_function_call(meta_graphdef):
+  """Returns true if the graph def has quantized function call."""
+  for func in meta_graphdef.graph_def.library.function:
+    if func.signature.name.startswith('quantized_'):
+      return True
+  return False
+
+
+class QuantizationTest(test.TestCase, parameterized.TestCase):
+
+  def test_qat_model(self):
+
+    class QATModelWithAdd(tracking.AutoTrackable):
+      """Basic model with Fake quant + add."""
+
+      @def_function.function(input_signature=[
+          tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='x'),
+          tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='y')
+      ])
+      def add(self, x, y):
+        float_res = math_ops.add(x, y)
+        x = array_ops.fake_quant_with_min_max_args(
+            x, min=-0.1, max=0.2, num_bits=8, narrow_range=False)
+        y = array_ops.fake_quant_with_min_max_args(
+            y, min=-0.3, max=0.4, num_bits=8, narrow_range=False)
+        res = math_ops.add(x, y)
+        res = array_ops.fake_quant_with_min_max_args(
+            res, min=-0.4, max=0.6, num_bits=8, narrow_range=False)
+        return {'output': res, 'float_output': float_res}
+
+    root = QATModelWithAdd()
+
+    temp_path = self.create_tempdir().full_path
+    saved_model_save.save(
+        root, temp_path, signatures=root.add.get_concrete_function())
+
+    output_directory = self.create_tempdir().full_path
+    tags = [tag_constants.SERVING]
+    model = quantize_model.quantize(temp_path, ['serving_default'],
+                                    [tag_constants.SERVING], output_directory)
+    self.assertIsNotNone(model)
+    self.assertEqual(
+        list(model.signatures._signatures.keys()), ['serving_default'])
+    func = model.signatures['serving_default']
+    func_res = func(
+        x=array_ops.constant(0.1, shape=[10]),
+        y=array_ops.constant(0.1, shape=[10]))
+    self.assertAllClose(
+        func_res['output'], array_ops.constant(0.2, shape=[10]), atol=0.01)
+    self.assertAllClose(
+        func_res['float_output'],
+        array_ops.constant(0.2, shape=[10]),
+        atol=1e-3)
+
+    output_loader = saved_model_loader.SavedModelLoader(output_directory)
+    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
+    self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
+
+  def test_ptq_model(self):
+
+    class PTQModelWithAdd(tracking.AutoTrackable):
+      """Basic model with addition."""
+
+      @def_function.function(input_signature=[
+          tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='x'),
+          tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='y')
+      ])
+      def add(self, x, y):
+        res = math_ops.add(x, y)
+        return {'output': res, 'x': x, 'y': y}
+
+    def data_gen():
+      for _ in range(255):
+        yield {
+            'x':
+                ops.convert_to_tensor(
+                    np.random.uniform(size=(10)).astype('f4')),
+            'y':
+                ops.convert_to_tensor(
+                    np.random.uniform(size=(10)).astype('f4'))
+        }
+
+    root = PTQModelWithAdd()
+
+    temp_path = self.create_tempdir().full_path
+    saved_model_save.save(
+        root, temp_path, signatures=root.add.get_concrete_function())
+
+    output_directory = self.create_tempdir().full_path
+    tags = [tag_constants.SERVING]
+    model = quantize_model.quantize(
+        temp_path, ['serving_default'],
+        tags,
+        output_directory,
+        representative_dataset=data_gen)
+    self.assertIsNotNone(model)
+    self.assertEqual(
+        list(model.signatures._signatures.keys()), ['serving_default'])
+    func = model.signatures['serving_default']
+    func_res = func(
+        x=array_ops.constant(0.1, shape=[10]),
+        y=array_ops.constant(0.1, shape=[10]))
+    self.assertAllClose(
+        func_res['output'], array_ops.constant(0.2, shape=[10]), atol=0.01)
+    xy_atol = 1e-6
+    self.assertAllClose(
+        func_res['x'], array_ops.constant(0.1, shape=[10]), atol=xy_atol)
+    self.assertAllClose(
+        func_res['y'], array_ops.constant(0.1, shape=[10]), atol=xy_atol)
+
+    output_loader = saved_model_loader.SavedModelLoader(output_directory)
+    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
+    self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
+
+  @parameterized.named_parameters(
+      ('none', None),
+      ('relu', nn_ops.relu),
+      ('relu6', nn_ops.relu6),
+  )
+  def test_qat_conv_model(self, activation_fn):
+
+    class ConvModel(module.Module):
+
+      @def_function.function(input_signature=[
+          tensor_spec.TensorSpec(
+              name='input', shape=[1, 3, 4, 3], dtype=dtypes.float32),
+          tensor_spec.TensorSpec(
+              name='filter', shape=[2, 3, 3, 2], dtype=dtypes.float32),
+      ])
+      def conv(self, input_tensor, filter_tensor):
+        q_input = array_ops.fake_quant_with_min_max_args(
+            input_tensor, min=-0.1, max=0.2, num_bits=8, narrow_range=False)
+        q_filters = array_ops.fake_quant_with_min_max_args(
+            filter_tensor, min=-1.0, max=2.0, num_bits=8, narrow_range=False)
+        bias = array_ops.constant([0, 0], dtype=dtypes.float32)
+        out = nn_ops.conv2d(
+            q_input,
+            q_filters,
+            strides=[1, 1, 2, 1],
+            dilations=[1, 1, 1, 1],
+            padding='SAME',
+            data_format='NHWC')
+        out = nn_ops.bias_add(out, bias, data_format='NHWC')
+        if activation_fn is not None:
+          out = activation_fn(out)
+        q_out = array_ops.fake_quant_with_min_max_args(
+            out, min=-0.3, max=0.4, num_bits=8, narrow_range=False)
+        return {'output': q_out}
+
+    model = ConvModel()
+    input_saved_model_path = self.create_tempdir('input').full_path
+    saved_model_save.save(model, input_saved_model_path)
+
+    signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+    tags = [tag_constants.SERVING]
+    output_directory = self.create_tempdir().full_path
+    converted_model = quantize_model.quantize(
+        input_saved_model_path, [signature_key],
+        tags,
+        output_directory=output_directory)
+    self.assertIsNotNone(converted_model)
+    self.assertEqual(
+        list(converted_model.signatures._signatures.keys()), [signature_key])
+
+    input_data = np.random.uniform(
+        low=-0.1, high=0.2, size=(1, 3, 4, 3)).astype('f4')
+    filter_data = np.random.uniform(
+        low=-0.5, high=0.5, size=(2, 3, 3, 2)).astype('f4')
+
+    expected_outputs = model.conv(input_data, filter_data)
+    got_outputs = converted_model.signatures[signature_key](
+        input=ops.convert_to_tensor(input_data),
+        filter=ops.convert_to_tensor(filter_data))
+    # TODO(b/215633216): Check if the accuracy is acceptable.
+    self.assertAllClose(expected_outputs, got_outputs, atol=0.01)
+
+    output_loader = saved_model_loader.SavedModelLoader(output_directory)
+    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
+    self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
+
+  def test_conv_ptq_model(self):
+
+    class ConvModel(module.Module):
+
+      @def_function.function(input_signature=[
+          tensor_spec.TensorSpec(shape=[1, 3, 4, 3], dtype=dtypes.float32)
+      ])
+      def conv(self, input_tensor):
+        filters = np.random.uniform(
+            low=-10, high=10, size=(2, 3, 3, 2)).astype('f4')
+        bias = np.random.uniform(low=0, high=10, size=(2)).astype('f4')
+        out = nn_ops.conv2d(
+            input_tensor,
+            filters,
+            strides=[1, 1, 2, 1],
+            dilations=[1, 1, 1, 1],
+            padding='SAME',
+            data_format='NHWC')
+        out = nn_ops.bias_add(out, bias, data_format='NHWC')
+        out = nn_ops.relu6(out)
+        return {'output': out}
+
+    model = ConvModel()
+    input_saved_model_path = self.create_tempdir('input').full_path
+    saved_model_save.save(model, input_saved_model_path)
+
+    def data_gen():
+      for _ in range(255):
+        yield {
+            'input_tensor':
+                ops.convert_to_tensor(
+                    np.random.uniform(low=0, high=150,
+                                      size=(1, 3, 4, 3)).astype('f4')),
+        }
+
+    tags = [tag_constants.SERVING]
+    output_directory = self.create_tempdir().full_path
+    converted_model = quantize_model.quantize(
+        input_saved_model_path, ['serving_default'],
+        tags,
+        output_directory,
+        representative_dataset=data_gen)
+    self.assertIsNotNone(converted_model)
+    self.assertEqual(
+        list(converted_model.signatures._signatures.keys()),
+        ['serving_default'])
+
+    output_loader = saved_model_loader.SavedModelLoader(output_directory)
+    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
+    self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc
new file mode 100644
index 0000000..e3763a2
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc
@@ -0,0 +1,268 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h"
+
+#include <memory>
+#include <string>
+#include <string_view>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
+#include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
+#include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/IR/Location.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/Parser.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Transforms/Passes.h"  // from @llvm-project
+#include "tensorflow/cc/saved_model/loader.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
+#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
+#include "tensorflow/core/common_runtime/graph_constructor.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/path.h"
+#include "tensorflow/core/platform/statusor.h"
+#include "tensorflow/core/platform/stringpiece.h"
+#include "tensorflow/core/util/env_var.h"
+
+using tensorflow::FunctionDefLibrary;
+using tensorflow::Graph;
+using tensorflow::GraphDef;
+using tensorflow::ImportGraphDefOptions;
+using tensorflow::OpRegistry;
+
+namespace mlir {
+namespace quant {
+
+absl::StatusOr<tensorflow::GraphDef> QuantizeQATModel(
+    absl::string_view saved_model_path, absl::string_view exported_names_str,
+    absl::string_view tags) {
+  const std::unordered_set<std::string> tag_set =
+      absl::StrSplit(tags, ',', absl::SkipEmpty());
+  std::vector<std::string> exported_names_vec =
+      absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
+
+  // Convert the SavedModelBundle to an MLIR module.
+  DialectRegistry registry;
+  registry.insert<StandardOpsDialect, scf::SCFDialect,
+                  tf_saved_model::TensorFlowSavedModelDialect,
+                  TF::TensorFlowDialect, shape::ShapeDialect,
+                  QuantizationDialect>();
+  MLIRContext context(registry);
+
+  tensorflow::MLIRImportOptions import_options;
+  import_options.upgrade_legacy = true;
+  auto bundle = std::make_unique<tensorflow::SavedModelBundle>();
+
+  // TODO(b/213406917): Add support for the object graph based saved model input
+  auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
+      saved_model_path, tag_set, absl::Span<std::string>(exported_names_vec),
+      &context, import_options, /*lift_variables=*/true, &bundle);
+
+  if (!module_or.status().ok()) {
+    return absl::InternalError("failed to import SavedModel: " +
+                               module_or.status().error_message());
+  }
+
+  OwningOpRef<mlir::ModuleOp> moduleRef = module_or.ConsumeValueOrDie();
+
+  PassManager pm(&context);
+
+  std::string error;
+  llvm::raw_string_ostream error_stream(error);
+
+  pm.addPass(createCanonicalizerPass());
+  // Freezes constants so that FakeQuant ops can reference quantization ranges.
+  pm.addPass(tf_saved_model::CreateOptimizeGlobalTensorsPass());
+  pm.addPass(mlir::createInlinerPass());
+  pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+  pm.addPass(tf_saved_model::CreateFreezeGlobalTensorsPass());
+
+  pm.addNestedPass<FuncOp>(CreateConvertFakeQuantToQdqPass());
+  pm.addNestedPass<FuncOp>(TF::CreateFusedKernelMatcherPass());
+  pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass());
+  pm.addPass(CreateInsertQuantizedFunctionsPass());
+  pm.addPass(CreateQuantizeCompositeFunctionsPass());
+  pm.addPass(mlir::createSymbolDCEPass());
+
+  pm.addPass(CreateInsertMainFunctionPass());
+  pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
+  pm.addPass(CreateBreakUpIslandsPass());
+
+  mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
+  if (failed(pm.run(*moduleRef))) {
+    return absl::InternalError(
+        "failed to apply the quantization: " +
+        diagnostic_handler.ConsumeStatus().error_message());
+  }
+
+  // Export as GraphDef.
+  tensorflow::GraphExportConfig confs;
+  auto graph_or = tensorflow::ConvertMlirToGraphdef(*moduleRef, confs);
+  if (!graph_or.ok()) {
+    return absl::InternalError("failed to convert MLIR to graphdef: " +
+                               graph_or.status().error_message());
+  }
+
+  return *graph_or.ConsumeValueOrDie();
+}
+
+absl::StatusOr<tensorflow::GraphDef> QuantizePTQModelPreCalibration(
+    absl::string_view saved_model_path, absl::string_view exported_names_str,
+    absl::string_view tags) {
+  const std::unordered_set<std::string> tag_set =
+      absl::StrSplit(tags, ',', absl::SkipEmpty());
+  std::vector<std::string> exported_names_vec =
+      absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
+
+  // Convert the SavedModelBundle to an MLIR module.
+  DialectRegistry registry;
+  registry.insert<StandardOpsDialect, scf::SCFDialect,
+                  tf_saved_model::TensorFlowSavedModelDialect,
+                  TF::TensorFlowDialect, shape::ShapeDialect,
+                  QuantizationDialect>();
+  MLIRContext context(registry);
+
+  tensorflow::MLIRImportOptions import_options;
+  import_options.upgrade_legacy = true;
+  auto bundle = std::make_unique<tensorflow::SavedModelBundle>();
+
+  // TODO(b/213406917): Add support for the object graph based saved model input
+  auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
+      saved_model_path, tag_set, absl::Span<std::string>(exported_names_vec),
+      &context, import_options,
+      /*lift_variables=*/true, &bundle);
+
+  if (!module_or.status().ok()) {
+    return absl::InternalError("failed to import SavedModel: " +
+                               module_or.status().error_message());
+  }
+
+  OwningOpRef<mlir::ModuleOp> moduleRef = module_or.ConsumeValueOrDie();
+
+  PassManager pm(&context);
+
+  pm.addPass(createCanonicalizerPass());
+  pm.addNestedPass<FuncOp>(TF::CreateFusedKernelMatcherPass());
+  pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass());
+  pm.addNestedPass<FuncOp>(CreateInsertCustomAggregationOpsPass());
+  pm.addPass(CreateIssueIDsOfCustomAggregationOpsPass());
+  pm.addPass(CreateInsertMainFunctionPass());
+  pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
+  pm.addPass(CreateBreakUpIslandsPass());
+
+  mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
+  if (failed(pm.run(*moduleRef))) {
+    return absl::InternalError(
+        "failed to apply the quantization at the pre-calibration stage: " +
+        diagnostic_handler.ConsumeStatus().error_message());
+  }
+
+  // Export as GraphDef.
+  tensorflow::GraphExportConfig confs;
+  auto graph_or = tensorflow::ConvertMlirToGraphdef(*moduleRef, confs);
+  if (!graph_or.ok()) {
+    return absl::InternalError("failed to convert MLIR to graphdef: " +
+                               graph_or.status().error_message());
+  }
+
+  return *graph_or.ConsumeValueOrDie();
+}
+
+absl::StatusOr<tensorflow::GraphDef> QuantizePTQModelPostCalibration(
+    absl::string_view saved_model_path, absl::string_view exported_names_str,
+    absl::string_view tags) {
+  const std::unordered_set<std::string> tag_set =
+      absl::StrSplit(tags, ',', absl::SkipEmpty());
+  std::vector<std::string> exported_names_vec =
+      absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
+
+  // Convert the SavedModelBundle to an MLIR module.
+  DialectRegistry registry;
+  registry.insert<StandardOpsDialect, scf::SCFDialect,
+                  tf_saved_model::TensorFlowSavedModelDialect,
+                  TF::TensorFlowDialect, shape::ShapeDialect,
+                  QuantizationDialect>();
+  MLIRContext context(registry);
+
+  tensorflow::MLIRImportOptions import_options;
+  import_options.upgrade_legacy = true;
+  auto bundle = std::make_unique<tensorflow::SavedModelBundle>();
+
+  // TODO(b/213406917): Add support for the object graph based saved model input
+  auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
+      saved_model_path, tag_set, absl::Span<std::string>(exported_names_vec),
+      &context, import_options,
+      /*lift_variables=*/true, &bundle);
+
+  if (!module_or.status().ok()) {
+    return absl::InternalError("failed to import SavedModel: " +
+                               module_or.status().error_message());
+  }
+
+  OwningOpRef<mlir::ModuleOp> moduleRef = module_or.ConsumeValueOrDie();
+
+  PassManager pm(&context);
+
+  pm.addPass(createCanonicalizerPass());
+  pm.addNestedPass<FuncOp>(CreateConvertCustomAggregationOpToQuantStatsPass());
+  pm.addPass(CreateInsertQuantizedFunctionsPass());
+  pm.addPass(CreateQuantizeCompositeFunctionsPass());
+  pm.addPass(mlir::createSymbolDCEPass());
+  pm.addPass(CreateInsertMainFunctionPass());
+  pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
+  pm.addPass(CreateBreakUpIslandsPass());
+
+  mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
+  if (failed(pm.run(*moduleRef))) {
+    return absl::InternalError(
+        "failed to apply the quantization at the post-calibation stage: " +
+        diagnostic_handler.ConsumeStatus().error_message());
+  }
+
+  // Export as GraphDef.
+  tensorflow::GraphExportConfig confs;
+  auto graph_or = tensorflow::ConvertMlirToGraphdef(*moduleRef, confs);
+  if (!graph_or.ok()) {
+    return absl::InternalError("failed to convert MLIR to graphdef: " +
+                               graph_or.status().error_message());
+  }
+
+  return *graph_or.ConsumeValueOrDie();
+}
+
+}  // namespace quant
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h
new file mode 100644
index 0000000..a40d486
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h
@@ -0,0 +1,47 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_MODEL_H_
+#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_MODEL_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/framework/graph.pb.h"
+
+namespace mlir {
+namespace quant {
+
+// Quantizes the given QAT-enabled model.
+absl::StatusOr<tensorflow::GraphDef> QuantizeQATModel(
+    absl::string_view saved_model_path, absl::string_view exported_names_str,
+    absl::string_view tags);
+
+// Quantizes the given model with post-training quantization. This method covers
+// the part for the pre-calibration stage.
+absl::StatusOr<tensorflow::GraphDef> QuantizePTQModelPreCalibration(
+    absl::string_view saved_model_path, absl::string_view exported_names_str,
+    absl::string_view tags);
+
+// Quantizes the given model with post-training quantization. This method covers
+// the part for the post-calibration stage.
+absl::StatusOr<tensorflow::GraphDef> QuantizePTQModelPostCalibration(
+    absl::string_view saved_model_path, absl::string_view exported_names_str,
+    absl::string_view tags);
+}  // namespace quant
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_MODEL_H_
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
new file mode 100644
index 0000000..469c8f4
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
@@ -0,0 +1,291 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines TF Quantization API from SavedModel to SavedModel."""
+
+import tempfile
+import uuid
+import warnings
+
+from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model_wrapper
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader_impl as saved_model_loader
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.saved_model.load import load as saved_model_load
+
+
+# The signature key of the saved model init op.
+_INIT_OP_SIGNATURE_KEY = '__saved_model_init_op'
+
+
+def _legalize_tensor_name(tensor_name: str) -> str:
+  """Converts tensor name from 'name:index' to 'name__index' format."""
+  return tensor_name.replace(':', '__')
+
+
+def _is_qat_saved_model(saved_model_path: str):
+  """Checks if the SavedModel is QAT-enabled by looking for 'FakeQuant' ops."""
+  saved_model_proto = saved_model_loader.parse_saved_model(saved_model_path)
+  for meta_graph in saved_model_proto.meta_graphs:
+    if any(
+        node.op.startswith('FakeQuant') for node in meta_graph.graph_def.node):
+      return True
+    for function in meta_graph.graph_def.library.function:
+      if any(node.op.startswith('FakeQuant') for node in function.node_def):
+        return True
+  return False
+
+
+def _get_signatures_from_saved_model(saved_model_path: str,
+                                     signature_keys=None,
+                                     tags=None):
+  """Gets a map from signature keys to their SignatureDef from a saved model."""
+  if tags is None:
+    tags = set([tag_constants.SERVING])
+
+  loader = saved_model_loader.SavedModelLoader(saved_model_path)
+  meta_graphdef = loader.get_meta_graph_def_from_tags(tags)
+  signatures = {}
+  for key, signature_def in meta_graphdef.signature_def.items():
+    if key == _INIT_OP_SIGNATURE_KEY:
+      continue
+    if signature_keys is not None and key not in signature_keys:
+      continue
+    signatures[key] = signature_def
+
+  return signatures
+
+
+def _fix_tensor_names(signatures, exported_graph):
+  """Tries fixing tensor names in the signatures to match the exported graph.
+
+  The output tensor names in the original graph usually become names of the
+  return nodes in the exported graph. This function tries to fix that and checks
+  if the input tensor names are found in the exported graph.
+
+  Args:
+    signatures: the signatures of the original graph.
+    exported_graph: The PTQ-exported GraphDef.
+
+  Returns:
+    Fixed signatures or None if it couldn't be fixed.
+  """
+  if signatures is None:
+    return None
+
+  # The InsertMainFunctionPass populates input and output nodes of the newly
+  # inserted main function with "tf_saved_model.index_path" attributes. These
+  # attributes can be used to identify outputs in the exported graph.
+  output_index_path_map = {}
+  for op in exported_graph.get_operations():
+    if (op.type == '_Retval' and
+        op.get_attr('tf_saved_model.index_path') is not None):
+      index_path_name = op.get_attr('tf_saved_model.index_path')[0]
+      index_path_name = index_path_name.decode('utf-8')
+      output_index_path_map[index_path_name] = op.inputs[0].name
+
+  for signature_def in signatures.values():
+    for tensor_info in signature_def.inputs.values():
+      try:
+        exported_graph.get_tensor_by_name(tensor_info.name)
+      except KeyError:
+        # If input tensors are not found, the signatures can't be used for the
+        # exported graph.
+        warnings.warn('Cannot find the tensor with name %s in the graph.' %
+                      tensor_info.name)
+        return None
+
+    for tensor_info in signature_def.outputs.values():
+      try:
+        if tensor_info.name in output_index_path_map:
+          tensor_info.name = output_index_path_map[tensor_info.name]
+        else:
+          # Tries to find the return node with the given name and use its input
+          # as the output tensor name.
+          return_node = exported_graph.get_operation_by_name(
+              _legalize_tensor_name(tensor_info.name))
+          tensor_info.name = return_node.inputs[0].name
+      except KeyError:
+        warnings.warn(
+            'Cannot find the tensor or node with name %s in the graph.' %
+            tensor_info.name)
+        return None
+
+  return signatures
+
+
+def quantize(saved_model_path: str,
+             signature_keys=None,
+             tags=None,
+             output_directory=None,
+             representative_dataset=None):
+  """Quantizes the given SavedModel.
+
+  Args:
+    saved_model_path: Path to the saved model. When representative_dataset is
+      not provided, this should be a model trained with QAT.
+    signature_keys: List of keys identifying SignatureDef containing inputs and
+      outputs.
+    tags: Set of tags identifying the MetaGraphDef within the SavedModel to
+      analyze.
+    output_directory: The path to save the output SavedModel (must be an empty
+      directory).
+    representative_dataset: a generator that returns a dictionary in
+      {input_name: input_tensor} format or a tuple with signature key and a
+      dictionary in {input_name: input_tensor} format that feeds calibration
+        data for quantizing model. This should be provided when the model is not
+        a QAT model.
+
+  Returns:
+    A SavedModel object with TF quantization applied.
+
+  Raises:
+    ValueError: when representative_dataset is not provided for non-QAT model.
+  """
+  if tags is None:
+    tags = set([tag_constants.SERVING])
+  if signature_keys is None:
+    signature_keys = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
+
+  is_qat_saved_model = _is_qat_saved_model(saved_model_path)
+  signatures = _get_signatures_from_saved_model(saved_model_path,
+                                                signature_keys, tags)
+
+  # Checks if the model is from QAT
+  if representative_dataset is None and not is_qat_saved_model:
+    raise ValueError(
+        'When `representative_dataset` is not provided, the model should be '
+        'trained with quantization-aware training (QAT).')
+
+  if is_qat_saved_model:
+    # Handle QAT models are supported.
+    graph_def_serialized = (
+        quantize_model_wrapper.quantize_qat_model(saved_model_path,
+                                                  ','.join(signature_keys),
+                                                  ','.join(tags)))
+  else:
+    # Handle PTQ models are supported with mocking calibration.
+    graph_def_serialized = (
+        quantize_model_wrapper.quantize_ptq_model_pre_calibration(
+            saved_model_path, ','.join(signature_keys), ','.join(tags)))
+
+    graph_def = graph_pb2.GraphDef()
+    graph_def.ParseFromString(graph_def_serialized)
+
+    float_model_dir = tempfile.mkdtemp()
+    v1_builder = builder.SavedModelBuilder(float_model_dir)
+
+    with session.Session(graph=ops.Graph()) as sess:
+      for function_def in graph_def.library.function:
+        for node_def in function_def.node_def:
+          if node_def.op == 'CustomAggregator':
+            node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii')
+
+      importer.import_graph_def(graph_def, name='')
+      working_graph = ops.get_default_graph()
+      graph_def = working_graph.as_graph_def()
+
+      signatures = _fix_tensor_names(signatures, working_graph)
+      if signatures is None:
+        raise ValueError(
+            "The input SavedModel doesn't contain a valid signature")
+
+      v1_builder.add_meta_graph_and_variables(
+          sess, [tag_constants.SERVING], signature_def_map=signatures)
+
+    v1_builder.save()
+
+    float_model = saved_model_load(float_model_dir)
+
+    for sample in representative_dataset():
+      # TODO(b/214311251): Add a test case with multiple signatures.
+      if isinstance(sample, tuple):
+        if not isinstance(sample[1], dict):
+          raise ValueError('You need to provide a dictionary with input '
+                           'names and values in the second argument in the '
+                           'tuple')
+        signature_key = sample[0]
+        input_data_map = sample[1]
+      elif isinstance(sample, dict):
+        if len(signature_keys) > 1:
+          raise ValueError('When the model has multiple signatures, you need '
+                           'to provide a tuple with signature key and a '
+                           'dictionary with input names and values')
+        signature_key = signature_keys[0]
+        input_data_map = sample
+      else:
+        raise ValueError('You need to provide either a dictionary with input '
+                         'names and values or a tuple with signature key and a '
+                         'dictionary with input names and values')
+      func = float_model.signatures[signature_key]
+      func(**input_data_map)
+
+    for function_def in graph_def.library.function:
+      for node_def in function_def.node_def:
+        if node_def.op == 'CustomAggregator':
+          node_id = node_def.attr['id'].s
+          min_val, max_val = quantize_model_wrapper.get_min_max_from_calibrator(
+              node_id)
+          quantize_model_wrapper.clear_data_from_calibrator(node_id)
+          node_def.attr['min'].f = float(min_val)
+          node_def.attr['max'].f = float(max_val)
+
+    calibrated_model_dir = tempfile.mkdtemp()
+    v1_builder = builder.SavedModelBuilder(calibrated_model_dir)
+
+    with session.Session(graph=ops.Graph()) as sess:
+      importer.import_graph_def(graph_def, name='')
+      working_graph = ops.get_default_graph()
+      graph_def = working_graph.as_graph_def()
+
+      v1_builder.add_meta_graph_and_variables(
+          sess, [tag_constants.SERVING], signature_def_map=signatures)
+
+    v1_builder.save()
+    signatures = _get_signatures_from_saved_model(calibrated_model_dir,
+                                                  signature_keys, tags)
+
+    graph_def_serialized = (
+        quantize_model_wrapper.quantize_ptq_model_post_calibration(
+            calibrated_model_dir,
+            ','.join(signature_keys),
+            ','.join(tags),
+        ))
+
+  graph_def = graph_pb2.GraphDef()
+  graph_def.ParseFromString(graph_def_serialized)
+
+  if output_directory is None:
+    output_directory = tempfile.mkdtemp()
+  v1_builder = builder.SavedModelBuilder(output_directory)
+
+  with session.Session(graph=ops.Graph()) as sess:
+    importer.import_graph_def(graph_def, name='')
+    working_graph = ops.get_default_graph()
+
+    signatures = _fix_tensor_names(signatures, working_graph)
+    if signatures is None:
+      raise ValueError("The input SavedModel doesn't contain a valid signature")
+
+    v1_builder.add_meta_graph_and_variables(
+        sess, [tag_constants.SERVING], signature_def_map=signatures)
+
+  v1_builder.save()
+
+  return saved_model_load(output_directory)
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.cc
new file mode 100644
index 0000000..3680832
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.cc
@@ -0,0 +1,144 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <pybind11/stl.h>
+
+#include <memory>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/stl.h"
+#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h"
+#include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h"
+#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
+#include "tensorflow/python/lib/core/pybind11_lib.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+
+using tensorflow::calibrator::CalibratorSingleton;
+
+PyObject* QuantizeQATModel(absl::string_view saved_model_path,
+                           absl::string_view exported_names_str,
+                           absl::string_view tags) {
+  auto graph_def_or =
+      mlir::quant::QuantizeQATModel(saved_model_path, exported_names_str, tags);
+  if (!graph_def_or.ok()) {
+    PyErr_Format(PyExc_ValueError,
+                 graph_def_or.status().error_message().c_str());
+    return nullptr;
+  }
+
+  std::string ret_str = graph_def_or.ValueOrDie().SerializeAsString();
+
+  return tflite::python_utils::ConvertToPyString(ret_str.c_str(),
+                                                 ret_str.size());
+}
+
+PyObject* QuantizePTQModelPreCalibration(absl::string_view saved_model_path,
+                                         absl::string_view exported_names_str,
+                                         absl::string_view tags) {
+  auto graph_def_or = mlir::quant::QuantizePTQModelPreCalibration(
+      saved_model_path, exported_names_str, tags);
+  if (!graph_def_or.ok()) {
+    PyErr_Format(PyExc_ValueError,
+                 graph_def_or.status().error_message().c_str());
+    return nullptr;
+  }
+
+  std::string ret_str = graph_def_or.ValueOrDie().SerializeAsString();
+
+  return tflite::python_utils::ConvertToPyString(ret_str.c_str(),
+                                                 ret_str.size());
+}
+
+PyObject* QuantizePTQModelPostCalibration(absl::string_view saved_model_path,
+                                          absl::string_view exported_names_str,
+                                          absl::string_view tags) {
+  auto graph_def_or = mlir::quant::QuantizePTQModelPostCalibration(
+      saved_model_path, exported_names_str, tags);
+  if (!graph_def_or.ok()) {
+    PyErr_Format(PyExc_ValueError,
+                 graph_def_or.status().error_message().c_str());
+    return nullptr;
+  }
+
+  std::string ret_str = graph_def_or.ValueOrDie().SerializeAsString();
+
+  return tflite::python_utils::ConvertToPyString(ret_str.c_str(),
+                                                 ret_str.size());
+}
+
+py::tuple GetMinMaxFromCalibrator(absl::string_view id) {
+  absl::optional<std::pair<float, float>> min_max =
+      CalibratorSingleton::GetMinMax(id);
+  if (!min_max.has_value()) {
+    PyErr_Format(PyExc_ValueError, "No calibrated data for '%s'",
+                 std::string{id}.c_str());
+    throw py::error_already_set();
+  }
+
+  return py::make_tuple(min_max->first, min_max->second);
+}
+
+PYBIND11_MODULE(quantize_model_wrapper, m) {
+  m.def(
+      "clear_calibrator",
+      []() { CalibratorSingleton::ClearCollectedInformation(); },
+      R"pbdoc(
+      Clears the collected metrics from the calibrator.
+    )pbdoc");
+  m.def(
+      "clear_data_from_calibrator",
+      [](absl::string_view id) { CalibratorSingleton::ClearData(id); },
+      R"pbdoc(
+      Clears the collected data of the given id from calibrator.
+    )pbdoc");
+  m.def(
+      "get_min_max_from_calibrator",
+      [](absl::string_view id) { return GetMinMaxFromCalibrator(id); },
+      R"pbdoc(
+      Return the tuple with the min and max values of the given id.
+    )pbdoc");
+  m.def(
+      "quantize_qat_model",
+      [](absl::string_view saved_model_path,
+         absl::string_view exported_names_str, absl::string_view tags) {
+        return tensorflow::PyoOrThrow(
+            QuantizeQATModel(saved_model_path, exported_names_str, tags));
+      },
+      R"pbdoc(
+      Returns a tf model graph def string.
+    )pbdoc");
+  m.def(
+      "quantize_ptq_model_pre_calibration",
+      [](absl::string_view saved_model_path,
+         absl::string_view exported_names_str, absl::string_view tags) {
+        return tensorflow::PyoOrThrow(QuantizePTQModelPreCalibration(
+            saved_model_path, exported_names_str, tags));
+      },
+      R"pbdoc(
+      Returns a tf model graph def string.
+    )pbdoc");
+  m.def(
+      "quantize_ptq_model_post_calibration",
+      [](absl::string_view saved_model_path,
+         absl::string_view exported_names_str, absl::string_view tags) {
+        return tensorflow::PyoOrThrow(QuantizePTQModelPostCalibration(
+            saved_model_path, exported_names_str, tags));
+      },
+      R"pbdoc(
+      Returns a tf model graph def string.
+    )pbdoc");
+}