Only infer the fixed output range when the input graph has dequantize ops

PiperOrigin-RevId: 327312937
Change-Id: Ice4c2e35aeb074e34516dc434ca0c066af947ca8
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
index 9e0ad99..16b5149 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
@@ -99,12 +99,14 @@
  public:
   explicit QuantizationDriver(FuncOp fn, bool is_signed,
                               bool disable_per_channel,
-                              OpQuantSpecGetter op_quant_spec_getter)
+                              OpQuantSpecGetter op_quant_spec_getter,
+                              bool enforce_fixed_output_range)
       : fn_(fn),
         builder_(fn.getBody()),
         is_signed_(is_signed),
         disable_per_channel_(disable_per_channel),
-        op_quant_spec_getter_(op_quant_spec_getter) {}
+        op_quant_spec_getter_(op_quant_spec_getter),
+        enforce_fixed_output_range_(enforce_fixed_output_range) {}
 
   // The entry point of the quantization parameters propagation.
   void Run();
@@ -354,6 +356,8 @@
   llvm::SmallVector<BlockArgument, 4> args_;
 
   OpQuantSpecGetter op_quant_spec_getter_;
+
+  bool enforce_fixed_output_range_;
 };
 }  // namespace
 
@@ -794,7 +798,8 @@
     }
 
     // TODO(fengliuai): make the bit width configurable.
-    if (auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op)) {
+    auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
+    if (restricted && enforce_fixed_output_range_) {
       // TODO(fengliuai): different result can have different fixed range.
       auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
       for (auto i = 0; i < op->getNumResults(); ++i) {
@@ -864,10 +869,12 @@
   }
 }
 
-void ApplyQuantizationParamsPropagation(
-    mlir::FuncOp func, bool is_signed, bool disable_per_channel,
-    OpQuantSpecGetter op_quant_spec_getter) {
-  QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter)
+void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
+                                        bool disable_per_channel,
+                                        OpQuantSpecGetter op_quant_spec_getter,
+                                        bool post_training_quantization) {
+  QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
+                     post_training_quantization)
       .Run();
 }
 
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
index 07e5ba4..6e356ac 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
@@ -490,9 +490,13 @@
 // and the propagation results are materialized by inserting pairs of quantize
 // and dequantize ops to this function. Set `disable_per_channel` to true to not
 // use per channel quantization even the op supports it.
+// Setting `enforce_fixed_output_range` to true, to infer quantization
+// parameters from the fixed output range ops. This is only used for
+// post-training quantization.
 void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
                                         bool disable_per_channel,
-                                        OpQuantSpecGetter op_quant_spec_getter);
+                                        OpQuantSpecGetter op_quant_spec_getter,
+                                        bool enforce_fixed_output_range);
 
 // The function might contain more stats ops than required, and it will
 // introduce requantize if the calibration stats have conflicts. This method
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
index 9a27d0de..07b7aac 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
@@ -23,6 +23,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/CommandLine.h"
 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
@@ -122,6 +123,10 @@
   // the best quantization practise. This also fixes some simple violations.
   void SanityCheckAndAdjustment(FuncOp func);
 
+  // Whether the func contains Quantize ops. This is used to determine whether
+  // to use the quantization parameters from the fixed output range property.
+  bool ContainsQuantizeOps(FuncOp func);
+
   QuantizationSpecs quant_specs_;
 };
 
@@ -285,6 +290,13 @@
   });
 }
 
+bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
+  for (const auto& op : func.getOps()) {
+    if (llvm::isa<quant::DequantizeCastOp>(op)) return true;
+  }
+  return false;
+}
+
 using PrepareQuantStats =
     quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
 
@@ -309,6 +321,7 @@
   OwningRewritePatternList patterns;
   bool is_signed = quant_specs_.IsSignedInferenceType();
   int bit_width = quant_specs_.GetQuantizationTypeWidth();
+  bool enforce_fixed_output_range = ContainsQuantizeOps(func);
   if (is_signed) {
     patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
     // Convert quant stats to int8 quantization parameters.
@@ -327,7 +340,8 @@
   // values (tensors).
   ApplyQuantizationParamsPropagation(
       func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
-      GetOpQuantSpec);
+      GetOpQuantSpec,
+      enforce_fixed_output_range || quant_specs_.post_training_quantization);
 
   ConvertMlirQuantOpsToTFLQuantOps(func);
 }