A pass to legalize tf quantization emulation ops to mlir quant dialect ops

Potentially this pass can also be shared with tflite. We put it as a TODO for now.

PiperOrigin-RevId: 292421480
Change-Id: Ica6e6270343700a704c0d00bba6640daf117d1b0
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 1fe2023..a9a6182 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -66,6 +66,7 @@
         "//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
         "//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
+        "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD
new file mode 100644
index 0000000..96d6c4f
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD
@@ -0,0 +1,36 @@
+package(
+    default_visibility = [
+        ":friends",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+package_group(
+    name = "friends",
+    includes = ["//third_party/mlir:subpackages"],
+    packages = [
+        "//tensorflow/compiler/mlir/...",
+        "//tensorflow/compiler/mlir/lite/...",
+    ],
+)
+
+cc_library(
+    name = "tf_to_quant",
+    srcs = [
+        "tf_to_quant.cc",
+    ],
+    hdrs = [
+        "passes.h",
+    ],
+    deps = [
+        "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
+        "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:QuantOps",
+    ],
+    alwayslink = 1,
+)
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h b/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h
new file mode 100644
index 0000000..c345da0
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
+
+#include <memory>
+
+#include "mlir/IR/Function.h"  // TF:llvm-project
+#include "mlir/Pass/Pass.h"  // TF:llvm-project
+
+namespace mlir {
+namespace TF {
+
+// Legalize the tf ops to the quant ops, so the quantization passes can work.
+std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
+
+}  // namespace TF
+}  // namespace mlir
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD
new file mode 100644
index 0000000..4faa8d2
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD
@@ -0,0 +1,19 @@
+load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+
+package(licenses = ["notice"])
+
+glob_lit_tests(
+    data = [":test_utilities"],
+    driver = "@llvm-project//mlir:run_lit.sh",
+    test_file_exts = ["mlir"],
+)
+
+# Bundle together all of the test utilities that are used by tests.
+filegroup(
+    name = "test_utilities",
+    testonly = True,
+    data = [
+        "//tensorflow/compiler/mlir:tf-opt",
+        "@llvm-project//llvm:FileCheck",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir
new file mode 100644
index 0000000..6664dca
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/tf_to_quant.mlir
@@ -0,0 +1,149 @@
+// RUN: tf-opt -tf-to-quant %s
+//| FileCheck %s
+
+// CHECK-LABEL: fakeQuantPerChannelForActivation
+func @fakeQuantPerChannelForActivation(%arg0: tensor<8x3xf32>) -> (tensor<8x3xf32>) {
+  %arg1 = constant dense<[0.0, -1.0, 1.0]> : tensor<3xf32>
+  %arg2 = constant dense<[255.0, 254.0, 256.0]> : tensor<3xf32>
+  %0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<8x3xf32>
+  return %0 : tensor<8x3xf32>
+
+// CHECK:  %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0)
+// CHECK:  %[[q:.*]] = "quant.qcast"(%[[fq]]) : (tensor<8x3xf32>) -> tensor<8x3x!quant.uniform<i8:f32:1, {1.000000e+00:-128,1.000000e+00:-127,1.000000e+00:-128}>>
+// CHECK:  %[[dq:.*]] = "quant.dcast"(%[[q]])
+// CHECK:  return %[[dq]]
+}
+
+// CHECK-LABEL: fakeQuantForActivation
+func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) {
+^bb0(%arg0: tensor<8xf32>):
+  %arg1 = constant dense<0.0> : tensor<f32>
+  %arg2 = constant dense<255.0> : tensor<f32>
+  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+  return %0 : tensor<8xf32>
+
+// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0)
+// CHECK:  %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+// CHECK:  %2 = "quant.dcast"(%1)
+// CHECK:  return %2
+}
+
+// CHECK-LABEL: fakeQuantForActivationNoDuplication
+func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>) {
+^bb0(%arg0: tensor<8xf32>):
+  %arg1 = constant dense<0.0> : tensor<f32>
+  %arg2 = constant dense<255.0> : tensor<f32>
+  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+  %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+  return %1 : tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+
+// CHECK:  %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
+// CHECK:  %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+// CHECK:  return %1
+}
+
+// CHECK-LABEL: fakeQuantFolded
+func @fakeQuantFolded() -> (tensor<8xf32>) {
+  %in = constant dense<0.0> : tensor<8xf32>
+  %min = constant dense<0.0> : tensor<f32>
+  %max = constant dense<255.0> : tensor<f32>
+  %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
+  %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
+  %rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+  return %rst : tensor<8xf32>
+
+// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>}
+// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT]]) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
+// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
+}
+
+// CHECK-LABEL: fakeQuantNotFolded
+func @fakeQuantNotFolded(tensor<8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8xf32>) {
+^bb0(%arg0: tensor<8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
+  %1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
+  return %1 : tensor<8xf32>
+
+// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2)
+// CHECK: return %0 : tensor<8xf32>
+}
+
+// CHECK-LABEL: fakeQuantWithConv2D
+func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+^bb0(%arg: tensor<256x32x32x3xf32>) :
+  %in = constant dense<0.0> : tensor<3x3x3x16xf32>
+  %min = constant dense<0.0> : tensor<f32>
+  %max = constant dense<255.0> : tensor<f32>
+  %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
+  %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
+  %fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
+  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+  return %rst : tensor<256x30x30x16xf32>
+
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
+// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
+// CHECK: return %[[CONV]]
+}
+
+// CHECK-LABEL: perChannelFakeQuantWithConv2D
+func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+^bb0(%arg: tensor<256x32x32x3xf32>) :
+  %in = constant dense<0.0> : tensor<3x3x3x16xf32>
+  %min = constant dense<0.0> : tensor<16xf32>
+  %max = constant dense<255.0> : tensor<16xf32>
+  %mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
+  %maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
+  %fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
+  %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+  return %rst : tensor<256x30x30x16xf32>
+
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
+// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
+// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
+// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
+// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
+// CHECK: return %[[CONV]] : tensor<256x30x30x16xf32>
+}
+
+// CHECK-LABEL: fakeQuantWithDepthwiseConv2D
+func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+^bb0(%arg: tensor<256x32x32x3xf32>) :
+  %in = constant dense<0.0> : tensor<3x3x3x16xf32>
+  %min = constant dense<0.0> : tensor<f32>
+  %max = constant dense<255.0> : tensor<f32>
+  %mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
+  %maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
+  %fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
+  %rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+  return %rst : tensor<256x30x30x16xf32>
+
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
+// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
+// CHECK: return %[[CONV]]
+}
+
+// CHECK-LABEL: perChannelFakeQuantWithDepthwiseConv2D
+func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
+^bb0(%arg: tensor<256x32x32x3xf32>) :
+  %in = constant dense<0.0> : tensor<3x3x3x16xf32>
+  %min = constant dense<0.0> : tensor<16xf32>
+  %max = constant dense<255.0> : tensor<16xf32>
+  %mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
+  %maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
+  %fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
+  %rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
+  return %rst : tensor<256x30x30x16xf32>
+
+// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
+// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
+// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
+// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
+// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
+// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
+// CHECK: return %[[CONV]]
+}
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc
new file mode 100644
index 0000000..64fddd0
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc
@@ -0,0 +1,162 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "mlir/Dialect/QuantOps/QuantOps.h"  // TF:llvm-project
+#include "mlir/IR/PatternMatch.h"  // TF:llvm-project
+#include "mlir/Pass/Pass.h"  // TF:llvm-project
+#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+
+namespace mlir {
+namespace TF {
+
+//===----------------------------------------------------------------------===//
+// The pass to legalize the quantization emulation ops from TF.
+//
+namespace {
+
+// Legalize TF quantization emulation ops to that in Quant ops dialect.
+struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
+  explicit LegalizeTFToQuant() = default;
+  LegalizeTFToQuant(const LegalizeTFToQuant &) {}
+
+  /// Performs the lowering to Quant ops dialect.
+  void runOnFunction() override;
+};
+
+// TODO(fengliuai): move this rule to PreparePatterns.td
+// TODO(b/140968741): propagate the sign from the command line. Currently all
+// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
+// actually INT8.
+// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
+// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
+// folding logic will use a "std.constant" op to replace the
+// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
+// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
+// convert the output type to the next op. Here are the transformations:
+//
+// input   min cst       max cst          input   min cst       max cst
+//  \       |             |                \       |             |
+//   \  (tf.Identity) (tf.Identity)   =>    \  (tf.Identity) (tf.Identity)
+//    \     |             |                  \     |             |
+//       tf.FakeQuantWithMinMaxVars       tf.FakeQuantWithMinMaxVars
+//                   |                                 |
+//                                                tf.quantize
+//                                                     |
+//                                                tf.dequantize
+//                                                     |
+// If the input is a constant, the result pattern will eventually converted to
+//
+//            quant-emulated input
+//                   |
+//               tf.quantize
+//                   |
+//              tf.dequantize
+//                   |
+template <typename TFFakeQuantOp, bool PerAxis>
+struct InsertQuantOpsAfterTFFakeQuantOp
+    : public OpRewritePattern<TFFakeQuantOp> {
+  using BaseType = InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>;
+
+  explicit InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>(
+      MLIRContext *ctx)
+      : OpRewritePattern<TFFakeQuantOp>(ctx) {}
+
+  PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
+                                     PatternRewriter &rewriter) const override {
+    // We don't want to insert quantize/dequantize if the quantize op exists.
+    auto res = tf_op.outputs();
+    if (!res.hasOneUse() || isa<quant::QuantizeCastOp>(*res.user_begin()))
+      return this->matchFailure();
+
+    // Extract the min/max constant values from the operands. We also consider
+    // a special case that there are tf.Identity ops between the min/max
+    // constants and the tf.FakeQuantWithMinMaxVarsOp.
+    Value min = tf_op.min(), max = tf_op.max();
+    DenseFPElementsAttr min_value, max_value;
+    if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp())) {
+      id1.replaceAllUsesWith(id1.input());
+      min = tf_op.min();
+      rewriter.eraseOp(id1);
+    }
+    if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp())) {
+      id2.replaceAllUsesWith(id2.input());
+      max = tf_op.max();
+      rewriter.eraseOp(id2);
+    }
+    if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
+    if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
+
+    int quant_dim = -1;
+    if (PerAxis) {
+      // This is a special case that the quant_dim is the last dimensions
+      // according to the tf.FakeQuantWithMinMaxPerChannel.
+      quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
+    }
+    // Use the min/max from the operands and the num_bits and narrow_range
+    // attribute to create the quantization parameter for the new quantize op.
+    rewriter.setInsertionPointAfter(tf_op);
+    IntegerAttr num_bits =
+        rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
+    BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
+    Type res_type = tf_op.getType();
+    TypeAttr qtype = quant::GetQuantizedTypeAttr(
+        rewriter, res_type, min_value, max_value, quant_dim, num_bits,
+        narrow_range, /*is_signed=*/true);
+    if (!qtype) this->matchFailure();
+
+    // Finally, use the quantization parameter to create the quantize and
+    // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
+    // and its users.
+    Value value = tf_op.outputs();
+    auto quantize = rewriter.create<quant::QuantizeCastOp>(
+        tf_op.getLoc(), qtype.getValue(), value);
+    auto dequantize = rewriter.create<quant::DequantizeCastOp>(
+        tf_op.getLoc(), res_type, quantize.getResult());
+    value.replaceAllUsesWith(dequantize);
+    quantize.getOperation()->replaceUsesOfWith(dequantize, value);
+
+    return this->matchSuccess();
+  }
+};
+
+using PreparePerTensorFakeQuant =
+    InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsOp, false>;
+
+using PreparePerChannelFakeQuant =
+    InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsPerChannelOp,
+                                     true>;
+
+// TODO(fengliuai): add the support of the tf.QuantizeAndDequantize*
+// legalization.
+
+void LegalizeTFToQuant::runOnFunction() {
+  OwningRewritePatternList patterns;
+  auto func = getFunction();
+  auto *ctx = func.getContext();
+  patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
+  applyPatternsGreedily(func, patterns);
+}
+}  // namespace
+
+// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
+std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
+  return std::make_unique<LegalizeTFToQuant>();
+}
+
+static PassRegistration<LegalizeTFToQuant> pass(
+    "tf-to-quant", "Legalize TF to quant ops dialect");
+
+}  // namespace TF
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
new file mode 100644
index 0000000..9a36428
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
@@ -0,0 +1 @@
+# TODO(fengliuai): describe this package.
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index ff6b84d..3419ee2 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -82,6 +82,7 @@
 };
 
 // TODO(fengliuai): move this rule to PreparePatterns.td
+// TODO(fengliuai): reuse the quantization/tensorflow/tf_to_quant pass.
 // TODO(b/140968741): propagate the sign from the command line. Currently all
 // the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
 // actually INT8.