| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This transformation pass applies quantization propagation on TFLite dialect. |
| #include <iterator> |
| #include <string> |
| |
| #include "absl/memory/memory.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project |
| #include "mlir/IR/Function.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/PatternMatch.h" // from @llvm-project |
| #include "mlir/IR/Value.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" |
| #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" |
| #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" |
| #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" |
| #include "tensorflow/compiler/mlir/lite/transforms/passes.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| |
| // NOLINTNEXTLINE |
| static llvm::cl::list<std::string> quantize_allowlist( |
| "tfl-test-quantize-allowlist", llvm::cl::value_desc("list"), |
| llvm::cl::desc("comma separated list of allowlisted functions to be " |
| "quantized. Only used in tests"), |
| llvm::cl::CommaSeparated); |
| |
| // NOLINTNEXTLINE |
| static llvm::cl::opt<bool> quantize_signed( |
| "tfl-test-quantize-signed", llvm::cl::value_desc("bool"), |
| llvm::cl::desc("signed inference type. Only used in tests"), |
| llvm::cl::init(false)); |
| |
| // NOLINTNEXTLINE |
| static llvm::cl::opt<bool> disable_per_channel( |
| "tfl-disable-per-channel", llvm::cl::value_desc("bool"), |
| llvm::cl::desc("Whether disable per-channel quantized weights."), |
| llvm::cl::init(false)); |
| |
| //===----------------------------------------------------------------------===// |
| // The prepare-quantize Pass. |
| // |
| namespace mlir { |
| namespace TFL { |
| |
| namespace { |
| |
| // Applies prepare quantization on the model in TFL dialect. This pass runs |
| // before the quantization pass and propagate the quantization parameters |
| // across ops. This step is necessary for post-training quantization and also |
| // making the quantization rule for some operations in the quantization-aware |
| // training quantization simpler. |
| class PrepareQuantizePass |
| : public PassWrapper<PrepareQuantizePass, FunctionPass> { |
| public: |
| // Constructor used by the PassRegistration and enforce uint8 quantization. |
| // This is only used by test. |
| explicit PrepareQuantizePass() { |
| if (quantize_signed) |
| quant_specs_.inference_type = tensorflow::DT_QINT8; |
| else |
| quant_specs_.inference_type = tensorflow::DT_QUINT8; |
| } |
| |
| // Constructor used by manually creating the pass. |
| explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs) |
| : quant_specs_(quant_specs) {} |
| |
| void runOnFunction() override; |
| |
| private: |
| // Set the quantization parameters of the input nodes. These parameters are |
| // converted from the user specified input value ranges. The input nodes with |
| // non-float tensor types will be skipped because they are not quantizable. |
| // Return true if number of input nodes doesn't equal to that of the input |
| // ranges. |
| bool SetInputNodesQuantizationParams(FuncOp func); |
| |
| // The function might contain more stats ops than required, and it will |
| // introduce requantize if the calibration stats have conflicts. This method |
| // tries to remove all the redundant stats ops. |
| bool RemoveRedundantStats(FuncOp func); |
| |
| // Verify the quantization specification is expected for quantizing the |
| // current function. |
| bool IsLegalQuantSpecs(FuncOp func) { |
| if (func.getName() == quant_specs_.target_func) { |
| return func.getNumArguments() == quant_specs_.input_ranges.size(); |
| } |
| return true; |
| } |
| |
| // Get the min and max values from the quantization specification for the |
| // current function function and argument index. Uses default values if |
| // the function is specified in the `quantize_allowlist`. |
| std::pair<llvm::Optional<double>, llvm::Optional<double>> |
| GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) { |
| if (func_name == quant_specs_.target_func) { |
| return quant_specs_.input_ranges[index]; |
| } else { |
| return {0.0, 255.0}; |
| } |
| } |
| |
| // Apply some sanity check and report some warnings for those don't follow |
| // the best quantization practise. This also fixes some simple violations. |
| void SanityCheckAndAdjustment(FuncOp func); |
| |
| // Whether the func contains Quantize ops. This is used to determine whether |
| // to use the quantization parameters from the fixed output range property. |
| bool ContainsQuantizeOps(FuncOp func); |
| |
| QuantizationSpecs quant_specs_; |
| }; |
| |
| bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) { |
| StringRef func_name = func.getName(); |
| auto& target_func = quant_specs_.target_func; |
| |
| // Skip this function because it isn't the target function from the spec or |
| // in the function while list. |
| if (target_func != func_name && |
| !llvm::is_contained(quantize_allowlist, func_name)) { |
| return false; |
| } |
| |
| // If the validation fails, the pass should stop immediately. |
| if (!IsLegalQuantSpecs(func)) { |
| return true; |
| } |
| |
| OpBuilder builder(func); |
| bool is_signed = quant_specs_.IsSignedInferenceType(); |
| IntegerAttr num_bits = |
| builder.getI32IntegerAttr(quant_specs_.GetQuantizationTypeWidth()); |
| BoolAttr narrow_range = builder.getBoolAttr(false); |
| |
| auto add_quantize_op = [&](Location loc, Type input_type, Block* block, |
| Block::iterator insertion_point, Value arg, |
| int i) { |
| if (auto shaped = input_type.dyn_cast<ShapedType>()) { |
| if (shaped.getElementType().isa<FloatType>()) { |
| // If there are existing quantize ops, they are from training and we |
| // should respect them. |
| if (arg.hasOneUse() && |
| llvm::isa<quant::QuantizeCastOp>(*arg.user_begin())) { |
| return; |
| } |
| |
| auto min_max = GetMinMaxValuesForArgument(func_name, i); |
| // The input min/max or mean/std are not specified, then skip. |
| if (!min_max.first.hasValue() || !min_max.second.hasValue()) return; |
| |
| TypeAttr params = quant::GetQuantizedTypeAttr( |
| builder, input_type, |
| builder.getF64FloatAttr(min_max.first.getValue()), |
| builder.getF64FloatAttr(min_max.second.getValue()), |
| /*quant_dim=*/-1, num_bits, narrow_range, is_signed); |
| builder.setInsertionPoint(block, insertion_point); |
| auto q_op = |
| builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg); |
| auto dq_op = builder.create<quant::DequantizeCastOp>(loc, input_type, |
| q_op.getResult()); |
| arg.replaceAllUsesWith(dq_op.getResult()); |
| q_op.setOperand(arg); |
| } |
| } |
| }; |
| |
| for (int i = 0, e = func.getNumArguments(); i != e; ++i) { |
| BlockArgument arg = func.getArgument(i); |
| auto* arg_block = arg.getOwner(); |
| add_quantize_op(arg.getLoc(), arg.getType(), arg_block, |
| std::next(arg_block->begin(), i), arg, i); |
| } |
| |
| return false; |
| } |
| |
| #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc" |
| |
| bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) { |
| return RemoveRedundantStatsOps(func, GetOpQuantSpec); |
| } |
| |
| static Value Quantized(Operation* user) { |
| if (auto q = llvm::dyn_cast_or_null<quant::QuantizeCastOp>(user)) { |
| if (auto dq = llvm::dyn_cast_or_null<quant::DequantizeCastOp>( |
| *q.getResult().user_begin())) { |
| return dq.getResult(); |
| } |
| } |
| return {}; |
| } |
| |
| void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) { |
| // If an op output has two users: one of them is a quantize op and another |
| // one is returned directly, we decide to return the quantized result instead, |
| // so this op can be quantized. This is only applied on the returned result |
| // because the error will not be accumulated. |
| |
| func.walk([&](ReturnOp ret) { |
| int i = 0; |
| for (Value returned : ret.operands()) { |
| llvm::SmallVector<Value, 4> quantized; |
| for (auto user : returned.getUsers()) { |
| if (auto q = Quantized(user)) { |
| quantized.push_back(q); |
| } |
| } |
| if (quantized.size() == 1) { |
| ret.setOperand(i, quantized.front()); |
| } |
| i++; |
| } |
| }); |
| |
| // We prefer to placing quantization emulation ops on the results of the |
| // concat ops. |
| func.walk([&](ConcatenationOp concat) { |
| if (concat.output().hasOneUse() && |
| Quantized(*concat.output().user_begin())) { |
| return; |
| } |
| concat.emitWarning( |
| "Missing quantization parameter on the output might introduce " |
| "quantization error!"); |
| }); |
| |
| // Check for (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be |
| // eliminated at this point. This only occurs for the pattern |
| // (Quant (Dequant (Quant $in, $qB)), $qA) $qB != $qA |
| // where the qdq pair denotes a non-trivial requantiziion of an |
| // alreadyquantized value. Since this makes little sense (directly quantizing |
| // (Quant $in, $qA) would introduce less quantization noise) the likley cause |
| // is an minor error in constructing the original network model that |
| // introduced back-to-back Fake Quantization operations. Hence: emit a |
| // warning. N.b. at this point weŕe (teporarility) in the quantization dialect |
| // (presuambly enalbe re-use in xla etc) quant::*QuantizeCastOp weŕe matching |
| // here. |
| // |
| func.walk([&](quant::QuantizeCastOp q_op) { |
| // If up with end up with |
| auto dq_op = dyn_cast_or_null<quant::DequantizeCastOp>( |
| q_op.getOperand().getDefiningOp()); |
| if (!dq_op) { |
| return; |
| } |
| auto dq_arg = dq_op.getOperand(); |
| |
| if (!dq_arg.hasOneUse()) { |
| // The initial quanization is used sompleace else ... so it might be |
| // reasonable for it to requantized for another purpose. |
| // TODO: ideally would want to still check whether requanization narrows |
| // rather than widens the representation |
| return; |
| } |
| |
| // Invariant: |
| // isa<quant::QuantizeCastOp>(dq_arg.getDefiningOp()) --> |
| // getdq_arg.getType() != q_op.getResult().getType() |
| // |
| // as otherwise qdq pair would have been optimized away. |
| auto qd_arg_def_q_op = |
| dyn_cast_or_null<quant::QuantizeCastOp>(dq_arg.getDefiningOp()); |
| if (!qd_arg_def_q_op) { |
| return; |
| } |
| |
| qd_arg_def_q_op.emitWarning() |
| << " quantizer's output has another quantizer (" << q_op.getLoc() |
| << ") as consumer - intentional?"; |
| }); |
| } |
| |
| bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) { |
| for (const auto& op : func.getOps()) { |
| if (llvm::isa<quant::DequantizeCastOp>(op)) return true; |
| } |
| return false; |
| } |
| |
| using PrepareQuantStats = |
| quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>; |
| |
| void PrepareQuantizePass::runOnFunction() { |
| FuncOp func = getFunction(); |
| MLIRContext* ctx = func.getContext(); |
| ConvertTFLQuantOpsToMlirQuantOps(func); |
| |
| if (quant_specs_.post_training_quantization) { |
| RemoveRedundantStats(func); |
| } else { |
| // Set the quantization parameters for the quantizable input nodes. If this |
| // failed, return the function immediately. This is only required for |
| // quantization aware training model conversion. |
| if (SetInputNodesQuantizationParams(func)) { |
| return; |
| } |
| } |
| |
| // During the legalization, unsigned quantized type is used, so we have to |
| // convert all of them to signed. |
| OwningRewritePatternList patterns; |
| bool is_signed = quant_specs_.IsSignedInferenceType(); |
| int bit_width = quant_specs_.GetQuantizationTypeWidth(); |
| bool enforce_fixed_output_range = ContainsQuantizeOps(func); |
| if (is_signed) { |
| patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx); |
| // Convert quant stats to int8 quantization parameters. |
| // Currently, only activation stats are imported, so narrow_range = false. |
| patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx); |
| } else { |
| // Convert quant stats to uint8 quantization parameters. |
| // Currently, only activation stats are imported, so narrow_range = false. |
| patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx); |
| } |
| applyPatternsAndFoldGreedily(func, patterns); |
| |
| SanityCheckAndAdjustment(func); |
| |
| // Finally, the quantization parameters can be propagated to the rest of the |
| // values (tensors). |
| ApplyQuantizationParamsPropagation( |
| func, is_signed, disable_per_channel || quant_specs_.disable_per_channel, |
| GetOpQuantSpec, |
| enforce_fixed_output_range || quant_specs_.post_training_quantization); |
| |
| ConvertMlirQuantOpsToTFLQuantOps(func); |
| } |
| |
| } // namespace |
| |
| // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass. |
| std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass( |
| const QuantizationSpecs& quant_specs) { |
| return std::make_unique<PrepareQuantizePass>(quant_specs); |
| } |
| |
| static PassRegistration<PrepareQuantizePass> pass( |
| "tfl-prepare-quantize", "Prepare TFL dialect for quantization"); |
| |
| } // namespace TFL |
| } // namespace mlir |