blob: 54b8c1c5661e2a1c7ebcbaa0566e132db7b3535f [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/util.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace quant {
namespace {
constexpr char kQuantizeFuncName[] = "quantize_i8";
constexpr char kDequantizeFuncName[] = "dequantize_i8";
constexpr char kAttrMapAttribute[] = "attr_map";
class QuantizeCompositeFunctionsPass
: public mlir::PassWrapper<QuantizeCompositeFunctionsPass,
OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeCompositeFunctionsPass)
explicit QuantizeCompositeFunctionsPass() {}
explicit QuantizeCompositeFunctionsPass(
QuantizationMethod quantization_method)
: quantization_method_(quantization_method) {}
StringRef getArgument() const final {
// This is the argument used to refer to the pass in
// the textual format (on the commandline for example).
return "quant-quantize-composite-functions";
}
StringRef getDescription() const final {
// This is a brief description of the pass.
return "Quantize composite functions with QDQ input/outputs.";
}
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<TF::TensorFlowDialect, QuantizationDialect>();
}
private:
void runOnOperation() override;
QuantizationMethod quantization_method_ =
QuantizationMethod::kQuantizationAwareTraining;
};
LogicalResult CreateUniformQuantizedTypeParams(UniformQuantizedType qtype,
Location loc,
PatternRewriter& rewriter,
Value& scale,
Value& zero_point) {
TensorType scale_type = RankedTensorType::get({}, rewriter.getF32Type());
TensorType zero_point_type = scale_type.clone(rewriter.getI32Type());
scale = rewriter.create<TF::ConstOp>(
loc, scale_type,
DenseFPElementsAttr::get(scale_type,
{static_cast<float>(qtype.getScale())}));
zero_point = rewriter.create<TF::ConstOp>(
loc, zero_point_type,
DenseIntElementsAttr::get(zero_point_type,
{static_cast<int32_t>(qtype.getZeroPoint())}));
return success(scale && zero_point);
}
LogicalResult CreateUniformQuantizedPerAxisTypeParams(
UniformQuantizedPerAxisType qtype, Location loc, PatternRewriter& rewriter,
Value& scale, Value& zero_point) {
// Consuming op should already know about Quantized channel information,
// so not passing it during conversion. This design might change if needed.
ArrayRef<double> scales = qtype.getScales();
ArrayRef<int64_t> zero_points = qtype.getZeroPoints();
const int num_channels = scales.size();
TensorType scale_type = RankedTensorType::get(
{static_cast<int64_t>(num_channels)}, rewriter.getF32Type());
TensorType zero_point_type = scale_type.clone(rewriter.getI32Type());
llvm::SmallVector<float, 4> float_scales;
llvm::SmallVector<int32_t, 4> int32_zero_points;
float_scales.reserve(num_channels);
int32_zero_points.reserve(num_channels);
for (int i = 0; i < num_channels; ++i) {
float_scales.push_back(scales[i]);
int32_zero_points.push_back(zero_points[i]);
}
scale = rewriter.create<TF::ConstOp>(
loc, scale_type, DenseFPElementsAttr::get(scale_type, float_scales));
zero_point = rewriter.create<TF::ConstOp>(
loc, zero_point_type,
DenseIntElementsAttr::get(zero_point_type, int32_zero_points));
return success(scale && zero_point);
}
LogicalResult CreateQuantizationParams(QuantizedType elem_type, Location loc,
PatternRewriter& rewriter, Value& scale,
Value& zero_point) {
if (!elem_type) {
return failure();
}
if (auto qtype = elem_type.dyn_cast<UniformQuantizedType>()) {
return CreateUniformQuantizedTypeParams(qtype, loc, rewriter, scale,
zero_point);
} else if (auto qtype = elem_type.dyn_cast<UniformQuantizedPerAxisType>()) {
return CreateUniformQuantizedPerAxisTypeParams(qtype, loc, rewriter, scale,
zero_point);
}
return failure();
}
// Replaces quant.qcast op to composite quantize_i8 function.
class ReplaceQuantizePattern : public mlir::OpRewritePattern<QuantizeCastOp> {
public:
explicit ReplaceQuantizePattern(MLIRContext* context)
: OpRewritePattern<QuantizeCastOp>(context) {}
private:
LogicalResult matchAndRewrite(QuantizeCastOp q_op,
PatternRewriter& rewriter) const override {
auto output_type = q_op.getType().cast<TensorType>();
auto elem_type = output_type.getElementType().dyn_cast<QuantizedType>();
const Location loc = q_op->getLoc();
Value scale, zero_point;
if (failed(CreateQuantizationParams(elem_type, loc, rewriter, scale,
zero_point))) {
return failure();
}
SmallVector<Type> output_types = {
output_type.clone(elem_type.getStorageType())};
SmallVector<Value> args = {q_op.arg(), scale, zero_point};
FlatSymbolRefAttr func_name =
FlatSymbolRefAttr::get(rewriter.getStringAttr(kQuantizeFuncName));
auto quantize_call = rewriter.create<TF::PartitionedCallOp>(
loc, output_types, args, func_name,
/*config=*/"", /*config_proto=*/"", /*executor_type=*/"");
auto scast_op = rewriter.create<quant::StorageCastOp>(
loc, output_type, quantize_call->getResult(0));
q_op->replaceAllUsesWith(scast_op);
return success();
}
};
// Replaces quant.dcast op to composite dequantize_i8 function.
class ReplaceDequantizePattern
: public mlir::OpRewritePattern<DequantizeCastOp> {
public:
explicit ReplaceDequantizePattern(MLIRContext* context)
: OpRewritePattern<DequantizeCastOp>(context) {}
private:
LogicalResult matchAndRewrite(DequantizeCastOp dq_op,
PatternRewriter& rewriter) const override {
auto input_type = dq_op.arg().getType().cast<TensorType>();
auto elem_type = input_type.getElementType().dyn_cast<QuantizedType>();
const Location loc = dq_op->getLoc();
Value scale, zero_point;
if (failed(CreateQuantizationParams(elem_type, loc, rewriter, scale,
zero_point))) {
return failure();
}
TensorType output_type = input_type.clone(elem_type.getStorageType());
auto scast_op =
rewriter.create<quant::StorageCastOp>(loc, output_type, dq_op.arg());
FlatSymbolRefAttr func_name =
FlatSymbolRefAttr::get(rewriter.getStringAttr(kDequantizeFuncName));
SmallVector<Value> args = {scast_op->getResult(0), scale, zero_point};
auto dequantize_call = rewriter.create<TF::PartitionedCallOp>(
loc, dq_op.getResult().getType(), args, func_name,
/*config=*/"", /*config_proto=*/"", /*executor_type=*/"");
dq_op->replaceAllUsesWith(dequantize_call);
return success();
}
};
// Determines if all float input/outputs are now quantized.
bool IsQuantizedCall(TF::PartitionedCallOp call_op) {
bool has_quantized_types = false;
for (Value input : call_op.args()) {
if (auto type = input.getType().dyn_cast<TensorType>()) {
if (type.getElementType().isa<FloatType>()) {
return false;
}
if (type.getElementType().isa<QuantizedType>()) {
has_quantized_types = true;
}
}
}
for (Value output : call_op.output()) {
if (auto type = output.getType().dyn_cast<TensorType>()) {
if (type.getElementType().isa<FloatType>()) {
return false;
}
if (type.getElementType().isa<QuantizedType>()) {
has_quantized_types = true;
}
}
}
return has_quantized_types;
}
// Transfers the attributes of the corresponding ops from the float function to
// the quantized function using the attr_map attribute. In the quantized
// function, this map (map1) is in {attr_name_1: attr_identifier} format; and in
// the float function, this map (map2) is in {attr_identifier: attr_name_2}
// format. Where, the attribute identifiers should match between two maps,
// attr_name_1 is the name of the of the attribute needs to be set in the
// quantized function, attr_name_2 is the name of the attribute corresponding to
// the attribute identifier in the float function.
LogicalResult TransferAttributes(func::FuncOp float_func,
func::FuncOp quantized_func) {
// A map to find an attribute from its identifier.
llvm::StringMap<Attribute> identifier_to_attr;
for (Operation& inner_op : float_func.getBody().front().getOperations()) {
if (!inner_op.hasAttr(kAttrMapAttribute)) continue;
std::string attr_map_str =
inner_op.getAttrOfType<StringAttr>(kAttrMapAttribute).str();
for (absl::string_view element_str : absl::StrSplit(attr_map_str, ',')) {
std::vector<absl::string_view> key_and_value_pair =
absl::StrSplit(element_str, ':');
if (key_and_value_pair.size() != 2) {
float_func.emitError("The attr_map attribute is malformed");
return failure();
}
identifier_to_attr.insert(
{llvm::StringRef(std::string(key_and_value_pair[0])),
inner_op.getAttr(
llvm::StringRef(std::string(key_and_value_pair[1])))});
}
}
// Set the attributes for ops with the attr_map attribute.
for (Operation& inner_op : quantized_func.getBody().front().getOperations()) {
if (!inner_op.hasAttr(kAttrMapAttribute)) continue;
std::string attr_map_str =
inner_op.getAttrOfType<StringAttr>(kAttrMapAttribute).str();
for (absl::string_view element_str : absl::StrSplit(attr_map_str, ',')) {
std::vector<absl::string_view> key_and_value_pair =
absl::StrSplit(element_str, ':');
if (key_and_value_pair.size() != 2) {
float_func.emitError("The attr_map attribute is malformed");
return failure();
}
if (identifier_to_attr.count(
llvm::StringRef(std::string(key_and_value_pair[1]))) == 0) {
float_func.emitWarning(absl::StrCat("Using the default value for the '",
key_and_value_pair[0],
"' attribute"));
continue;
}
inner_op.setAttr(llvm::StringRef(std::string(key_and_value_pair[0])),
identifier_to_attr[llvm::StringRef(
std::string(key_and_value_pair[1]))]);
}
inner_op.removeAttr(kAttrMapAttribute);
}
return success();
}
// Unwraps quantization parameters of PartitionedCall ops with quantized
// input/outputs that are created from QuantizePass.
class QuantizeFunctionPattern
: public mlir::OpRewritePattern<TF::PartitionedCallOp> {
public:
explicit QuantizeFunctionPattern(MLIRContext* context)
: OpRewritePattern<TF::PartitionedCallOp>(context) {}
private:
LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op,
PatternRewriter& rewriter) const override {
auto f_attr = call_op.fAttr().dyn_cast<FlatSymbolRefAttr>();
// removeAttr will return nullptr if no attribute was removed.
if (!call_op->removeAttr(kQuantTraitAttrName) || !f_attr) {
return failure();
}
if (!f_attr.getValue().startswith("fused_") || !IsQuantizedCall(call_op)) {
return failure();
}
llvm::Twine quantized_function_name = llvm::Twine(
"quantized_", f_attr.getValue().substr(6).rsplit('_').first);
SmallVector<Value, 4> args;
SmallVector<Value, 4> qparam_args;
SmallVector<Type, 4> result_types;
for (Value arg : call_op.args()) {
if (auto arg_type = arg.getType().dyn_cast<TensorType>()) {
QuantizedType qtype =
arg_type.getElementType().dyn_cast<QuantizedType>();
if (qtype &&
!qtype.isa<UniformQuantizedType, UniformQuantizedPerAxisType>()) {
return failure();
}
}
}
for (Value result : call_op->getResults()) {
if (auto result_type = result.getType().dyn_cast<TensorType>()) {
QuantizedType qtype =
result_type.getElementType().dyn_cast<QuantizedType>();
if (qtype &&
!qtype.isa<UniformQuantizedType, UniformQuantizedPerAxisType>()) {
return failure();
}
}
}
rewriter.setInsertionPoint(call_op);
for (Value arg : call_op.args()) {
TensorType arg_type = arg.getType().dyn_cast<TensorType>();
if (!arg_type) {
args.push_back(arg);
continue;
}
QuantizedType qtype = arg_type.getElementType().dyn_cast<QuantizedType>();
if (!qtype) {
args.push_back(arg);
continue;
}
Value scale, zero_point;
if (failed(CreateQuantizationParams(qtype, arg.getLoc(), rewriter, scale,
zero_point))) {
// As the quantized types are already checked, this is unexpected.
call_op->emitError(
"Failed to create quantization parameter for an argument.");
return failure();
}
auto scast_op = rewriter.create<StorageCastOp>(
arg.getLoc(), arg_type.clone(qtype.getStorageType()), arg);
args.push_back(scast_op.getResult());
qparam_args.push_back(scale);
qparam_args.push_back(zero_point);
}
DenseMap<Value, StorageCastOp> replace_map;
rewriter.setInsertionPointAfter(call_op);
for (Value result : call_op->getResults()) {
TensorType result_type = result.getType().dyn_cast<TensorType>();
if (!result_type) {
result_types.push_back(result.getType());
continue;
}
QuantizedType qtype =
result_type.getElementType().dyn_cast<QuantizedType>();
if (!qtype) {
result_types.push_back(result_type);
continue;
}
Value scale, zero_point;
if (failed(CreateQuantizationParams(qtype, result.getLoc(), rewriter,
scale, zero_point))) {
// As the quantized types are already checked, this is unexpected.
call_op->emitError(
"Failed to create quantization parameter for a result.");
return failure();
}
auto scast_op =
rewriter.create<StorageCastOp>(call_op.getLoc(), result_type, result);
replace_map.insert(std::make_pair(result, scast_op));
result_types.push_back(result_type.clone(qtype.getStorageType()));
qparam_args.push_back(scale);
qparam_args.push_back(zero_point);
}
for (auto replace_pair : replace_map) {
Value result = replace_pair.first;
StorageCastOp scast_op = replace_pair.second;
result.replaceAllUsesExcept(scast_op, scast_op);
}
args.insert(args.end(), qparam_args.begin(), qparam_args.end());
// Make a copy of the quantized function.
auto module = call_op->getParentOfType<ModuleOp>();
SymbolTable symbol_table(module);
func::FuncOp float_func =
dyn_cast<func::FuncOp>(symbol_table.lookup(f_attr.getValue()));
func::FuncOp quantized_func = dyn_cast<func::FuncOp>(
symbol_table.lookup(quantized_function_name.str()));
rewriter.setInsertionPointAfter(float_func);
func::FuncOp new_quantized_func =
dyn_cast<func::FuncOp>(quantized_func->clone());
if (new_quantized_func == nullptr) {
return failure();
}
StringAttr new_quant_func_name = symbol_table.insert(new_quantized_func);
// Set the attributes for ops with the attr_map attribute.
if (failed(TransferAttributes(float_func, new_quantized_func))) {
return failure();
}
rewriter.setInsertionPoint(call_op);
rewriter.replaceOpWithNewOp<TF::PartitionedCallOp>(
call_op, result_types, args,
FlatSymbolRefAttr::get(new_quant_func_name));
return success();
}
};
// Converts const -> quant.qcast pattern to quantized constant, after
// quantization parameters are safely included to each quantize composite
// functions.
class QuantizeConstPattern : public OpRewritePattern<QuantizeCastOp> {
public:
// This pattern should have larger benefit than ReplaceQuantizePattern
explicit QuantizeConstPattern(MLIRContext* context)
: OpRewritePattern<QuantizeCastOp>(context, /*benefit=*/10) {}
LogicalResult matchAndRewrite(QuantizeCastOp q_op,
PatternRewriter& rewriter) const override {
DenseFPElementsAttr attr;
if (!matchPattern(q_op.arg(), m_Constant(&attr))) {
return failure();
}
ShapedType tensor_qtype = q_op.getResult().getType().cast<ShapedType>();
Attribute quantized_attr;
quantized_attr = Quantize(attr, tensor_qtype);
if (!quantized_attr) {
return failure();
}
Type storage_type =
tensor_qtype.getElementType().cast<QuantizedType>().getStorageType();
ShapedType new_type = tensor_qtype.clone(storage_type);
Location loc = q_op.arg().getLoc();
auto const_op = rewriter.create<TF::ConstOp>(loc, new_type, quantized_attr);
// Add scast op to match quantize -> composition pattern. The added scast
// is then removed by canonicalization. ([scast - scast] -> [])
auto scast_op = rewriter.create<quant::StorageCastOp>(loc, tensor_qtype,
const_op.output());
q_op->replaceAllUsesWith(scast_op);
return success();
}
};
static PassRegistration<QuantizeCompositeFunctionsPass> pass;
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.inc"
void QuantizeCompositeFunctionsPass::runOnOperation() {
MLIRContext* ctx = &getContext();
ModuleOp module = getOperation();
PassManager pm(ctx);
// Intermediate output from QuantizePass will have PartitionedCall ops with
// quantized input and output types, which are not allowed in TF dialect.
// This can be removed when the composite call supports quantized types.
pm.enableVerifier(false);
pm.addNestedPass<func::FuncOp>(
CreatePrepareQuantizePass(quantization_method_));
pm.addNestedPass<func::FuncOp>(CreateQuantizePass());
pm.addNestedPass<func::FuncOp>(CreatePostQuantizePass());
if (failed(pm.run(module))) {
signalPassFailure();
}
RewritePatternSet patterns(ctx);
patterns.add<QuantizeFunctionPattern>(ctx);
if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
signalPassFailure();
}
// Constant quantization is a lossy transformation, so they are applied only
// after all the other patterns have been aplied.
RewritePatternSet patterns_2(ctx);
populateWithGenerated(patterns_2);
patterns_2.add<ReplaceQuantizePattern, ReplaceDequantizePattern,
QuantizeConstPattern>(ctx);
if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns_2))) ||
failed(verify(module))) {
signalPassFailure();
}
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateQuantizeCompositeFunctionsPass(
QuantizationMethod quantization_method) {
return std::make_unique<QuantizeCompositeFunctionsPass>(quantization_method);
}
} // namespace quant
} // namespace mlir