A pass to legalize tf quantization emulation ops to mlir quant dialect ops
Potentially this pass can also be shared with tflite. We put it as a TODO for now.
PiperOrigin-RevId: 292421480
Change-Id: Ica6e6270343700a704c0d00bba6640daf117d1b0
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 1fe2023..a9a6182 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -66,6 +66,7 @@
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
+ "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD
new file mode 100644
index 0000000..96d6c4f
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD
@@ -0,0 +1,36 @@
+package(
+ default_visibility = [
+ ":friends",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+package_group(
+ name = "friends",
+ includes = ["//third_party/mlir:subpackages"],
+ packages = [
+ "//tensorflow/compiler/mlir/...",
+ "//tensorflow/compiler/mlir/lite/...",
+ ],
+)
+
+cc_library(
+ name = "tf_to_quant",
+ srcs = [
+ "tf_to_quant.cc",
+ ],
+ hdrs = [
+ "passes.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
+ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
+ "//tensorflow/compiler/mlir/tensorflow",
+ "@com_google_absl//absl/strings",
+ "@llvm-project//llvm:support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:QuantOps",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h b/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h
new file mode 100644
index 0000000..c345da0
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
+
+#include <memory>
+
+#include "mlir/IR/Function.h" // TF:llvm-project
+#include "mlir/Pass/Pass.h" // TF:llvm-project
+
+namespace mlir {
+namespace TF {
+
+// Legalize the tf ops to the quant ops, so the quantization passes can work.
+std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
+
+} // namespace TF
+} // namespace mlir
+#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD
new file mode 100644
index 0000000..4faa8d2
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD
@@ -0,0 +1,19 @@
+load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+
+package(licenses = ["notice"])
+
+glob_lit_tests(
+ data = [":test_utilities"],
+ driver = "@llvm-project//mlir:run_lit.sh",
+ test_file_exts = ["mlir"],
+)
+
+# Bundle together all of the test utilities that are used by tests.
+filegroup(
+ name = "test_utilities",
+ testonly = True,
+ data = [
+ "//tensorflow/compiler/mlir:tf-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir
new file mode 100644
index 0000000..6664dca
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir
@@ -0,0 +1,149 @@
+// RUN: tf-opt -tf-to-quant %s
+//| FileCheck %s
+
+// CHECK-LABEL: fakeQuantPerChannelForActivation
+func @fakeQuantPerChannelForActivation(%arg0: tensor<8x3xf32>) -> (tensor<8x3xf32>) {
+ %arg1 = constant dense<[0.0, -1.0, 1.0]> : tensor<3xf32>
+ %arg2 = constant dense<[255.0, 254.0, 256.0]> : tensor<3xf32>
+ %0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<8x3xf32>
+ return %0 : tensor<8x3xf32>
+
+// CHECK: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0)
+// CHECK: %[[q:.*]] = "quant.qcast"(%[[fq]]) : (tensor<8x3xf32>) -> tensor<8x3x!quant.uniform<i8:f32:1, {1.000000e+00:-128,1.000000e+00:-127,1.000000e+00:-128}>>
+// CHECK: %[[dq:.*]] = "quant.dcast"(%[[q]])
+// CHECK: return %[[dq]]
+}
+
+// CHECK-LABEL: fakeQuantForActivation
+func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) {
+^bb0(%arg0: tensor<8xf32>):
+ %arg1 = constant dense<0.0> : tensor<f32>
+ %arg2 = constant dense<255.0> : tensor<f32>
+ %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+ return %0 : tensor<8xf32>
+
+// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0)
+// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+// CHECK: %2 = "quant.dcast"(%1)
+// CHECK: return %2
+}
+
+// CHECK-LABEL: fakeQuantForActivationNoDuplication
+func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>) {
+^bb0(%arg0: tensor<8xf32>):
+ %arg1 = constant dense<0.0> : tensor<f32>
+ %arg2 = constant dense<255.0> : tensor<f32>
+ %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+ %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+ return %1 : tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+
+// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
+// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+// CHECK: return %1
+}
+
+// CHECK-LABEL: fakeQuantFolded
+func @fakeQuantFolded() -> (tensor<8xf32>) {
+ %in = constant dense<0.0> : tensor<8xf32>
+ %min = constant dense<0.0> : tensor<f32>
+ %max = constant dense<255.0> : tensor<f32>
+ %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
+ %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
+ %rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+ return %rst : tensor<8xf32>
+
+// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>}
+// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT]]) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
+// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
+}
+
+// CHECK-LABEL: fakeQuantNotFolded
+func @fakeQuantNotFolded(tensor<8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8xf32>) {
+^bb0(%arg0: tensor<8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+ return %1 : tensor<8xf32>
+
+// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2)
+// CHECK: return %0 : tensor<8xf32>
+}
+
+// CHECK-LABEL: fakeQuantWithConv2D
+func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+^bb0(%arg: tensor<256x32x32x3xf32>) :
+ %in = constant dense<0.0> : tensor<3x3x3x16xf32>
+ %min = constant dense<0.0> : tensor<f32>
+ %max = constant dense<255.0> : tensor<f32>
+ %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
+ %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
+ %fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
+ %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+ return %rst : tensor<256x30x30x16xf32>
+
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
+// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
+// CHECK: return %[[CONV]]
+}
+
+// CHECK-LABEL: perChannelFakeQuantWithConv2D
+func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+^bb0(%arg: tensor<256x32x32x3xf32>) :
+ %in = constant dense<0.0> : tensor<3x3x3x16xf32>
+ %min = constant dense<0.0> : tensor<16xf32>
+ %max = constant dense<255.0> : tensor<16xf32>
+ %mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
+ %maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
+ %fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
+ %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+ return %rst : tensor<256x30x30x16xf32>
+
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
+// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
+// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
+// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
+// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
+// CHECK: return %[[CONV]] : tensor<256x30x30x16xf32>
+}
+
+// CHECK-LABEL: fakeQuantWithDepthwiseConv2D
+func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+^bb0(%arg: tensor<256x32x32x3xf32>) :
+ %in = constant dense<0.0> : tensor<3x3x3x16xf32>
+ %min = constant dense<0.0> : tensor<f32>
+ %max = constant dense<255.0> : tensor<f32>
+ %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
+ %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
+ %fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
+ %rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+ return %rst : tensor<256x30x30x16xf32>
+
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
+// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
+// CHECK: return %[[CONV]]
+}
+
+// CHECK-LABEL: perChannelFakeQuantWithDepthwiseConv2D
+func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+^bb0(%arg: tensor<256x32x32x3xf32>) :
+ %in = constant dense<0.0> : tensor<3x3x3x16xf32>
+ %min = constant dense<0.0> : tensor<16xf32>
+ %max = constant dense<255.0> : tensor<16xf32>
+ %mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
+ %maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
+ %fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
+ %rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+ return %rst : tensor<256x30x30x16xf32>
+
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
+// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
+// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
+// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
+// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
+// CHECK: return %[[CONV]]
+}
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc
new file mode 100644
index 0000000..64fddd0
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc
@@ -0,0 +1,162 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
+#include "mlir/IR/PatternMatch.h" // TF:llvm-project
+#include "mlir/Pass/Pass.h" // TF:llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+
+namespace mlir {
+namespace TF {
+
+//===----------------------------------------------------------------------===//
+// The pass to legalize the quantization emulation ops from TF.
+//
+namespace {
+
+// Legalize TF quantization emulation ops to that in Quant ops dialect.
+struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
+ explicit LegalizeTFToQuant() = default;
+ LegalizeTFToQuant(const LegalizeTFToQuant &) {}
+
+ /// Performs the lowering to Quant ops dialect.
+ void runOnFunction() override;
+};
+
+// TODO(fengliuai): move this rule to PreparePatterns.td
+// TODO(b/140968741): propagate the sign from the command line. Currently all
+// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
+// actually INT8.
+// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
+// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
+// folding logic will use a "std.constant" op to replace the
+// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
+// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
+// convert the output type to the next op. Here are the transformations:
+//
+// input min cst max cst input min cst max cst
+// \ | | \ | |
+// \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity)
+// \ | | \ | |
+// tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars
+// | |
+// tf.quantize
+// |
+// tf.dequantize
+// |
+// If the input is a constant, the result pattern will eventually converted to
+//
+// quant-emulated input
+// |
+// tf.quantize
+// |
+// tf.dequantize
+// |
+template <typename TFFakeQuantOp, bool PerAxis>
+struct InsertQuantOpsAfterTFFakeQuantOp
+ : public OpRewritePattern<TFFakeQuantOp> {
+ using BaseType = InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>;
+
+ explicit InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>(
+ MLIRContext *ctx)
+ : OpRewritePattern<TFFakeQuantOp>(ctx) {}
+
+ PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
+ PatternRewriter &rewriter) const override {
+ // We don't want to insert quantize/dequantize if the quantize op exists.
+ auto res = tf_op.outputs();
+ if (!res.hasOneUse() || isa<quant::QuantizeCastOp>(*res.user_begin()))
+ return this->matchFailure();
+
+ // Extract the min/max constant values from the operands. We also consider
+ // a special case that there are tf.Identity ops between the min/max
+ // constants and the tf.FakeQuantWithMinMaxVarsOp.
+ Value min = tf_op.min(), max = tf_op.max();
+ DenseFPElementsAttr min_value, max_value;
+ if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp())) {
+ id1.replaceAllUsesWith(id1.input());
+ min = tf_op.min();
+ rewriter.eraseOp(id1);
+ }
+ if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp())) {
+ id2.replaceAllUsesWith(id2.input());
+ max = tf_op.max();
+ rewriter.eraseOp(id2);
+ }
+ if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
+ if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
+
+ int quant_dim = -1;
+ if (PerAxis) {
+ // This is a special case that the quant_dim is the last dimensions
+ // according to the tf.FakeQuantWithMinMaxPerChannel.
+ quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
+ }
+ // Use the min/max from the operands and the num_bits and narrow_range
+ // attribute to create the quantization parameter for the new quantize op.
+ rewriter.setInsertionPointAfter(tf_op);
+ IntegerAttr num_bits =
+ rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
+ BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
+ Type res_type = tf_op.getType();
+ TypeAttr qtype = quant::GetQuantizedTypeAttr(
+ rewriter, res_type, min_value, max_value, quant_dim, num_bits,
+ narrow_range, /*is_signed=*/true);
+ if (!qtype) this->matchFailure();
+
+ // Finally, use the quantization parameter to create the quantize and
+ // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
+ // and its users.
+ Value value = tf_op.outputs();
+ auto quantize = rewriter.create<quant::QuantizeCastOp>(
+ tf_op.getLoc(), qtype.getValue(), value);
+ auto dequantize = rewriter.create<quant::DequantizeCastOp>(
+ tf_op.getLoc(), res_type, quantize.getResult());
+ value.replaceAllUsesWith(dequantize);
+ quantize.getOperation()->replaceUsesOfWith(dequantize, value);
+
+ return this->matchSuccess();
+ }
+};
+
+using PreparePerTensorFakeQuant =
+ InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsOp, false>;
+
+using PreparePerChannelFakeQuant =
+ InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsPerChannelOp,
+ true>;
+
+// TODO(fengliuai): add the support of the tf.QuantizeAndDequantize*
+// legalization.
+
+void LegalizeTFToQuant::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto func = getFunction();
+ auto *ctx = func.getContext();
+ patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
+ applyPatternsGreedily(func, patterns);
+}
+} // namespace
+
+// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
+std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
+ return std::make_unique<LegalizeTFToQuant>();
+}
+
+static PassRegistration<LegalizeTFToQuant> pass(
+ "tf-to-quant", "Legalize TF to quant ops dialect");
+
+} // namespace TF
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
new file mode 100644
index 0000000..9a36428
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
@@ -0,0 +1 @@
+# TODO(fengliuai): describe this package.
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index ff6b84d..3419ee2 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -82,6 +82,7 @@
};
// TODO(fengliuai): move this rule to PreparePatterns.td
+// TODO(fengliuai): reuse the quantization/tensorflow/tf_to_quant pass.
// TODO(b/140968741): propagate the sign from the command line. Currently all
// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
// actually INT8.