Only infer the fixed output range when the input graph has dequantize ops
PiperOrigin-RevId: 327312937
Change-Id: Ice4c2e35aeb074e34516dc434ca0c066af947ca8
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
index 9e0ad99..16b5149 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
@@ -99,12 +99,14 @@
public:
explicit QuantizationDriver(FuncOp fn, bool is_signed,
bool disable_per_channel,
- OpQuantSpecGetter op_quant_spec_getter)
+ OpQuantSpecGetter op_quant_spec_getter,
+ bool enforce_fixed_output_range)
: fn_(fn),
builder_(fn.getBody()),
is_signed_(is_signed),
disable_per_channel_(disable_per_channel),
- op_quant_spec_getter_(op_quant_spec_getter) {}
+ op_quant_spec_getter_(op_quant_spec_getter),
+ enforce_fixed_output_range_(enforce_fixed_output_range) {}
// The entry point of the quantization parameters propagation.
void Run();
@@ -354,6 +356,8 @@
llvm::SmallVector<BlockArgument, 4> args_;
OpQuantSpecGetter op_quant_spec_getter_;
+
+ bool enforce_fixed_output_range_;
};
} // namespace
@@ -794,7 +798,8 @@
}
// TODO(fengliuai): make the bit width configurable.
- if (auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op)) {
+ auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
+ if (restricted && enforce_fixed_output_range_) {
// TODO(fengliuai): different result can have different fixed range.
auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
for (auto i = 0; i < op->getNumResults(); ++i) {
@@ -864,10 +869,12 @@
}
}
-void ApplyQuantizationParamsPropagation(
- mlir::FuncOp func, bool is_signed, bool disable_per_channel,
- OpQuantSpecGetter op_quant_spec_getter) {
- QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter)
+void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
+ bool disable_per_channel,
+ OpQuantSpecGetter op_quant_spec_getter,
+ bool post_training_quantization) {
+ QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
+ post_training_quantization)
.Run();
}
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
index 07e5ba4..6e356ac 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
@@ -490,9 +490,13 @@
// and the propagation results are materialized by inserting pairs of quantize
// and dequantize ops to this function. Set `disable_per_channel` to true to not
// use per channel quantization even the op supports it.
+// Setting `enforce_fixed_output_range` to true, to infer quantization
+// parameters from the fixed output range ops. This is only used for
+// post-training quantization.
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
bool disable_per_channel,
- OpQuantSpecGetter op_quant_spec_getter);
+ OpQuantSpecGetter op_quant_spec_getter,
+ bool enforce_fixed_output_range);
// The function might contain more stats ops than required, and it will
// introduce requantize if the calibration stats have conflicts. This method
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
index 9a27d0de..07b7aac 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
@@ -23,6 +23,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
+#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
@@ -122,6 +123,10 @@
// the best quantization practise. This also fixes some simple violations.
void SanityCheckAndAdjustment(FuncOp func);
+ // Whether the func contains Quantize ops. This is used to determine whether
+ // to use the quantization parameters from the fixed output range property.
+ bool ContainsQuantizeOps(FuncOp func);
+
QuantizationSpecs quant_specs_;
};
@@ -285,6 +290,13 @@
});
}
+bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
+ for (const auto& op : func.getOps()) {
+ if (llvm::isa<quant::DequantizeCastOp>(op)) return true;
+ }
+ return false;
+}
+
using PrepareQuantStats =
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
@@ -309,6 +321,7 @@
OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType();
int bit_width = quant_specs_.GetQuantizationTypeWidth();
+ bool enforce_fixed_output_range = ContainsQuantizeOps(func);
if (is_signed) {
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters.
@@ -327,7 +340,8 @@
// values (tensors).
ApplyQuantizationParamsPropagation(
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
- GetOpQuantSpec);
+ GetOpQuantSpec,
+ enforce_fixed_output_range || quant_specs_.post_training_quantization);
ConvertMlirQuantOpsToTFLQuantOps(func);
}