blob: 74c3e23a29820dc4adcda7c876dd4eb028cce1e6 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Transform pass for LSTMs.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM
#include <algorithm>
#include <cmath>
#include <string>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/tools/optimize/operator_property.h"
//===----------------------------------------------------------------------===//
// The prepare-quantize Pass for LSTM.
//
namespace mlir {
namespace TFL {
constexpr double power_of_two_scale = 32768.0;
// Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
constexpr const char* intermediate_attributes[] = {
"input_to_input_intermediate", "input_to_forget_intermediate",
"input_to_cell_intermediate", "input_to_output_intermediate",
"effective_hidden_scale_intermediate"};
// Calculates the minimum power of two that is not less than the value.
inline double PowerOfTwoBound(double value) {
return std::pow(2, std::ceil(std::log2(value)));
}
// Returns the element type of LSTM's intermediate tensor designated by the
// index.
template <typename LstmOp>
inline QuantizedType GetIntermediateElementType(LstmOp op, int tensor_index) {
if (tensor_index < 0 || tensor_index > 4) return nullptr;
TypeAttr attr = op->template getAttrOfType<TypeAttr>(
intermediate_attributes[tensor_index]);
if (!attr) {
return nullptr;
}
return QuantizedType::getQuantizedElementType(attr.getValue());
}
namespace operator_property = ::tflite::optimize::operator_property;
using Q = quant::QuantizeCastOp;
using DQ = quant::DequantizeCastOp;
template <typename LstmOp>
LogicalResult GetLstmProperty(
LstmOp op, operator_property::OpVariant* lstm_variant,
operator_property::OperatorProperty* op_property) {
if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
lstm_variant->op_code = tflite::BuiltinOperator_LSTM;
} else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(op.getOperation())) {
lstm_variant->op_code =
tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
} else {
op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
return failure();
}
lstm_variant->use_projection =
!op.projection_weights().getType().template isa<NoneType>();
lstm_variant->use_peephole =
!op.cell_to_output_weights().getType().template isa<NoneType>();
lstm_variant->use_peephole =
!op.cell_to_output_weights().getType().template isa<NoneType>();
lstm_variant->use_layer_norm =
!op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
*op_property = operator_property::GetOperatorProperty(*lstm_variant);
// TODO(b/176258587) move this to operator_property.cc if this is needed in
// other components, too.
bool use_cifg =
op.input_to_input_weights().getType().template isa<NoneType>();
if (use_cifg) {
const absl::flat_hash_set<int> cifg_non_inputs = {1, 5, 9, 12, 20};
const int cifg_non_intermediate = 0;
op_property->inputs.erase(
std::remove_if(
op_property->inputs.begin(), op_property->inputs.end(),
[&](std::pair<int, operator_property::TensorProperty> input) {
return cifg_non_inputs.find(input.first) != cifg_non_inputs.end();
}),
op_property->inputs.end());
op_property->intermediates.erase(
std::remove_if(op_property->intermediates.begin(),
op_property->intermediates.end(),
[&](std::pair<int, operator_property::TensorProperty>
intermediate) {
return intermediate.first == cifg_non_intermediate;
}),
op_property->intermediates.end());
}
return success();
}
template <typename SourceOp>
struct PrepareLstmOutputScale : public OpRewritePattern<SourceOp> {
public:
explicit PrepareLstmOutputScale(MLIRContext* context)
: OpRewritePattern<SourceOp>(context) {}
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override {
operator_property::OpVariant lstm_variant;
operator_property::OperatorProperty lstm_property;
if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
return failure();
}
if (lstm_property.restrict_scale.size() != 1) {
op.emitError() << "The LSTM's operator property expects exactly one "
<< "restrict scale requirement. Got "
<< lstm_property.restrict_scale.size()
<< " restrict scale requirements.";
return failure();
}
// Use same scale for input and output specified in restrict_scale.
const std::vector<int>& tensors = lstm_property.restrict_scale[0];
if (tensors.size() != 2) {
op.emitError(
"Unexpected restricted_scale from operator property."
" Should only have a pair of indices.");
return failure();
}
return processRestrictScale(op, tensors[0], tensors[1], rewriter);
}
private:
// For LSTM's recurrent input activation and output, they are quantized with
// the collective range of both tensors, because theoretically the input
// activation value for the very first inference is not reflected in the
// output and the input activation is not captured.
LogicalResult processRestrictScale(SourceOp op, int input_index,
int output_index,
PatternRewriter& rewriter) const {
assert(output_index == 0);
if (!op.getResult().hasOneUse()) {
op.emitError()
<< "output " << output_index
<< " should have only one use, which should be quant.stats.";
return failure();
}
llvm::SmallVector<quant::StatisticsOp, 2> stats_ops = {
llvm::dyn_cast_or_null<quant::StatisticsOp>(
op.getOperand(input_index).getDefiningOp()),
llvm::dyn_cast_or_null<quant::StatisticsOp>(
*op.getResult().getUsers().begin()),
};
if (!stats_ops[0] || !stats_ops[1]) {
return failure(); // Already converted to Q-DQ pair.
}
llvm::SmallVector<llvm::APFloat, 4> min_max_values;
for (auto& stats_op : stats_ops) {
auto values = stats_op.layerStats()
.dyn_cast<DenseFPElementsAttr>()
.getValues<llvm::APFloat>();
min_max_values.insert(min_max_values.end(), values.begin(), values.end());
}
// min and max values of two stats are already the same.
if (min_max_values[0] == min_max_values[2] &&
min_max_values[1] == min_max_values[3]) {
return failure();
}
mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
mlir::RankedTensorType::get({2}, rewriter.getF32Type()),
{llvm::minimum(min_max_values[0], min_max_values[2]),
llvm::maximum(min_max_values[1], min_max_values[3])});
mlir::ElementsAttr axis_stats;
mlir::IntegerAttr axis;
for (auto& stats_op : stats_ops) {
rewriter.setInsertionPointAfter(stats_op);
rewriter.replaceOpWithNewOp<quant::StatisticsOp>(
stats_op, stats_op.arg(), layer_stats, axis_stats, axis);
}
return success();
}
};
// Quantize LSTM according to its quantization recipe.
template <typename SourceOp>
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
public:
ConvertLstmStatsToQDQs(MLIRContext* context,
const QuantizationSpecs& quant_specs)
: OpRewritePattern<SourceOp>(context, /*benefit=*/2),
quant_specs(quant_specs) {}
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override {
operator_property::OpVariant lstm_variant;
operator_property::OperatorProperty lstm_property;
if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
return failure();
}
if (failed(processIntermediates(op, lstm_variant, lstm_property)) ||
failed(processInputs(op, lstm_variant, lstm_property, rewriter))) {
return failure();
}
return success();
}
private:
QuantizationSpecs quant_specs;
LogicalResult processIntermediates(
SourceOp op, const operator_property::OpVariant& lstm_variant,
const operator_property::OperatorProperty& lstm_property) const {
for (auto& enumerated_intermediates : lstm_property.intermediates) {
int index = enumerated_intermediates.first;
auto& tensor_property = enumerated_intermediates.second;
// intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
if (!lstm_variant.use_layer_norm && index != 4) {
continue;
}
TypeAttr attr =
op->template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
auto quant_type = GetIntermediateElementType<SourceOp>(op, index);
if (!quant_type) {
// intermediate tensor 4 is optional, unless the LSTM uses projection.
if (index == 4 && !lstm_variant.use_projection) {
return success();
}
op.emitError() << intermediate_attributes[index]
<< " is not quantized.";
return failure();
}
auto calibrated_type =
quant_type.template dyn_cast<quant::CalibratedQuantizedType>();
if (!calibrated_type) {
int num_storage_bits = quant_type.getStorageTypeIntegralWidth();
if (tensor_property.number_of_bits != num_storage_bits) {
op.emitError() << intermediate_attributes[index]
<< " is expected to be quantized with "
<< tensor_property.number_of_bits << " bits, but got "
<< num_storage_bits << " bits instead.";
return failure();
}
continue; // skip if it is already quantized.
}
quant::UniformQuantizedType qtype;
if (tensor_property.number_of_bits == 8) {
qtype = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits,
calibrated_type.getMin(), calibrated_type.getMax(),
/*narrowRange=*/false, calibrated_type.getExpressedType(),
/*isSigned=*/quant_specs.IsSignedInferenceType());
} else if (tensor_property.number_of_bits == 16) {
double max = std::max(std::abs(calibrated_type.getMin()),
std::abs(calibrated_type.getMax()));
qtype = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits, -max, max,
/*narrowRange=*/true, calibrated_type.getExpressedType(),
/*isSigned=*/true);
} else {
op.emitError() << "Unsupported quantization bits: "
<< tensor_property.number_of_bits;
return failure();
}
op->setAttr(intermediate_attributes[index],
TypeAttr::get(qtype.castFromExpressedType(
qtype.castToExpressedType(attr.getValue()))));
}
return success();
}
LogicalResult processInputs(
SourceOp op, const operator_property::OpVariant& lstm_variant,
const operator_property::OperatorProperty& lstm_property,
PatternRewriter& rewriter) const {
for (auto& enumerated_inputs : lstm_property.inputs) {
int index = enumerated_inputs.first;
auto& tensor_property = enumerated_inputs.second;
Value input = op.getOperand(index);
if (input.getDefiningOp() == nullptr) continue;
// TODO(b/172517537): make this work with non-PTQ case.
if (llvm::isa<ConstantOp, TFL::ConstOp>(input.getDefiningOp())) {
// Tensors with derived scale are biases, and handled in propagation.
if (tensor_property.use_derived_scale) continue;
if (failed(processConstantOp(op, input.getDefiningOp(), index,
tensor_property, rewriter))) {
return failure();
}
} else {
if (auto stats_op =
llvm::dyn_cast<quant::StatisticsOp>(input.getDefiningOp())) {
if (failed(replaceStatsOp(op, stats_op, index, tensor_property,
rewriter))) {
return failure();
}
} else if (!llvm::isa<DQ>(input.getDefiningOp()) &&
!llvm::isa<SameScalesOpInterface>(input.getDefiningOp())) {
// Continue if StatisticsOp is already converted to Q-DQ pair, or
// stats op is not immediately available to the input because it's
// connected to ops with same scale requirements.
// TODO(b/172517537): make this work with non-PTQ case.
op.emitError() << "Input " << index
<< " should be from DequantizeCast, Statistics, "
<< ", or ops with same scale requirement.";
input.getDefiningOp()->emitError();
return failure();
}
}
}
return success();
}
// For weights, use quantization scale directly inferred from the values.
//
// input 1~4: input to gate weights
// input 5~8: recurrent to gate weights
// input 9~11: peephole weights, input 16: projection weight
// input 20~23: normalization weights
LogicalResult processConstantOp(
SourceOp op, Operation* const_op, int input_index,
const operator_property::TensorProperty& tensor_property,
PatternRewriter& rewriter) const {
// Non-float tensors are neither weights nor require quantization.
auto type = const_op->getResult(0).getType().dyn_cast<ShapedType>();
if (!type || !type.getElementType().isa<FloatType>()) return success();
DenseFPElementsAttr attr;
if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) {
const_op->emitError("Not a constant op.");
return failure();
}
UniformQuantizedType quant_type =
quant::GetUniformQuantizedTypeForWeight(
attr, /*symmetric=*/true,
/*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true,
/*narrow_range=*/true)
.template dyn_cast<quant::UniformQuantizedType>();
if (!quant_type) {
const_op->emitError("Failed to get quantized type");
return failure();
}
// TODO(b/172517537): duplicate the constant when the bias is shared.
Type expressed_type = const_op->getResult(0).getType();
Type cast_type = quant_type.castFromExpressedType(expressed_type);
rewriter.setInsertionPointAfter(const_op);
auto q = rewriter.create<Q>(const_op->getLoc(), cast_type,
const_op->getResult(0));
auto dq = rewriter.create<DQ>(const_op->getLoc(), expressed_type, q);
op.setOperand(input_index, dq.getResult());
return success();
}
LogicalResult replaceStatsOp(
SourceOp op, quant::StatisticsOp stats_op, int input_index,
const operator_property::TensorProperty& tensor_property,
PatternRewriter& rewriter) const {
if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) {
// TODO(b/172517537): check if other tensors should go through this
// check too.
op.emitError() << "Input tensor [" << input_index
<< "] is a state tensor, but has more than one use.";
return failure();
}
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
if (!stats || stats.getNumElements() != 2) {
stats_op.emitError("Stats should have 2 values.");
return failure();
}
quant::QuantizedType quant_type;
double min = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
double max = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
Type expressed = getElementTypeOrSelf(stats_op.getType());
if (tensor_property.extend_to_power_of_two) {
if (tensor_property.number_of_bits != 16) {
op.emitError(
"extended power of 2 scale is only supported for 16-bit"
" quantization.");
return failure();
}
double bound = PowerOfTwoBound(std::max(std::abs(min), std::abs(max)));
// Set flags to 1 for signed type.
quant_type = UniformQuantizedType::getChecked(
quant::QuantizationFlags::Signed,
rewriter.getIntegerType(tensor_property.number_of_bits), expressed,
/*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits),
/*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits),
llvm::maxIntN(tensor_property.number_of_bits), op.getLoc());
} else {
quant_type = quant::fakeQuantAttrsToType(
op.getLoc(), tensor_property.number_of_bits, min, max,
/*narrowRange=*/false, expressed,
/*isSigned=*/true);
}
rewriter.setInsertionPointAfter(stats_op);
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
return success();
}
};
// Returns a function that returns the quantized type of a bias input.
// The scale of bias is a multiplication of given scale and scales from the
// quantization type of other operands.
inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale(
double scale) {
return [=](const std::vector<quant::QuantParams>& quant_params)
-> quant::QuantParams {
if (auto qtype = GetUniformQuantizedTypeForBias(quant_params)
.dyn_cast_or_null<UniformQuantizedType>()) {
return quant::UniformQuantizedType::get(
qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
qtype.getScale() * scale, qtype.getZeroPoint(),
qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
}
return {};
};
}
// Returns quantization spec for LSTMs based on their operator properties.
template <typename LstmOp>
std::unique_ptr<quant::OpQuantSpec> GetLstmOpQuantSpec(LstmOp op) {
operator_property::OpVariant lstm_variant;
operator_property::OperatorProperty lstm_property;
if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
return nullptr;
}
auto spec = absl::make_unique<quant::OpQuantSpec>();
for (const auto& enumerated_inputs : lstm_property.inputs) {
int index = enumerated_inputs.first;
auto& tensor_property = enumerated_inputs.second;
if (tensor_property.use_derived_scale) {
double scale = 1.0;
for (int tensor_index :
tensor_property.derived_scale.intermediate_tensors) {
auto quant_type = GetIntermediateElementType<LstmOp>(op, tensor_index);
if (!quant_type ||
!quant_type.template isa<quant::UniformQuantizedType>()) {
op->emitError() << "While processing derived scale, intermediate "
<< intermediate_attributes[tensor_index]
<< " is not quantized.";
return nullptr;
}
scale *= quant_type.template dyn_cast<quant::UniformQuantizedType>()
.getScale();
}
for (float factor : tensor_property.derived_scale.factors) {
scale *= factor;
}
spec->biases_params.emplace(
index,
std::make_pair(tensor_property.derived_scale.input_tensors,
GetUniformQuantizedTypeForBiasWithScale(scale)));
}
}
return spec;
}
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_LSTM