Check in the Python codes for TensorFlow quantization passes
PiperOrigin-RevId: 431033331
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD
index 7b9480b..61047e7 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD
@@ -22,7 +22,6 @@
"passes/util.h",
],
deps = [
- "//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
@@ -38,7 +37,6 @@
"passes/quantize_composite_functions.td",
"passes/utils.td",
],
- visibility = ["//visibility:private"],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
@@ -123,8 +121,6 @@
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:macros",
"//tensorflow/core/platform:path",
- "//tensorflow/lite/c:common",
- "//tensorflow/lite/kernels:padding",
"//tensorflow/lite/kernels/internal:quantization_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/random",
@@ -147,7 +143,6 @@
deps = [
":passes",
"//tensorflow/compiler/mlir:init_mlir",
- "//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD
index 9560c5b..808be4d 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD
@@ -5,6 +5,7 @@
"tf_copts",
"tf_custom_op_library",
"tf_gen_op_wrapper_py",
+ "tf_py_test",
)
package(
@@ -74,14 +75,16 @@
],
)
-# TODO(b/220688154): Re-enable test once the quantize model wrapper is ready.
-#py_test(
-# name = "custom_aggregator_op_test",
-# size = "small",
-# srcs = ["custom_aggregator_op_test.py"],
-# deps = [
-# ":custom_aggregator_op",
-# "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_model_wrapper",
-# "//tensorflow:tensorflow_py",
-# ],
-#)
+tf_py_test(
+ name = "custom_aggregator_op_test",
+ size = "small",
+ srcs = ["custom_aggregator_op_test.py"],
+ tags = [
+ "no_oss", # TODO(b/220688154): Enable OSS tests.
+ ],
+ deps = [
+ ":custom_aggregator_op",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_wrapper",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py
index 1a82541..e7e8a5d 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.py
@@ -14,12 +14,10 @@
# ==============================================================================
"""Custom Aggregator op is for collecting numeric metrics from the given input."""
-# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import custom_aggregator_op_wrapper
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
-# pylint: enable=g-direct-tensorflow-import
_custom_aggregator_op = load_library.load_op_library(
resource_loader.get_path_to_datafile('_custom_aggregator_op.so'))
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op_test.py
index 23c2195..06bc5d0 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op_test.py
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op_test.py
@@ -14,16 +14,16 @@
# ==============================================================================
"""Tests for Custom Aggregator op."""
-# pylint:disable=g-direct-tensorflow-import
-from tensorflow.compiler.mlir.quantization.tensorflow import convert_model_wrapper
+
from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import custom_aggregator_op
-from tensorflow.python import tf
+from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model_wrapper
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-
-# pylint:enable=g-direct-tensorflow-import
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
-class CustomAggregatorTest(tf.test.TestCase):
+class CustomAggregatorTest(test.TestCase):
def setUp(self):
super(CustomAggregatorTest, self).setUp()
@@ -31,54 +31,59 @@
def testBypassAndMinMax(self):
with self.test_session():
- convert_model_wrapper.clear_calibrator()
- input_tensor = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0], tf.float32)
+ quantize_model_wrapper.clear_calibrator()
+ input_tensor = array_ops.constant([1.0, 2.0, 3.0, 4.0, 5.0],
+ dtypes.float32)
aggregator = custom_aggregator_op.custom_aggregator(
input_tensor, tensor_id='1')
self.assertAllEqual(aggregator.eval(), [1.0, 2.0, 3.0, 4.0, 5.0])
- min_max = convert_model_wrapper.get_min_max_from_calibrator('1')
+ min_max = quantize_model_wrapper.get_min_max_from_calibrator('1')
self.assertAllEqual(min_max, (1.0, 5.0))
def testTwoIdentities(self):
with self.test_session():
- convert_model_wrapper.clear_calibrator()
- input_tensor1 = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0], tf.float32)
+ quantize_model_wrapper.clear_calibrator()
+ input_tensor1 = array_ops.constant([1.0, 2.0, 3.0, 4.0, 5.0],
+ dtypes.float32)
aggregator1 = custom_aggregator_op.custom_aggregator(
input_tensor1, tensor_id='2')
self.assertAllEqual(aggregator1.eval(), [1.0, 2.0, 3.0, 4.0, 5.0])
- input_tensor2 = tf.constant([-1.0, -2.0, -3.0, -4.0, -5.0], tf.float32)
+ input_tensor2 = array_ops.constant([-1.0, -2.0, -3.0, -4.0, -5.0],
+ dtypes.float32)
aggregator2 = custom_aggregator_op.custom_aggregator(
input_tensor2, tensor_id='3')
self.assertAllEqual(aggregator2.eval(), [-1.0, -2.0, -3.0, -4.0, -5.0])
- min_max = convert_model_wrapper.get_min_max_from_calibrator('2')
+ min_max = quantize_model_wrapper.get_min_max_from_calibrator('2')
self.assertAllEqual(min_max, (1.0, 5.0))
- min_max = convert_model_wrapper.get_min_max_from_calibrator('3')
+ min_max = quantize_model_wrapper.get_min_max_from_calibrator('3')
self.assertAllEqual(min_max, (-5.0, -1.0))
def testClearData(self):
with self.test_session():
- convert_model_wrapper.clear_calibrator()
- input_tensor1 = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0], tf.float32)
+ quantize_model_wrapper.clear_calibrator()
+ input_tensor1 = array_ops.constant([1.0, 2.0, 3.0, 4.0, 5.0],
+ dtypes.float32)
aggregator1 = custom_aggregator_op.custom_aggregator(
input_tensor1, tensor_id='4')
self.assertAllEqual(aggregator1.eval(), [1.0, 2.0, 3.0, 4.0, 5.0])
- input_tensor2 = tf.constant([-1.0, -2.0, -3.0, -4.0, -5.0], tf.float32)
+ input_tensor2 = array_ops.constant([-1.0, -2.0, -3.0, -4.0, -5.0],
+ dtypes.float32)
aggregator2 = custom_aggregator_op.custom_aggregator(
input_tensor2, tensor_id='5')
self.assertAllEqual(aggregator2.eval(), [-1.0, -2.0, -3.0, -4.0, -5.0])
- min_max = convert_model_wrapper.get_min_max_from_calibrator('4')
+ min_max = quantize_model_wrapper.get_min_max_from_calibrator('4')
self.assertAllEqual(min_max, (1.0, 5.0))
- min_max = convert_model_wrapper.get_min_max_from_calibrator('5')
+ min_max = quantize_model_wrapper.get_min_max_from_calibrator('5')
self.assertAllEqual(min_max, (-5.0, -1.0))
- convert_model_wrapper.clear_data_from_calibrator('4')
+ quantize_model_wrapper.clear_data_from_calibrator('4')
with self.assertRaises(ValueError):
- convert_model_wrapper.get_min_max_from_calibrator('4')
- min_max = convert_model_wrapper.get_min_max_from_calibrator('5')
+ quantize_model_wrapper.get_min_max_from_calibrator('4')
+ min_max = quantize_model_wrapper.get_min_max_from_calibrator('5')
self.assertAllEqual(min_max, (-5.0, -1.0))
if __name__ == '__main__':
- tf.test.main()
+ test.main()
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc
index d7aa55c..0a0802b 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc
@@ -21,7 +21,6 @@
#include "mlir/InitAllPasses.h" // from @llvm-project
#include "mlir/Support/MlirOptMain.h" // from @llvm-project
#include "tensorflow/compiler/mlir/init_mlir.h"
-#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
@@ -33,12 +32,11 @@
mlir::registerTensorFlowPasses();
mlir::DialectRegistry registry;
- registry
- .insert<mlir::scf::SCFDialect, mlir::TF::TensorFlowDialect,
- mlir::tf_saved_model::TensorFlowSavedModelDialect,
- mlir::StandardOpsDialect, mlir::shape::ShapeDialect,
- mlir::arith::ArithmeticDialect, mlir::quant::QuantizationDialect,
- mlir::TFL::TensorFlowLiteDialect>();
+ registry.insert<mlir::scf::SCFDialect, mlir::TF::TensorFlowDialect,
+ mlir::tf_saved_model::TensorFlowSavedModelDialect,
+ mlir::StandardOpsDialect, mlir::shape::ShapeDialect,
+ mlir::arith::ArithmeticDialect,
+ mlir::quant::QuantizationDialect>();
return failed(
mlir::MlirOptMain(argc, argv, "TF quant Pass Driver\n", registry));
}
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/util.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/util.cc
index 2083c84..9cc2d1c 100644
--- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/util.cc
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/util.cc
@@ -14,14 +14,13 @@
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/util.h"
-#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
namespace mlir {
namespace quant {
bool HasQuantizedTensors(Operation *op) {
- if (!isa<TFL::QuantizeOp>(op) && IsOpNotQuantizable(op)) return false;
+ if (IsOpNotQuantizable(op)) return false;
for (Type operand_type : op->getOperandTypes()) {
auto tensor_type = operand_type.dyn_cast<TensorType>();
if (tensor_type && tensor_type.getElementType().isa<QuantizedType>()) {
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
new file mode 100644
index 0000000..e3f6180
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
@@ -0,0 +1,125 @@
+load("//tensorflow:tensorflow.bzl", "pytype_strict_library")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
+
+package(
+ default_visibility = [
+ "//tensorflow/compiler/mlir/quantization:__subpackages__",
+ ],
+ licenses = ["notice"],
+)
+
+cc_library(
+ name = "quantize_model_lib",
+ srcs = [
+ "quantize_model.cc",
+ ],
+ hdrs = [
+ "quantize_model.h",
+ ],
+ deps = [
+ "//tensorflow/cc/saved_model:loader",
+ "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
+ "//tensorflow/compiler/mlir/quantization/tensorflow:passes",
+ "//tensorflow/compiler/mlir/tensorflow",
+ "//tensorflow/compiler/mlir/tensorflow:error_util",
+ "//tensorflow/compiler/mlir/tensorflow:export_graphdef",
+ "//tensorflow/compiler/mlir/tensorflow:mlir_import_options",
+ "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
+ "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
+ "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
+ "//tensorflow/compiler/mlir/tensorflow:translate_lib",
+ "//tensorflow/core:core_cpu_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal_impl",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/platform:errors",
+ "//tensorflow/core/platform:path",
+ "//tensorflow/core/platform:statusor",
+ "//tensorflow/core/platform:stringpiece",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:QuantOps",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:Shape",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+tf_python_pybind_extension(
+ name = "quantize_model_wrapper",
+ srcs = [
+ "quantize_model_wrapper.cc",
+ ],
+ module_name = "quantize_model_wrapper",
+ deps = [
+ ":quantize_model_lib",
+ "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton",
+ "//tensorflow/lite/python/interpreter_wrapper:python_utils",
+ "//tensorflow/python/lib/core:pybind11_lib",
+ "//tensorflow/python/lib/core:pybind11_status",
+ "@com_google_absl//absl/strings",
+ "@pybind11",
+ ],
+)
+
+pytype_strict_library(
+ name = "quantize_model",
+ srcs = [
+ "quantize_model.py",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":quantize_model_wrapper",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/client:session",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:load",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow/python/saved_model:utils",
+ "@flatbuffers//:runtime_py",
+ ],
+)
+
+tf_py_test(
+ name = "quantize_model_test",
+ size = "medium",
+ srcs = ["integration_test/quantize_model_test.py"],
+ tags = [
+ "no_oss", # TODO(b/220688154): Enable OSS tests.
+ ],
+ deps = [
+ ":quantize_model",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+tf_py_test(
+ name = "concurrency_test",
+ size = "medium",
+ srcs = ["integration_test/concurrency_test.py"],
+ tags = [
+ "no_oss", # TODO(b/220688154): Enable OSS tests.
+ ],
+ deps = [
+ ":quantize_model",
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py
new file mode 100644
index 0000000..5cc4139
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py
@@ -0,0 +1,89 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Concurrency tests for quantize_model."""
+
+from concurrent import futures
+
+import numpy as np
+import tensorflow as tf # pylint: disable=unused-import
+
+from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import save as saved_model_save
+from tensorflow.python.saved_model.saved_model import tag_constants
+from tensorflow.python.training.tracking import tracking
+
+
+class MultiThreadedTest(test.TestCase):
+ """Tests involving multiple threads."""
+
+ def setUp(self):
+ super(MultiThreadedTest, self).setUp()
+ self.pool = futures.ThreadPoolExecutor(max_workers=4)
+
+ def _convert_with_calibration(self):
+
+ class ModelWithAdd(tracking.AutoTrackable):
+ """Basic model with addition."""
+
+ @def_function.function(input_signature=[
+ tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='x'),
+ tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='y')
+ ])
+ def add(self, x, y):
+ res = math_ops.add(x, y)
+ return {'output': res}
+
+ def data_gen():
+ for _ in range(255):
+ yield {
+ 'x':
+ ops.convert_to_tensor(
+ np.random.uniform(size=(10)).astype('f4')),
+ 'y':
+ ops.convert_to_tensor(
+ np.random.uniform(size=(10)).astype('f4'))
+ }
+
+ root = ModelWithAdd()
+
+ temp_path = self.create_tempdir().full_path
+ saved_model_save.save(
+ root, temp_path, signatures=root.add.get_concrete_function())
+
+ model = quantize_model.quantize(
+ temp_path, ['serving_default'], [tag_constants.SERVING],
+ representative_dataset=data_gen)
+ self.assertIsNotNone(model)
+
+ def testMultipleConversionJobsWithCalibration(self):
+ # Ensure that multiple conversion jobs with calibration won't encounter any
+ # concurrency issue.
+ with self.pool:
+ jobs = []
+ for _ in range(10):
+ jobs.append(self.pool.submit(self._convert_with_calibration))
+
+ for job in jobs:
+ job.result()
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
new file mode 100644
index 0000000..9c86f5e
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py
@@ -0,0 +1,271 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for quantize_model."""
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf # pylint: disable=unused-import
+
+from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.module import module
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import loader_impl as saved_model_loader
+from tensorflow.python.saved_model import save as saved_model_save
+from tensorflow.python.saved_model.saved_model import signature_constants
+from tensorflow.python.saved_model.saved_model import tag_constants
+from tensorflow.python.training.tracking import tracking
+
+
+def _contains_quantized_function_call(meta_graphdef):
+ """Returns true if the graph def has quantized function call."""
+ for func in meta_graphdef.graph_def.library.function:
+ if func.signature.name.startswith('quantized_'):
+ return True
+ return False
+
+
+class QuantizationTest(test.TestCase, parameterized.TestCase):
+
+ def test_qat_model(self):
+
+ class QATModelWithAdd(tracking.AutoTrackable):
+ """Basic model with Fake quant + add."""
+
+ @def_function.function(input_signature=[
+ tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='x'),
+ tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='y')
+ ])
+ def add(self, x, y):
+ float_res = math_ops.add(x, y)
+ x = array_ops.fake_quant_with_min_max_args(
+ x, min=-0.1, max=0.2, num_bits=8, narrow_range=False)
+ y = array_ops.fake_quant_with_min_max_args(
+ y, min=-0.3, max=0.4, num_bits=8, narrow_range=False)
+ res = math_ops.add(x, y)
+ res = array_ops.fake_quant_with_min_max_args(
+ res, min=-0.4, max=0.6, num_bits=8, narrow_range=False)
+ return {'output': res, 'float_output': float_res}
+
+ root = QATModelWithAdd()
+
+ temp_path = self.create_tempdir().full_path
+ saved_model_save.save(
+ root, temp_path, signatures=root.add.get_concrete_function())
+
+ output_directory = self.create_tempdir().full_path
+ tags = [tag_constants.SERVING]
+ model = quantize_model.quantize(temp_path, ['serving_default'],
+ [tag_constants.SERVING], output_directory)
+ self.assertIsNotNone(model)
+ self.assertEqual(
+ list(model.signatures._signatures.keys()), ['serving_default'])
+ func = model.signatures['serving_default']
+ func_res = func(
+ x=array_ops.constant(0.1, shape=[10]),
+ y=array_ops.constant(0.1, shape=[10]))
+ self.assertAllClose(
+ func_res['output'], array_ops.constant(0.2, shape=[10]), atol=0.01)
+ self.assertAllClose(
+ func_res['float_output'],
+ array_ops.constant(0.2, shape=[10]),
+ atol=1e-3)
+
+ output_loader = saved_model_loader.SavedModelLoader(output_directory)
+ output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
+ self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
+
+ def test_ptq_model(self):
+
+ class PTQModelWithAdd(tracking.AutoTrackable):
+ """Basic model with addition."""
+
+ @def_function.function(input_signature=[
+ tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='x'),
+ tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='y')
+ ])
+ def add(self, x, y):
+ res = math_ops.add(x, y)
+ return {'output': res, 'x': x, 'y': y}
+
+ def data_gen():
+ for _ in range(255):
+ yield {
+ 'x':
+ ops.convert_to_tensor(
+ np.random.uniform(size=(10)).astype('f4')),
+ 'y':
+ ops.convert_to_tensor(
+ np.random.uniform(size=(10)).astype('f4'))
+ }
+
+ root = PTQModelWithAdd()
+
+ temp_path = self.create_tempdir().full_path
+ saved_model_save.save(
+ root, temp_path, signatures=root.add.get_concrete_function())
+
+ output_directory = self.create_tempdir().full_path
+ tags = [tag_constants.SERVING]
+ model = quantize_model.quantize(
+ temp_path, ['serving_default'],
+ tags,
+ output_directory,
+ representative_dataset=data_gen)
+ self.assertIsNotNone(model)
+ self.assertEqual(
+ list(model.signatures._signatures.keys()), ['serving_default'])
+ func = model.signatures['serving_default']
+ func_res = func(
+ x=array_ops.constant(0.1, shape=[10]),
+ y=array_ops.constant(0.1, shape=[10]))
+ self.assertAllClose(
+ func_res['output'], array_ops.constant(0.2, shape=[10]), atol=0.01)
+ xy_atol = 1e-6
+ self.assertAllClose(
+ func_res['x'], array_ops.constant(0.1, shape=[10]), atol=xy_atol)
+ self.assertAllClose(
+ func_res['y'], array_ops.constant(0.1, shape=[10]), atol=xy_atol)
+
+ output_loader = saved_model_loader.SavedModelLoader(output_directory)
+ output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
+ self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
+
+ @parameterized.named_parameters(
+ ('none', None),
+ ('relu', nn_ops.relu),
+ ('relu6', nn_ops.relu6),
+ )
+ def test_qat_conv_model(self, activation_fn):
+
+ class ConvModel(module.Module):
+
+ @def_function.function(input_signature=[
+ tensor_spec.TensorSpec(
+ name='input', shape=[1, 3, 4, 3], dtype=dtypes.float32),
+ tensor_spec.TensorSpec(
+ name='filter', shape=[2, 3, 3, 2], dtype=dtypes.float32),
+ ])
+ def conv(self, input_tensor, filter_tensor):
+ q_input = array_ops.fake_quant_with_min_max_args(
+ input_tensor, min=-0.1, max=0.2, num_bits=8, narrow_range=False)
+ q_filters = array_ops.fake_quant_with_min_max_args(
+ filter_tensor, min=-1.0, max=2.0, num_bits=8, narrow_range=False)
+ bias = array_ops.constant([0, 0], dtype=dtypes.float32)
+ out = nn_ops.conv2d(
+ q_input,
+ q_filters,
+ strides=[1, 1, 2, 1],
+ dilations=[1, 1, 1, 1],
+ padding='SAME',
+ data_format='NHWC')
+ out = nn_ops.bias_add(out, bias, data_format='NHWC')
+ if activation_fn is not None:
+ out = activation_fn(out)
+ q_out = array_ops.fake_quant_with_min_max_args(
+ out, min=-0.3, max=0.4, num_bits=8, narrow_range=False)
+ return {'output': q_out}
+
+ model = ConvModel()
+ input_saved_model_path = self.create_tempdir('input').full_path
+ saved_model_save.save(model, input_saved_model_path)
+
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ tags = [tag_constants.SERVING]
+ output_directory = self.create_tempdir().full_path
+ converted_model = quantize_model.quantize(
+ input_saved_model_path, [signature_key],
+ tags,
+ output_directory=output_directory)
+ self.assertIsNotNone(converted_model)
+ self.assertEqual(
+ list(converted_model.signatures._signatures.keys()), [signature_key])
+
+ input_data = np.random.uniform(
+ low=-0.1, high=0.2, size=(1, 3, 4, 3)).astype('f4')
+ filter_data = np.random.uniform(
+ low=-0.5, high=0.5, size=(2, 3, 3, 2)).astype('f4')
+
+ expected_outputs = model.conv(input_data, filter_data)
+ got_outputs = converted_model.signatures[signature_key](
+ input=ops.convert_to_tensor(input_data),
+ filter=ops.convert_to_tensor(filter_data))
+ # TODO(b/215633216): Check if the accuracy is acceptable.
+ self.assertAllClose(expected_outputs, got_outputs, atol=0.01)
+
+ output_loader = saved_model_loader.SavedModelLoader(output_directory)
+ output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
+ self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
+
+ def test_conv_ptq_model(self):
+
+ class ConvModel(module.Module):
+
+ @def_function.function(input_signature=[
+ tensor_spec.TensorSpec(shape=[1, 3, 4, 3], dtype=dtypes.float32)
+ ])
+ def conv(self, input_tensor):
+ filters = np.random.uniform(
+ low=-10, high=10, size=(2, 3, 3, 2)).astype('f4')
+ bias = np.random.uniform(low=0, high=10, size=(2)).astype('f4')
+ out = nn_ops.conv2d(
+ input_tensor,
+ filters,
+ strides=[1, 1, 2, 1],
+ dilations=[1, 1, 1, 1],
+ padding='SAME',
+ data_format='NHWC')
+ out = nn_ops.bias_add(out, bias, data_format='NHWC')
+ out = nn_ops.relu6(out)
+ return {'output': out}
+
+ model = ConvModel()
+ input_saved_model_path = self.create_tempdir('input').full_path
+ saved_model_save.save(model, input_saved_model_path)
+
+ def data_gen():
+ for _ in range(255):
+ yield {
+ 'input_tensor':
+ ops.convert_to_tensor(
+ np.random.uniform(low=0, high=150,
+ size=(1, 3, 4, 3)).astype('f4')),
+ }
+
+ tags = [tag_constants.SERVING]
+ output_directory = self.create_tempdir().full_path
+ converted_model = quantize_model.quantize(
+ input_saved_model_path, ['serving_default'],
+ tags,
+ output_directory,
+ representative_dataset=data_gen)
+ self.assertIsNotNone(converted_model)
+ self.assertEqual(
+ list(converted_model.signatures._signatures.keys()),
+ ['serving_default'])
+
+ output_loader = saved_model_loader.SavedModelLoader(output_directory)
+ output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
+ self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc
new file mode 100644
index 0000000..e3763a2
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc
@@ -0,0 +1,268 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h"
+
+#include <memory>
+#include <string>
+#include <string_view>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
+#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
+#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/IR/BuiltinOps.h" // from @llvm-project
+#include "mlir/IR/Location.h" // from @llvm-project
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/Parser.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Pass/PassManager.h" // from @llvm-project
+#include "mlir/Transforms/Passes.h" // from @llvm-project
+#include "tensorflow/cc/saved_model/loader.h"
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
+#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
+#include "tensorflow/core/common_runtime/graph_constructor.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/path.h"
+#include "tensorflow/core/platform/statusor.h"
+#include "tensorflow/core/platform/stringpiece.h"
+#include "tensorflow/core/util/env_var.h"
+
+using tensorflow::FunctionDefLibrary;
+using tensorflow::Graph;
+using tensorflow::GraphDef;
+using tensorflow::ImportGraphDefOptions;
+using tensorflow::OpRegistry;
+
+namespace mlir {
+namespace quant {
+
+absl::StatusOr<tensorflow::GraphDef> QuantizeQATModel(
+ absl::string_view saved_model_path, absl::string_view exported_names_str,
+ absl::string_view tags) {
+ const std::unordered_set<std::string> tag_set =
+ absl::StrSplit(tags, ',', absl::SkipEmpty());
+ std::vector<std::string> exported_names_vec =
+ absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
+
+ // Convert the SavedModelBundle to an MLIR module.
+ DialectRegistry registry;
+ registry.insert<StandardOpsDialect, scf::SCFDialect,
+ tf_saved_model::TensorFlowSavedModelDialect,
+ TF::TensorFlowDialect, shape::ShapeDialect,
+ QuantizationDialect>();
+ MLIRContext context(registry);
+
+ tensorflow::MLIRImportOptions import_options;
+ import_options.upgrade_legacy = true;
+ auto bundle = std::make_unique<tensorflow::SavedModelBundle>();
+
+ // TODO(b/213406917): Add support for the object graph based saved model input
+ auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
+ saved_model_path, tag_set, absl::Span<std::string>(exported_names_vec),
+ &context, import_options, /*lift_variables=*/true, &bundle);
+
+ if (!module_or.status().ok()) {
+ return absl::InternalError("failed to import SavedModel: " +
+ module_or.status().error_message());
+ }
+
+ OwningOpRef<mlir::ModuleOp> moduleRef = module_or.ConsumeValueOrDie();
+
+ PassManager pm(&context);
+
+ std::string error;
+ llvm::raw_string_ostream error_stream(error);
+
+ pm.addPass(createCanonicalizerPass());
+ // Freezes constants so that FakeQuant ops can reference quantization ranges.
+ pm.addPass(tf_saved_model::CreateOptimizeGlobalTensorsPass());
+ pm.addPass(mlir::createInlinerPass());
+ pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+ pm.addPass(tf_saved_model::CreateFreezeGlobalTensorsPass());
+
+ pm.addNestedPass<FuncOp>(CreateConvertFakeQuantToQdqPass());
+ pm.addNestedPass<FuncOp>(TF::CreateFusedKernelMatcherPass());
+ pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass());
+ pm.addPass(CreateInsertQuantizedFunctionsPass());
+ pm.addPass(CreateQuantizeCompositeFunctionsPass());
+ pm.addPass(mlir::createSymbolDCEPass());
+
+ pm.addPass(CreateInsertMainFunctionPass());
+ pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
+ pm.addPass(CreateBreakUpIslandsPass());
+
+ mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
+ if (failed(pm.run(*moduleRef))) {
+ return absl::InternalError(
+ "failed to apply the quantization: " +
+ diagnostic_handler.ConsumeStatus().error_message());
+ }
+
+ // Export as GraphDef.
+ tensorflow::GraphExportConfig confs;
+ auto graph_or = tensorflow::ConvertMlirToGraphdef(*moduleRef, confs);
+ if (!graph_or.ok()) {
+ return absl::InternalError("failed to convert MLIR to graphdef: " +
+ graph_or.status().error_message());
+ }
+
+ return *graph_or.ConsumeValueOrDie();
+}
+
+absl::StatusOr<tensorflow::GraphDef> QuantizePTQModelPreCalibration(
+ absl::string_view saved_model_path, absl::string_view exported_names_str,
+ absl::string_view tags) {
+ const std::unordered_set<std::string> tag_set =
+ absl::StrSplit(tags, ',', absl::SkipEmpty());
+ std::vector<std::string> exported_names_vec =
+ absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
+
+ // Convert the SavedModelBundle to an MLIR module.
+ DialectRegistry registry;
+ registry.insert<StandardOpsDialect, scf::SCFDialect,
+ tf_saved_model::TensorFlowSavedModelDialect,
+ TF::TensorFlowDialect, shape::ShapeDialect,
+ QuantizationDialect>();
+ MLIRContext context(registry);
+
+ tensorflow::MLIRImportOptions import_options;
+ import_options.upgrade_legacy = true;
+ auto bundle = std::make_unique<tensorflow::SavedModelBundle>();
+
+ // TODO(b/213406917): Add support for the object graph based saved model input
+ auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
+ saved_model_path, tag_set, absl::Span<std::string>(exported_names_vec),
+ &context, import_options,
+ /*lift_variables=*/true, &bundle);
+
+ if (!module_or.status().ok()) {
+ return absl::InternalError("failed to import SavedModel: " +
+ module_or.status().error_message());
+ }
+
+ OwningOpRef<mlir::ModuleOp> moduleRef = module_or.ConsumeValueOrDie();
+
+ PassManager pm(&context);
+
+ pm.addPass(createCanonicalizerPass());
+ pm.addNestedPass<FuncOp>(TF::CreateFusedKernelMatcherPass());
+ pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass());
+ pm.addNestedPass<FuncOp>(CreateInsertCustomAggregationOpsPass());
+ pm.addPass(CreateIssueIDsOfCustomAggregationOpsPass());
+ pm.addPass(CreateInsertMainFunctionPass());
+ pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
+ pm.addPass(CreateBreakUpIslandsPass());
+
+ mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
+ if (failed(pm.run(*moduleRef))) {
+ return absl::InternalError(
+ "failed to apply the quantization at the pre-calibration stage: " +
+ diagnostic_handler.ConsumeStatus().error_message());
+ }
+
+ // Export as GraphDef.
+ tensorflow::GraphExportConfig confs;
+ auto graph_or = tensorflow::ConvertMlirToGraphdef(*moduleRef, confs);
+ if (!graph_or.ok()) {
+ return absl::InternalError("failed to convert MLIR to graphdef: " +
+ graph_or.status().error_message());
+ }
+
+ return *graph_or.ConsumeValueOrDie();
+}
+
+absl::StatusOr<tensorflow::GraphDef> QuantizePTQModelPostCalibration(
+ absl::string_view saved_model_path, absl::string_view exported_names_str,
+ absl::string_view tags) {
+ const std::unordered_set<std::string> tag_set =
+ absl::StrSplit(tags, ',', absl::SkipEmpty());
+ std::vector<std::string> exported_names_vec =
+ absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
+
+ // Convert the SavedModelBundle to an MLIR module.
+ DialectRegistry registry;
+ registry.insert<StandardOpsDialect, scf::SCFDialect,
+ tf_saved_model::TensorFlowSavedModelDialect,
+ TF::TensorFlowDialect, shape::ShapeDialect,
+ QuantizationDialect>();
+ MLIRContext context(registry);
+
+ tensorflow::MLIRImportOptions import_options;
+ import_options.upgrade_legacy = true;
+ auto bundle = std::make_unique<tensorflow::SavedModelBundle>();
+
+ // TODO(b/213406917): Add support for the object graph based saved model input
+ auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
+ saved_model_path, tag_set, absl::Span<std::string>(exported_names_vec),
+ &context, import_options,
+ /*lift_variables=*/true, &bundle);
+
+ if (!module_or.status().ok()) {
+ return absl::InternalError("failed to import SavedModel: " +
+ module_or.status().error_message());
+ }
+
+ OwningOpRef<mlir::ModuleOp> moduleRef = module_or.ConsumeValueOrDie();
+
+ PassManager pm(&context);
+
+ pm.addPass(createCanonicalizerPass());
+ pm.addNestedPass<FuncOp>(CreateConvertCustomAggregationOpToQuantStatsPass());
+ pm.addPass(CreateInsertQuantizedFunctionsPass());
+ pm.addPass(CreateQuantizeCompositeFunctionsPass());
+ pm.addPass(mlir::createSymbolDCEPass());
+ pm.addPass(CreateInsertMainFunctionPass());
+ pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
+ pm.addPass(CreateBreakUpIslandsPass());
+
+ mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
+ if (failed(pm.run(*moduleRef))) {
+ return absl::InternalError(
+ "failed to apply the quantization at the post-calibation stage: " +
+ diagnostic_handler.ConsumeStatus().error_message());
+ }
+
+ // Export as GraphDef.
+ tensorflow::GraphExportConfig confs;
+ auto graph_or = tensorflow::ConvertMlirToGraphdef(*moduleRef, confs);
+ if (!graph_or.ok()) {
+ return absl::InternalError("failed to convert MLIR to graphdef: " +
+ graph_or.status().error_message());
+ }
+
+ return *graph_or.ConsumeValueOrDie();
+}
+
+} // namespace quant
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h
new file mode 100644
index 0000000..a40d486
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h
@@ -0,0 +1,47 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_MODEL_H_
+#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_MODEL_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/framework/graph.pb.h"
+
+namespace mlir {
+namespace quant {
+
+// Quantizes the given QAT-enabled model.
+absl::StatusOr<tensorflow::GraphDef> QuantizeQATModel(
+ absl::string_view saved_model_path, absl::string_view exported_names_str,
+ absl::string_view tags);
+
+// Quantizes the given model with post-training quantization. This method covers
+// the part for the pre-calibration stage.
+absl::StatusOr<tensorflow::GraphDef> QuantizePTQModelPreCalibration(
+ absl::string_view saved_model_path, absl::string_view exported_names_str,
+ absl::string_view tags);
+
+// Quantizes the given model with post-training quantization. This method covers
+// the part for the post-calibration stage.
+absl::StatusOr<tensorflow::GraphDef> QuantizePTQModelPostCalibration(
+ absl::string_view saved_model_path, absl::string_view exported_names_str,
+ absl::string_view tags);
+} // namespace quant
+} // namespace mlir
+
+#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_MODEL_H_
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
new file mode 100644
index 0000000..469c8f4
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py
@@ -0,0 +1,291 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines TF Quantization API from SavedModel to SavedModel."""
+
+import tempfile
+import uuid
+import warnings
+
+from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model_wrapper
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader_impl as saved_model_loader
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.saved_model.load import load as saved_model_load
+
+
+# The signature key of the saved model init op.
+_INIT_OP_SIGNATURE_KEY = '__saved_model_init_op'
+
+
+def _legalize_tensor_name(tensor_name: str) -> str:
+ """Converts tensor name from 'name:index' to 'name__index' format."""
+ return tensor_name.replace(':', '__')
+
+
+def _is_qat_saved_model(saved_model_path: str):
+ """Checks if the SavedModel is QAT-enabled by looking for 'FakeQuant' ops."""
+ saved_model_proto = saved_model_loader.parse_saved_model(saved_model_path)
+ for meta_graph in saved_model_proto.meta_graphs:
+ if any(
+ node.op.startswith('FakeQuant') for node in meta_graph.graph_def.node):
+ return True
+ for function in meta_graph.graph_def.library.function:
+ if any(node.op.startswith('FakeQuant') for node in function.node_def):
+ return True
+ return False
+
+
+def _get_signatures_from_saved_model(saved_model_path: str,
+ signature_keys=None,
+ tags=None):
+ """Gets a map from signature keys to their SignatureDef from a saved model."""
+ if tags is None:
+ tags = set([tag_constants.SERVING])
+
+ loader = saved_model_loader.SavedModelLoader(saved_model_path)
+ meta_graphdef = loader.get_meta_graph_def_from_tags(tags)
+ signatures = {}
+ for key, signature_def in meta_graphdef.signature_def.items():
+ if key == _INIT_OP_SIGNATURE_KEY:
+ continue
+ if signature_keys is not None and key not in signature_keys:
+ continue
+ signatures[key] = signature_def
+
+ return signatures
+
+
+def _fix_tensor_names(signatures, exported_graph):
+ """Tries fixing tensor names in the signatures to match the exported graph.
+
+ The output tensor names in the original graph usually become names of the
+ return nodes in the exported graph. This function tries to fix that and checks
+ if the input tensor names are found in the exported graph.
+
+ Args:
+ signatures: the signatures of the original graph.
+ exported_graph: The PTQ-exported GraphDef.
+
+ Returns:
+ Fixed signatures or None if it couldn't be fixed.
+ """
+ if signatures is None:
+ return None
+
+ # The InsertMainFunctionPass populates input and output nodes of the newly
+ # inserted main function with "tf_saved_model.index_path" attributes. These
+ # attributes can be used to identify outputs in the exported graph.
+ output_index_path_map = {}
+ for op in exported_graph.get_operations():
+ if (op.type == '_Retval' and
+ op.get_attr('tf_saved_model.index_path') is not None):
+ index_path_name = op.get_attr('tf_saved_model.index_path')[0]
+ index_path_name = index_path_name.decode('utf-8')
+ output_index_path_map[index_path_name] = op.inputs[0].name
+
+ for signature_def in signatures.values():
+ for tensor_info in signature_def.inputs.values():
+ try:
+ exported_graph.get_tensor_by_name(tensor_info.name)
+ except KeyError:
+ # If input tensors are not found, the signatures can't be used for the
+ # exported graph.
+ warnings.warn('Cannot find the tensor with name %s in the graph.' %
+ tensor_info.name)
+ return None
+
+ for tensor_info in signature_def.outputs.values():
+ try:
+ if tensor_info.name in output_index_path_map:
+ tensor_info.name = output_index_path_map[tensor_info.name]
+ else:
+ # Tries to find the return node with the given name and use its input
+ # as the output tensor name.
+ return_node = exported_graph.get_operation_by_name(
+ _legalize_tensor_name(tensor_info.name))
+ tensor_info.name = return_node.inputs[0].name
+ except KeyError:
+ warnings.warn(
+ 'Cannot find the tensor or node with name %s in the graph.' %
+ tensor_info.name)
+ return None
+
+ return signatures
+
+
+def quantize(saved_model_path: str,
+ signature_keys=None,
+ tags=None,
+ output_directory=None,
+ representative_dataset=None):
+ """Quantizes the given SavedModel.
+
+ Args:
+ saved_model_path: Path to the saved model. When representative_dataset is
+ not provided, this should be a model trained with QAT.
+ signature_keys: List of keys identifying SignatureDef containing inputs and
+ outputs.
+ tags: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze.
+ output_directory: The path to save the output SavedModel (must be an empty
+ directory).
+ representative_dataset: a generator that returns a dictionary in
+ {input_name: input_tensor} format or a tuple with signature key and a
+ dictionary in {input_name: input_tensor} format that feeds calibration
+ data for quantizing model. This should be provided when the model is not
+ a QAT model.
+
+ Returns:
+ A SavedModel object with TF quantization applied.
+
+ Raises:
+ ValueError: when representative_dataset is not provided for non-QAT model.
+ """
+ if tags is None:
+ tags = set([tag_constants.SERVING])
+ if signature_keys is None:
+ signature_keys = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
+
+ is_qat_saved_model = _is_qat_saved_model(saved_model_path)
+ signatures = _get_signatures_from_saved_model(saved_model_path,
+ signature_keys, tags)
+
+ # Checks if the model is from QAT
+ if representative_dataset is None and not is_qat_saved_model:
+ raise ValueError(
+ 'When `representative_dataset` is not provided, the model should be '
+ 'trained with quantization-aware training (QAT).')
+
+ if is_qat_saved_model:
+ # Handle QAT models are supported.
+ graph_def_serialized = (
+ quantize_model_wrapper.quantize_qat_model(saved_model_path,
+ ','.join(signature_keys),
+ ','.join(tags)))
+ else:
+ # Handle PTQ models are supported with mocking calibration.
+ graph_def_serialized = (
+ quantize_model_wrapper.quantize_ptq_model_pre_calibration(
+ saved_model_path, ','.join(signature_keys), ','.join(tags)))
+
+ graph_def = graph_pb2.GraphDef()
+ graph_def.ParseFromString(graph_def_serialized)
+
+ float_model_dir = tempfile.mkdtemp()
+ v1_builder = builder.SavedModelBuilder(float_model_dir)
+
+ with session.Session(graph=ops.Graph()) as sess:
+ for function_def in graph_def.library.function:
+ for node_def in function_def.node_def:
+ if node_def.op == 'CustomAggregator':
+ node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii')
+
+ importer.import_graph_def(graph_def, name='')
+ working_graph = ops.get_default_graph()
+ graph_def = working_graph.as_graph_def()
+
+ signatures = _fix_tensor_names(signatures, working_graph)
+ if signatures is None:
+ raise ValueError(
+ "The input SavedModel doesn't contain a valid signature")
+
+ v1_builder.add_meta_graph_and_variables(
+ sess, [tag_constants.SERVING], signature_def_map=signatures)
+
+ v1_builder.save()
+
+ float_model = saved_model_load(float_model_dir)
+
+ for sample in representative_dataset():
+ # TODO(b/214311251): Add a test case with multiple signatures.
+ if isinstance(sample, tuple):
+ if not isinstance(sample[1], dict):
+ raise ValueError('You need to provide a dictionary with input '
+ 'names and values in the second argument in the '
+ 'tuple')
+ signature_key = sample[0]
+ input_data_map = sample[1]
+ elif isinstance(sample, dict):
+ if len(signature_keys) > 1:
+ raise ValueError('When the model has multiple signatures, you need '
+ 'to provide a tuple with signature key and a '
+ 'dictionary with input names and values')
+ signature_key = signature_keys[0]
+ input_data_map = sample
+ else:
+ raise ValueError('You need to provide either a dictionary with input '
+ 'names and values or a tuple with signature key and a '
+ 'dictionary with input names and values')
+ func = float_model.signatures[signature_key]
+ func(**input_data_map)
+
+ for function_def in graph_def.library.function:
+ for node_def in function_def.node_def:
+ if node_def.op == 'CustomAggregator':
+ node_id = node_def.attr['id'].s
+ min_val, max_val = quantize_model_wrapper.get_min_max_from_calibrator(
+ node_id)
+ quantize_model_wrapper.clear_data_from_calibrator(node_id)
+ node_def.attr['min'].f = float(min_val)
+ node_def.attr['max'].f = float(max_val)
+
+ calibrated_model_dir = tempfile.mkdtemp()
+ v1_builder = builder.SavedModelBuilder(calibrated_model_dir)
+
+ with session.Session(graph=ops.Graph()) as sess:
+ importer.import_graph_def(graph_def, name='')
+ working_graph = ops.get_default_graph()
+ graph_def = working_graph.as_graph_def()
+
+ v1_builder.add_meta_graph_and_variables(
+ sess, [tag_constants.SERVING], signature_def_map=signatures)
+
+ v1_builder.save()
+ signatures = _get_signatures_from_saved_model(calibrated_model_dir,
+ signature_keys, tags)
+
+ graph_def_serialized = (
+ quantize_model_wrapper.quantize_ptq_model_post_calibration(
+ calibrated_model_dir,
+ ','.join(signature_keys),
+ ','.join(tags),
+ ))
+
+ graph_def = graph_pb2.GraphDef()
+ graph_def.ParseFromString(graph_def_serialized)
+
+ if output_directory is None:
+ output_directory = tempfile.mkdtemp()
+ v1_builder = builder.SavedModelBuilder(output_directory)
+
+ with session.Session(graph=ops.Graph()) as sess:
+ importer.import_graph_def(graph_def, name='')
+ working_graph = ops.get_default_graph()
+
+ signatures = _fix_tensor_names(signatures, working_graph)
+ if signatures is None:
+ raise ValueError("The input SavedModel doesn't contain a valid signature")
+
+ v1_builder.add_meta_graph_and_variables(
+ sess, [tag_constants.SERVING], signature_def_map=signatures)
+
+ v1_builder.save()
+
+ return saved_model_load(output_directory)
diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.cc
new file mode 100644
index 0000000..3680832
--- /dev/null
+++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.cc
@@ -0,0 +1,144 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <pybind11/stl.h>
+
+#include <memory>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/stl.h"
+#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h"
+#include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h"
+#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
+#include "tensorflow/python/lib/core/pybind11_lib.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+
+using tensorflow::calibrator::CalibratorSingleton;
+
+PyObject* QuantizeQATModel(absl::string_view saved_model_path,
+ absl::string_view exported_names_str,
+ absl::string_view tags) {
+ auto graph_def_or =
+ mlir::quant::QuantizeQATModel(saved_model_path, exported_names_str, tags);
+ if (!graph_def_or.ok()) {
+ PyErr_Format(PyExc_ValueError,
+ graph_def_or.status().error_message().c_str());
+ return nullptr;
+ }
+
+ std::string ret_str = graph_def_or.ValueOrDie().SerializeAsString();
+
+ return tflite::python_utils::ConvertToPyString(ret_str.c_str(),
+ ret_str.size());
+}
+
+PyObject* QuantizePTQModelPreCalibration(absl::string_view saved_model_path,
+ absl::string_view exported_names_str,
+ absl::string_view tags) {
+ auto graph_def_or = mlir::quant::QuantizePTQModelPreCalibration(
+ saved_model_path, exported_names_str, tags);
+ if (!graph_def_or.ok()) {
+ PyErr_Format(PyExc_ValueError,
+ graph_def_or.status().error_message().c_str());
+ return nullptr;
+ }
+
+ std::string ret_str = graph_def_or.ValueOrDie().SerializeAsString();
+
+ return tflite::python_utils::ConvertToPyString(ret_str.c_str(),
+ ret_str.size());
+}
+
+PyObject* QuantizePTQModelPostCalibration(absl::string_view saved_model_path,
+ absl::string_view exported_names_str,
+ absl::string_view tags) {
+ auto graph_def_or = mlir::quant::QuantizePTQModelPostCalibration(
+ saved_model_path, exported_names_str, tags);
+ if (!graph_def_or.ok()) {
+ PyErr_Format(PyExc_ValueError,
+ graph_def_or.status().error_message().c_str());
+ return nullptr;
+ }
+
+ std::string ret_str = graph_def_or.ValueOrDie().SerializeAsString();
+
+ return tflite::python_utils::ConvertToPyString(ret_str.c_str(),
+ ret_str.size());
+}
+
+py::tuple GetMinMaxFromCalibrator(absl::string_view id) {
+ absl::optional<std::pair<float, float>> min_max =
+ CalibratorSingleton::GetMinMax(id);
+ if (!min_max.has_value()) {
+ PyErr_Format(PyExc_ValueError, "No calibrated data for '%s'",
+ std::string{id}.c_str());
+ throw py::error_already_set();
+ }
+
+ return py::make_tuple(min_max->first, min_max->second);
+}
+
+PYBIND11_MODULE(quantize_model_wrapper, m) {
+ m.def(
+ "clear_calibrator",
+ []() { CalibratorSingleton::ClearCollectedInformation(); },
+ R"pbdoc(
+ Clears the collected metrics from the calibrator.
+ )pbdoc");
+ m.def(
+ "clear_data_from_calibrator",
+ [](absl::string_view id) { CalibratorSingleton::ClearData(id); },
+ R"pbdoc(
+ Clears the collected data of the given id from calibrator.
+ )pbdoc");
+ m.def(
+ "get_min_max_from_calibrator",
+ [](absl::string_view id) { return GetMinMaxFromCalibrator(id); },
+ R"pbdoc(
+ Return the tuple with the min and max values of the given id.
+ )pbdoc");
+ m.def(
+ "quantize_qat_model",
+ [](absl::string_view saved_model_path,
+ absl::string_view exported_names_str, absl::string_view tags) {
+ return tensorflow::PyoOrThrow(
+ QuantizeQATModel(saved_model_path, exported_names_str, tags));
+ },
+ R"pbdoc(
+ Returns a tf model graph def string.
+ )pbdoc");
+ m.def(
+ "quantize_ptq_model_pre_calibration",
+ [](absl::string_view saved_model_path,
+ absl::string_view exported_names_str, absl::string_view tags) {
+ return tensorflow::PyoOrThrow(QuantizePTQModelPreCalibration(
+ saved_model_path, exported_names_str, tags));
+ },
+ R"pbdoc(
+ Returns a tf model graph def string.
+ )pbdoc");
+ m.def(
+ "quantize_ptq_model_post_calibration",
+ [](absl::string_view saved_model_path,
+ absl::string_view exported_names_str, absl::string_view tags) {
+ return tensorflow::PyoOrThrow(QuantizePTQModelPostCalibration(
+ saved_model_path, exported_names_str, tags));
+ },
+ R"pbdoc(
+ Returns a tf model graph def string.
+ )pbdoc");
+}