blob: 79c68d2a5fc57e3305e5fd68eb9fbf6475123156 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legalize TensorFlow Lite to TOSA
#include <climits>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <fstream>
#include <iterator>
#include <limits>
#include <numeric>
#include <string>
#include <unordered_set>
#include "llvm/ADT/ArrayRef.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
#define PASS_NAME "tosa-legalize-tfl"
#define DEBUG_TYPE PASS_NAME
#define HARDSWISH_EXPLICIT_RESCALING false
namespace mlir {
namespace tosa {
namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc"
// Performs lowering to TOSA dialect.
class LegalizeTFL : public TosaLegalizeTFLPassBase<LegalizeTFL> {
public:
LegalizeTFL() = default;
explicit LegalizeTFL(ArrayRef<std::string> disabled_patterns,
ArrayRef<std::string> enabled_patterns) {
this->disabled_patterns_ = disabled_patterns;
this->enabled_patterns_ = enabled_patterns;
}
void runOnOperation() override;
LogicalResult initialize(MLIRContext* context) override;
private:
FrozenRewritePatternSet frozen_patterns_;
};
#include "tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.inc"
// Input from tfl.conv2d takes 64 bits a bias, while tosa.conv2d expects 48
// bits. Need to do a customized truncate here instead of tablegen to handle
// attribute with negative value.
struct ConvertConstantOp : public RewritePattern {
explicit ConvertConstantOp(MLIRContext* context)
: RewritePattern(arith::ConstantOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override;
};
#define DECL_CONVERT_OP(tfl_op) \
struct ConvertTFL##tfl_op##Op : public RewritePattern { \
explicit ConvertTFL##tfl_op##Op(MLIRContext* context) \
: RewritePattern(TFL::tfl_op##Op::getOperationName(), 1, context) {} \
LogicalResult matchAndRewrite(Operation* op, \
PatternRewriter& rewriter) const override; \
}
DECL_CONVERT_OP(Relu);
DECL_CONVERT_OP(Relu6);
DECL_CONVERT_OP(Equal);
DECL_CONVERT_OP(NotEqual);
DECL_CONVERT_OP(Greater);
DECL_CONVERT_OP(GreaterEqual);
DECL_CONVERT_OP(Add);
DECL_CONVERT_OP(Sub);
DECL_CONVERT_OP(Mul);
DECL_CONVERT_OP(Square);
DECL_CONVERT_OP(SquaredDifference);
DECL_CONVERT_OP(Round);
DECL_CONVERT_OP(Div);
DECL_CONVERT_OP(Maximum);
DECL_CONVERT_OP(Minimum);
DECL_CONVERT_OP(FloorMod);
DECL_CONVERT_OP(FloorDiv);
DECL_CONVERT_OP(AddN);
DECL_CONVERT_OP(AveragePool2D);
DECL_CONVERT_OP(MaxPool2D);
DECL_CONVERT_OP(Concatenation);
DECL_CONVERT_OP(Reshape);
DECL_CONVERT_OP(Rank);
DECL_CONVERT_OP(Shape);
DECL_CONVERT_OP(ExpandDims);
DECL_CONVERT_OP(Squeeze);
DECL_CONVERT_OP(Fill);
DECL_CONVERT_OP(Elu);
DECL_CONVERT_OP(Softmax);
DECL_CONVERT_OP(LogSoftmax);
DECL_CONVERT_OP(Sqrt);
DECL_CONVERT_OP(L2Normalization);
DECL_CONVERT_OP(ReduceAny);
DECL_CONVERT_OP(ReduceMax);
DECL_CONVERT_OP(ReduceMin);
DECL_CONVERT_OP(Mean);
DECL_CONVERT_OP(ReduceProd);
DECL_CONVERT_OP(Sum);
DECL_CONVERT_OP(Conv2D);
DECL_CONVERT_OP(TransposeConv);
DECL_CONVERT_OP(DepthwiseConv2D);
DECL_CONVERT_OP(FullyConnected);
DECL_CONVERT_OP(BatchMatMul);
DECL_CONVERT_OP(Split);
DECL_CONVERT_OP(SplitV);
DECL_CONVERT_OP(Pack);
DECL_CONVERT_OP(Unpack);
DECL_CONVERT_OP(Transpose);
DECL_CONVERT_OP(Tile);
DECL_CONVERT_OP(Slice);
DECL_CONVERT_OP(StridedSlice);
DECL_CONVERT_OP(HardSwish);
DECL_CONVERT_OP(ZerosLike);
DECL_CONVERT_OP(Less);
DECL_CONVERT_OP(LessEqual);
DECL_CONVERT_OP(Pad);
DECL_CONVERT_OP(PadV2);
DECL_CONVERT_OP(ResizeBilinear);
DECL_CONVERT_OP(ResizeNearestNeighbor);
DECL_CONVERT_OP(Select);
DECL_CONVERT_OP(SelectV2);
DECL_CONVERT_OP(SpaceToBatchNd);
DECL_CONVERT_OP(BatchToSpaceNd);
DECL_CONVERT_OP(SpaceToDepth);
DECL_CONVERT_OP(DepthToSpace);
DECL_CONVERT_OP(Sin);
DECL_CONVERT_OP(Cos);
DECL_CONVERT_OP(Logistic);
DECL_CONVERT_OP(Tanh);
DECL_CONVERT_OP(PRelu);
DECL_CONVERT_OP(LeakyRelu);
DECL_CONVERT_OP(Neg);
DECL_CONVERT_OP(Yield);
DECL_CONVERT_OP(Custom);
DECL_CONVERT_OP(ReverseV2);
DECL_CONVERT_OP(Quantize);
DECL_CONVERT_OP(Dequantize);
DECL_CONVERT_OP(Const);
DECL_CONVERT_OP(QConst);
DECL_CONVERT_OP(Gather);
DECL_CONVERT_OP(GatherNd);
DECL_CONVERT_OP(SparseToDense);
DECL_CONVERT_OP(OneHot);
DECL_CONVERT_OP(ArgMax);
DECL_CONVERT_OP(FakeQuant);
#undef DECL_CONVERT_OP
LogicalResult ConvertTFLReluOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_relu_op = cast<TFL::ReluOp>(op);
ShapedType input_type = tfl_relu_op.x().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_relu_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!input_type || !output_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLReluOp: input/output tensor should "
"be all quantized or all floating-point.");
}
int64_t clamp_min = 0;
Value clamp_in = tfl_relu_op.x();
if (output_is_qtype) {
UniformQuantizedType input_qtype =
input_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
clamp_min = output_qtype.getZeroPoint();
clamp_in =
buildRescale(rewriter, op, output_type, tfl_relu_op.x(),
input_qtype.getScale() / output_qtype.getScale(),
input_qtype.getZeroPoint(), output_qtype.getZeroPoint(),
/*double_round=*/false, /*scale32=*/true);
}
CreateReplaceOpAndInfer<tosa::ClampOp>(
rewriter, op, output_type, clamp_in,
rewriter.getI64IntegerAttr(clamp_min),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
return success();
}
LogicalResult ConvertTFLRelu6Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_relu6_op = cast<TFL::Relu6Op>(op);
ShapedType input_type = tfl_relu6_op.x().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_relu6_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!input_type || !output_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLRelu6Op: input/output tensor should "
"be all quantized or all floating-point.");
}
int64_t clamp_min = 0;
int64_t clamp_max = 6;
Value clamp_in = tfl_relu6_op.x();
if (output_is_qtype && input_is_qtype) {
UniformQuantizedType input_qtype =
input_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
clamp_min = output_qtype.getZeroPoint();
clamp_max = std::llround(6.0f / output_qtype.getScale()) +
output_qtype.getZeroPoint();
clamp_in =
buildRescale(rewriter, op, output_type, tfl_relu6_op.x(),
input_qtype.getScale() / output_qtype.getScale(),
input_qtype.getZeroPoint(), output_qtype.getZeroPoint(),
/*double_round=*/false, /*scale32=*/true);
}
CreateReplaceOpAndInfer<tosa::ClampOp>(rewriter, op, output_type, clamp_in,
rewriter.getI64IntegerAttr(clamp_min),
rewriter.getI64IntegerAttr(clamp_max),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(6.0f));
return success();
}
static LogicalResult prepareMatchAndRewriteComparison(
Operation* op, mlir::OperandRange operands, PatternRewriter& rewriter,
llvm::SmallVectorImpl<Value>& newOperands) {
Value x = operands[0];
Value y = operands[1];
Value result = op->getResult(0);
ShapedType input_x_type = x.getType().dyn_cast<ShapedType>();
ShapedType input_y_type = y.getType().dyn_cast<ShapedType>();
ShapedType output_type = result.getType().dyn_cast<ShapedType>();
// Not a shaped tensor output
if (!input_x_type || !input_y_type || !output_type) return failure();
bool input_x_is_qtype =
input_x_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool input_y_is_qtype =
input_y_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_x_is_qtype != input_y_is_qtype ||
input_y_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLEqualOp: input/output tensor should "
"be all quantized or all floating-point.");
}
if (!output_is_qtype && !input_x_is_qtype && !input_y_is_qtype) {
newOperands.push_back(x);
newOperands.push_back(y);
return success();
}
UniformQuantizedType input_x_qtype =
input_x_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
UniformQuantizedType input_y_qtype =
input_y_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
if (input_x_qtype.getScale() != input_y_qtype.getScale() ||
input_x_qtype.getZeroPoint() != input_y_qtype.getZeroPoint()) {
return op->emitOpError(
"ConvertTFLEqualOp: input_x and input_y scale/zp "
"must be the same");
}
x = buildRescaleToInt32(rewriter, op, x, 1.0f, input_x_qtype.getZeroPoint());
y = buildRescaleToInt32(rewriter, op, y, 1.0f, input_y_qtype.getZeroPoint());
newOperands.push_back(x);
newOperands.push_back(y);
return success();
}
LogicalResult ConvertTFLEqualOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
CreateReplaceOpAndInfer<tosa::EqualOp>(
rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
return success();
}
LogicalResult ConvertTFLNotEqualOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
auto equal_op = CreateOpAndInfer<tosa::EqualOp>(
rewriter, op->getLoc(), op->getResult(0).getType(), newOperands[0],
newOperands[1]);
CreateReplaceOpAndInfer<tosa::LogicalNotOp>(
rewriter, op, op->getResult(0).getType(), equal_op);
return success();
}
LogicalResult ConvertTFLGreaterOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
CreateReplaceOpAndInfer<tosa::GreaterOp>(
rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
return success();
}
LogicalResult ConvertTFLGreaterEqualOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
CreateReplaceOpAndInfer<tosa::GreaterEqualOp>(
rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
return success();
}
LogicalResult ConvertTFLLessOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
CreateReplaceOpAndInfer<tosa::GreaterOp>(
rewriter, op, op->getResult(0).getType(), newOperands[1], newOperands[0]);
return success();
}
LogicalResult ConvertTFLLessEqualOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
// Swapping the args handles the greater/less difference.
CreateReplaceOpAndInfer<tosa::GreaterEqualOp>(
rewriter, op, op->getResult(0).getType(), newOperands[1], newOperands[0]);
return success();
}
template <typename TflOp, typename TosaOp>
static LogicalResult matchAndRewriteAddSub(Operation* op,
mlir::OperandRange operands,
PatternRewriter& rewriter) {
auto tfl_add_op = cast<TflOp>(op);
ShapedType input_lhs_type =
tfl_add_op.lhs().getType().template dyn_cast<ShapedType>();
ShapedType input_rhs_type =
tfl_add_op.rhs().getType().template dyn_cast<ShapedType>();
ShapedType output_type =
tfl_add_op.getResult().getType().template dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
bool input_lhs_is_qtype =
input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool input_rhs_is_qtype =
input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_lhs_is_qtype != output_is_qtype ||
input_rhs_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLAddOp: input/output tensor should "
"be all quantized or all floating-point.");
}
Value output;
if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
UniformQuantizedType input_lhs_qtype =
input_lhs_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
UniformQuantizedType input_rhs_qtype =
input_rhs_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
// Following quantization described in tensorflow/lite/kernels/add.cc
// In details it does:
// 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale)
// 2. Extra left shift to input to increase precision
// Where input_shift = 20 if input is 8-bit
// input_shift = 15 if input is 16-bit
double in_lhs_scale = input_lhs_qtype.getScale();
double in_rhs_scale = input_rhs_qtype.getScale();
double output_scale = output_qtype.getScale();
double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
const int32_t SHIFT_8_BIT = 20;
const int32_t SHIFT_16_BIT = 15;
int32_t input_shift = (output_qtype.getStorageTypeIntegralWidth() == 16)
? SHIFT_16_BIT
: SHIFT_8_BIT;
double lhs_rescale_scale =
static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
double rhs_rescale_scale =
static_cast<double>(1 << input_shift) * in_rhs_scale / max_scale_2x;
double output_rescale_scale =
max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
Value op1_rescale_lhs =
buildRescaleToInt32(rewriter, op, tfl_add_op.lhs(), lhs_rescale_scale,
input_lhs_qtype.getZeroPoint());
Value op2_rescale_rhs =
buildRescaleToInt32(rewriter, op, tfl_add_op.rhs(), rhs_rescale_scale,
input_rhs_qtype.getZeroPoint());
auto op3_add_op1_op2 = CreateOpAndInfer<TosaOp>(
rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
Value op4_rescale_op3 = buildRescaleFromInt32(
rewriter, op, output_type, op3_add_op1_op2.getResult(),
output_rescale_scale, output_qtype.getZeroPoint());
output = op4_rescale_op3;
} else {
auto op1_add_in =
CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), output_type,
tfl_add_op.lhs(), tfl_add_op.rhs());
output = op1_add_in.getResult();
}
auto fused_activation_fn = tfl_add_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val =
convertFusedActivation(rewriter, op, output, fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {output});
return success();
}
LogicalResult ConvertTFLAddOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
return matchAndRewriteAddSub<TFL::AddOp, tosa::AddOp>(op, op->getOperands(),
rewriter);
}
LogicalResult ConvertTFLSubOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
return matchAndRewriteAddSub<TFL::SubOp, tosa::SubOp>(op, op->getOperands(),
rewriter);
}
LogicalResult ConvertTFLMulOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_mul_op = cast<TFL::MulOp>(op);
llvm::Optional<Value> result = convertMultiplyOp(
rewriter, op, tfl_mul_op.getResult(), tfl_mul_op.lhs(), tfl_mul_op.rhs());
if (!result) return failure();
auto fused_activation_fn = tfl_mul_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val = convertFusedActivation(
rewriter, op, result.getValue(), fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSquareOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_square_op = cast<TFL::SquareOp>(op);
llvm::Optional<Value> result =
convertMultiplyOp(rewriter, op, tfl_square_op.getResult(),
tfl_square_op.x(), tfl_square_op.x());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSquaredDifferenceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_squared_op = cast<TFL::SquaredDifferenceOp>(op);
llvm::Optional<Value> result =
convertSquaredDifferenceOp(rewriter, op, tfl_squared_op.getResult(),
tfl_squared_op.lhs(), tfl_squared_op.rhs());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLRoundOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_round_op = cast<TFL::RoundOp>(op);
ShapedType input_type = tfl_round_op.x().getType().dyn_cast<ShapedType>();
if (!input_type) {
return op->emitOpError("Round: input not shaped tensor type");
}
if (input_type.getElementType().isa<FloatType>()) {
llvm::Optional<Value> result = convertRoundOp(
rewriter, op, tfl_round_op.getResult(), tfl_round_op.x());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
} else {
// Round on int is nonsensical. Instead, replace uses of result with the
// input.
tfl_round_op.replaceAllUsesWith(tfl_round_op.x());
return success();
}
}
LogicalResult ConvertTFLDivOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_div_op = cast<TFL::DivOp>(op);
ShapedType output_type =
tfl_div_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!output_type) return failure();
auto fused_activation_fn = tfl_div_op.fused_activation_functionAttr();
Type element_type = output_type.getElementType();
Value div_op;
if (element_type.isa<IntegerType>()) {
div_op = CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), output_type,
tfl_div_op.lhs(), tfl_div_op.rhs())
.getResult();
} else {
auto reciprocal_op = CreateOpAndInfer<tosa::ReciprocalOp>(
rewriter, op->getLoc(), tfl_div_op.rhs().getType(), tfl_div_op.rhs());
div_op = CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
tfl_div_op.lhs(),
reciprocal_op.getResult(), 0)
.getResult();
}
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val =
convertFusedActivation(rewriter, op, div_op, fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {div_op});
return success();
}
LogicalResult ConvertTFLMaximumOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_max_op = cast<TFL::MaximumOp>(op);
ShapedType input_lhs_type = tfl_max_op.lhs().getType().dyn_cast<ShapedType>();
ShapedType input_rhs_type = tfl_max_op.rhs().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_max_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped tensor output
if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
bool input_lhs_is_qtype =
input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool input_rhs_is_qtype =
input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_lhs_is_qtype != output_is_qtype ||
input_rhs_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLMaximumOp: input/output tensor should "
"be all quantized or all floating-point.");
}
Value output;
if (output_is_qtype) {
ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
Value op1_rescale_lhs =
buildRescaleToInt32(rewriter, op, tfl_max_op.lhs(), 1.0f, 0);
Value op2_rescale_rhs =
buildRescaleToInt32(rewriter, op, tfl_max_op.rhs(), 1.0f, 0);
auto op3_max_op1_op2 = CreateOpAndInfer<tosa::MaximumOp>(
rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
Value op4_rescale_op3 = buildRescaleFromInt32(
rewriter, op, output_type, op3_max_op1_op2.getResult(), 1.0f, 0);
output = op4_rescale_op3;
} else {
auto op1_max_in =
CreateOpAndInfer<tosa::MaximumOp>(rewriter, op->getLoc(), output_type,
tfl_max_op.lhs(), tfl_max_op.rhs());
output = op1_max_in.getResult();
}
rewriter.replaceOp(op, {output});
return success();
}
LogicalResult ConvertTFLMinimumOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_min_op = cast<TFL::MinimumOp>(op);
ShapedType input_lhs_type = tfl_min_op.lhs().getType().dyn_cast<ShapedType>();
ShapedType input_rhs_type = tfl_min_op.rhs().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_min_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped tensor output
if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
bool input_lhs_is_qtype =
input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool input_rhs_is_qtype =
input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_lhs_is_qtype != output_is_qtype ||
input_rhs_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLMinimumOp: input/output tensor should "
"be all quantized or all floating-point.");
}
Value output;
if (output_is_qtype) {
ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
Value op1_rescale_lhs =
buildRescaleToInt32(rewriter, op, tfl_min_op.lhs(), 1.0f, 0);
Value op2_rescale_rhs =
buildRescaleToInt32(rewriter, op, tfl_min_op.rhs(), 1.0f, 0);
auto op3_min_op1_op2 = CreateOpAndInfer<tosa::MinimumOp>(
rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
Value op4_rescale_op3 = buildRescaleFromInt32(
rewriter, op, output_type, op3_min_op1_op2.getResult(), 1.0f, 0);
output = op4_rescale_op3;
} else {
auto op1_min_in =
CreateOpAndInfer<tosa::MinimumOp>(rewriter, op->getLoc(), output_type,
tfl_min_op.lhs(), tfl_min_op.rhs());
output = op1_min_in.getResult();
}
rewriter.replaceOp(op, {output});
return success();
}
LogicalResult ConvertTFLFloorDivOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_floordiv_op = cast<TFL::FloorDivOp>(op);
llvm::Optional<Value> result =
convertFloorDivOp(rewriter, op, tfl_floordiv_op.getResult(),
tfl_floordiv_op.lhs(), tfl_floordiv_op.rhs());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLFloorModOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_floormod_op = cast<TFL::FloorModOp>(op);
llvm::Optional<Value> result =
convertFloorModOp(rewriter, op, tfl_floormod_op.getResult(),
tfl_floormod_op.lhs(), tfl_floormod_op.rhs());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLAddNOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_addn_op = cast<TFL::AddNOp>(op);
ShapedType output_type =
tfl_addn_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped output
if (!output_type) return failure();
SmallVector<Value> inputs(tfl_addn_op.inputs());
assert(inputs.size() >= 2);
auto newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(),
output_type, inputs[0], inputs[1]);
for (int i = 2; i < inputs.size(); i++) {
newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
inputs[i], newOp.getResult());
}
rewriter.replaceOp(op, {newOp.getResult()});
return success();
}
LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_avgpool_op = cast<TFL::AveragePool2DOp>(op);
ShapedType input_type =
tfl_avgpool_op.input().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_avgpool_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped output
if (!output_type) return failure();
// Kernels and strides are dimensionally ordered
SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
ArrayAttr kernel_size;
ArrayAttr stride;
ArrayAttr pad;
{
int64_t kernel_h = tfl_avgpool_op.filter_height();
int64_t kernel_w = tfl_avgpool_op.filter_width();
kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
// i64array is formatted as NHWC now
i64array[1] = kernel_h;
i64array[2] = kernel_w;
}
{
int64_t stride_h = tfl_avgpool_op.stride_h();
int64_t stride_w = tfl_avgpool_op.stride_w();
stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
}
{
tensorflow::Padding tf_pad;
if (!GetPaddingFromString(tfl_avgpool_op.padding().str(), &tf_pad).ok())
return failure();
// Pooling has no non-unit dilation
ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
RankedTensorType filter_type = RankedTensorType::get(
llvm::makeArrayRef(i64array), rewriter.getIntegerType(64));
// TFLite doesn't support explicit padding
if (!getPaddingValuesFromPadType(
tf_pad,
tensorflow::FORMAT_NHWC, // TFLite only supports this
1, // tensorflow::FORMAT_OHWI,
input_type, filter_type, stride, dilation, rewriter, pad))
return failure();
}
auto average_etype = input_type.getElementType();
auto average_type = output_type.clone(average_etype);
Value result;
if (average_etype.isa<quant::UniformQuantizedType>()) {
// TensorFlow Lite doesn't use the zero point when calculating
// quantized average pool, while TOSA does. Force the TOSA
// zero_points to zero to ensure that the calculations match
auto zero = rewriter.getI32IntegerAttr(0);
auto quant_attr = tosa::UnaryOpQuantizationAttr::get(
/*input_zp=*/zero, /*output_zp=*/zero, rewriter.getContext());
result = CreateOpAndInfer<tosa::AvgPool2dOp>(
rewriter, op->getLoc(), average_type, tfl_avgpool_op.input(),
kernel_size, stride, pad, quant_attr);
} else {
result = CreateOpAndInfer<tosa::AvgPool2dOp>(
rewriter, op->getLoc(), average_type, tfl_avgpool_op.input(),
kernel_size, stride, pad);
}
if (average_type != output_type) {
result = CreateOpAndInfer<tosa::CastOp>(rewriter, op->getLoc(), output_type,
result);
}
rewriter.replaceOp(op, result);
return success();
}
LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_maxpool_op = cast<TFL::MaxPool2DOp>(op);
ShapedType input_type =
tfl_maxpool_op.input().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_maxpool_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped type
if (!output_type) return failure();
// Kernels and strides are dimensionally ordered
SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
ArrayAttr kernel_size;
ArrayAttr stride;
ArrayAttr pad;
{
int64_t kernel_h = tfl_maxpool_op.filter_height();
int64_t kernel_w = tfl_maxpool_op.filter_width();
kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
// i64array is formatted as NHWC now
i64array[1] = kernel_h;
i64array[2] = kernel_w;
}
{
int64_t stride_h = tfl_maxpool_op.stride_h();
int64_t stride_w = tfl_maxpool_op.stride_w();
stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
}
{
tensorflow::Padding tf_pad;
if (!GetPaddingFromString(tfl_maxpool_op.padding().str(), &tf_pad).ok())
return failure();
// Pooling has no non-unit dilation
ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
RankedTensorType filter_type =
RankedTensorType::get(i64array, rewriter.getIntegerType(64));
// TFLite doesn't support explicit padding
if (!getPaddingValuesFromPadType(
tf_pad,
tensorflow::FORMAT_NHWC, // TFLite only supports this
1, // tensorflow::FORMAT_OHWI,
input_type, filter_type, stride, dilation, rewriter, pad))
return failure();
}
CreateReplaceOpAndInfer<tosa::MaxPool2dOp>(rewriter, op, output_type,
tfl_maxpool_op.input(),
kernel_size, stride, pad);
return success();
}
LogicalResult ConvertTFLConv2DOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_conv2d_op = cast<TFL::Conv2DOp>(op);
RankedTensorType input_type =
tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
RankedTensorType filter_type =
tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
ShapedType output_type =
tfl_conv2d_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!input_type) return failure();
if (!output_type) return failure();
if (!filter_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::QuantizedType>();
bool filter_is_qtype =
filter_type.getElementType().isa<mlir::quant::QuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::QuantizedType>();
if ((input_is_qtype != filter_is_qtype) ||
(input_is_qtype != output_is_qtype)) {
return op->emitOpError(
"ConvertTFLConv2DOp: input/filter/output tensor should "
"be all quantized or all floating-point.");
}
ArrayAttr pad;
ArrayAttr stride;
ArrayAttr dilation;
{
int64_t stride_h = tfl_conv2d_op.stride_h();
int64_t stride_w = tfl_conv2d_op.stride_w();
stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
}
{
int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
}
{
tensorflow::Padding tf_pad;
if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
return failure();
// TFLite doesn't support explicit padding
if (!getPaddingValuesFromPadType(
tf_pad,
tensorflow::FORMAT_NHWC, // TFLite only supports this
1, // tensorflow::FORMAT_OHWI,
input_type, filter_type, stride, dilation, rewriter, pad))
return failure();
}
Value unquantized_bias = tfl_conv2d_op.bias();
Type bias_ety =
output_is_qtype ? rewriter.getI32Type() : output_type.getElementType();
if (unquantized_bias)
bias_ety = unquantized_bias.getType().cast<ShapedType>().getElementType();
auto a1_conv2d_op = CreateOpAndInfer<tosa::Conv2DOp>(
rewriter, op->getLoc(), output_type.clone(bias_ety),
tfl_conv2d_op.input(), tfl_conv2d_op.filter(), unquantized_bias, pad,
stride, dilation);
Value conv2d_output;
if (input_is_qtype) {
conv2d_output =
buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
input_type, filter_type, output_type);
} else {
conv2d_output = a1_conv2d_op.getResult();
}
auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val = convertFusedActivation(
rewriter, op, conv2d_output, fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {conv2d_output});
return success();
}
LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_conv_op = cast<TFL::TransposeConvOp>(op);
ShapedType input_type = tfl_conv_op.input().getType().dyn_cast<ShapedType>();
ShapedType filter_type =
tfl_conv_op.weights().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_conv_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!input_type) return failure();
if (!output_type) return failure();
if (!filter_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::QuantizedType>();
bool filter_is_qtype =
filter_type.getElementType().isa<mlir::quant::QuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::QuantizedType>();
if ((input_is_qtype != filter_is_qtype) ||
(input_is_qtype != output_is_qtype)) {
return op->emitOpError(
"ConvertTFLConv2DOp: input/filter/output tensor should "
"be all quantized or all floating-point.");
}
ArrayAttr stride;
ArrayAttr dilation;
ArrayAttr outpad;
ArrayAttr output_shape;
{
int64_t stride_h = tfl_conv_op.stride_h();
int64_t stride_w = tfl_conv_op.stride_w();
stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
}
// tfl.transpose_conv doesn't support dilations
dilation = rewriter.getI64ArrayAttr({1, 1});
{
tensorflow::Padding tf_pad;
if (!GetPaddingFromString(tfl_conv_op.padding().str(), &tf_pad).ok())
return failure();
if (!getTransposeConv2dPaddingValues(
tf_pad,
tensorflow::FORMAT_NHWC, // TFLite only supports this
1, // tensorflow::FORMAT_OHWI,
input_type, filter_type, output_type, stride, dilation, rewriter,
outpad))
return failure();
}
{
ElementsAttr output_shape_elems;
// Match from input_size tensor first
if (matchPattern(tfl_conv_op.output_shape(),
m_Constant(&output_shape_elems))) {
SmallVector<int64_t> shape_vec;
for (int i = 0; i < output_shape_elems.getNumElements(); i++)
shape_vec.push_back(
output_shape_elems.getValues<APInt>()[i].getSExtValue());
output_shape = rewriter.getI64ArrayAttr(shape_vec);
} else if (output_type.hasRank()) {
// Use output tensor's shape otherwise
output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
} else {
// TODO(suderman): Figure out rankless shape propagation.
return failure();
}
}
int output_channel = 0;
// TODO(suderman): We need to figure out how to guarantee output channel
// propagation.
if (output_type.hasRank()) {
output_channel = output_type.getDimSize(3);
} else if (filter_type.hasRank()) {
output_channel = filter_type.getDimSize(0);
} else {
return failure();
}
llvm::Optional<Value> zero_bias;
if (input_is_qtype) {
uint32_t input_bits = input_type.getElementType()
.dyn_cast<mlir::quant::QuantizedType>()
.getStorageTypeIntegralWidth();
uint32_t weight_bits = filter_type.getElementType()
.dyn_cast<mlir::quant::QuantizedType>()
.getStorageTypeIntegralWidth();
if (input_bits == 16 && weight_bits == 8) {
SmallVector<APInt> vec(output_channel, APInt(48, 0, true));
zero_bias = getConstTensor<APInt>(rewriter, op, vec, {output_channel});
} else {
SmallVector<int32_t> vec(output_channel, 0);
zero_bias = getConstTensor<int32_t>(rewriter, op, vec, {output_channel});
}
} else {
SmallVector<float> vec(output_channel, 0.0f);
zero_bias = getConstTensor<float>(rewriter, op, vec, {output_channel});
}
if (!zero_bias) return failure();
Type bias_ety = zero_bias->getType().cast<ShapedType>().getElementType();
auto a1_conv2d_op = CreateOpAndInfer<tosa::TransposeConv2DOp>(
rewriter, op->getLoc(), output_type.clone(bias_ety), tfl_conv_op.input(),
tfl_conv_op.weights(), zero_bias.getValue(), outpad, stride, dilation,
output_shape);
Value conv2d_output;
if (input_is_qtype) {
conv2d_output =
buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
input_type, filter_type, output_type);
} else {
conv2d_output = a1_conv2d_op.getResult();
}
rewriter.replaceOp(op, {conv2d_output});
return success();
}
LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_conv2d_op = cast<TFL::DepthwiseConv2DOp>(op);
ShapedType input_type =
tfl_conv2d_op.input().getType().dyn_cast<ShapedType>();
ShapedType filter_type =
tfl_conv2d_op.filter().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_conv2d_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped output
if (!input_type) return failure();
if (!output_type) return failure();
if (!filter_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::QuantizedType>();
bool filter_is_qtype =
filter_type.getElementType().isa<mlir::quant::QuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::QuantizedType>();
if ((input_is_qtype != filter_is_qtype) ||
(input_is_qtype != output_is_qtype)) {
return op->emitOpError(
"ConvertTFLConv2DOp: input/filter/output tensor should "
"be all quantized or all floating-point.");
}
// We need the filter shape to compute the transpose.
if (!filter_type.hasRank()) return failure();
auto filter_shape = filter_type.getShape();
// Operator depthwiseConv2D
// TFLite orders the depthwiseConv2D filter in IHWO, while TOSA orders
// filter in HWIO
//
// The lowering reorders the filter.
//
// a1_transpose = tosa.transpose(filter, {1, 2, 3, 0}) // HWIO
// a2_reshape = tosa.reshape(filter, H, W, depth_multiplier, I /
// depth_multiplier)
// a3_transpose_conv2d = tosa.transpose_conv2d(input, a2_reshape, padding,
// stride, dilation)
ArrayAttr pad;
ArrayAttr stride;
ArrayAttr dilation;
auto depth_multiplier = tfl_conv2d_op.depth_multiplierAttr();
{
int64_t stride_h = tfl_conv2d_op.stride_h();
int64_t stride_w = tfl_conv2d_op.stride_w();
stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
}
{
int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
}
{
tensorflow::Padding tf_pad;
if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
return failure();
if (!getPaddingValuesFromPadType(
tf_pad,
tensorflow::FORMAT_NHWC, // TFLite only supports this
1, // tensorflow::FORMAT_OHWI,
input_type, filter_type, stride, dilation, rewriter, pad))
return failure();
}
SmallVector<int64_t, 4> a1_transpose_dims;
a1_transpose_dims.push_back(filter_shape[1]);
a1_transpose_dims.push_back(filter_shape[2]);
a1_transpose_dims.push_back(filter_shape[3]);
a1_transpose_dims.push_back(filter_shape[0]);
SmallVector<int64_t, 4> a2_reshape_dims;
a2_reshape_dims.push_back(a1_transpose_dims[0]);
a2_reshape_dims.push_back(a1_transpose_dims[1]);
a2_reshape_dims.push_back(a1_transpose_dims[2] / depth_multiplier.getInt());
a2_reshape_dims.push_back(depth_multiplier.getInt());
llvm::Optional<Value> a1_filter_transpose_perms = getConstTensor<int32_t>(
rewriter, op, /*vec=*/{1, 2, 3, 0}, /*shape=*/{4});
if (!a1_filter_transpose_perms) return failure();
auto a1_filter_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
filter_type.getElementType()),
tfl_conv2d_op.filter(), a1_filter_transpose_perms.getValue());
auto a2_filter_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(ArrayRef<int64_t>(a2_reshape_dims),
filter_type.getElementType()),
a1_filter_transpose_op.getResult(),
rewriter.getI64ArrayAttr(a2_reshape_dims));
Value unquantized_bias = tfl_conv2d_op.bias();
Type bias_ety =
output_is_qtype ? rewriter.getI32Type() : output_type.getElementType();
if (unquantized_bias)
bias_ety = unquantized_bias.getType().cast<ShapedType>().getElementType();
auto a3_depthwise_conv2d_op = CreateOpAndInfer<tosa::DepthwiseConv2DOp>(
rewriter, op->getLoc(), output_type.clone(bias_ety),
tfl_conv2d_op.input(), a2_filter_reshape_op.getResult(), unquantized_bias,
pad, stride, dilation);
Value conv2d_output;
if (input_is_qtype) {
conv2d_output = buildRescaleOpConvOutput(
rewriter, op, a3_depthwise_conv2d_op.getResult(), input_type,
filter_type, output_type);
} else {
conv2d_output = a3_depthwise_conv2d_op.getResult();
}
auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val = convertFusedActivation(
rewriter, op, conv2d_output, fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {conv2d_output});
return success();
}
LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_mm_op = cast<TFL::BatchMatMulOp>(op);
auto result_ty = tfl_mm_op.getType().cast<ShapedType>();
Value lhs = tfl_mm_op.x();
Value rhs = tfl_mm_op.y();
ShapedType lhs_type = lhs.getType().cast<ShapedType>();
ShapedType rhs_type = rhs.getType().cast<ShapedType>();
bool transpose_lhs = tfl_mm_op.adj_x();
bool transpose_rhs = tfl_mm_op.adj_y();
bool lhs_is_qtype =
lhs_type.getElementType().isa<mlir::quant::QuantizedType>();
bool rhs_is_qtype =
rhs_type.getElementType().isa<mlir::quant::QuantizedType>();
bool result_is_qtype =
result_ty.getElementType().isa<mlir::quant::QuantizedType>();
if ((lhs_is_qtype != rhs_is_qtype) || (lhs_is_qtype != result_is_qtype)) {
return op->emitOpError(
"ConvertTFLBatchMatMulOp: lhs/rhs/output tensor should "
"be all quantized or all floating-point.");
}
if (transpose_lhs) {
auto lhs_ty = lhs.getType().cast<ShapedType>();
Value perms =
getConstTensor<int32_t>(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3})
.getValue();
Type output_type = UnrankedTensorType::get(lhs_ty.getElementType());
lhs = CreateOpAndInfer<tosa::TransposeOp>(rewriter, op->getLoc(),
output_type, lhs, perms)
.getResult();
}
if (transpose_rhs) {
auto rhs_ty = rhs.getType().cast<ShapedType>();
Value perms =
getConstTensor<int32_t>(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3})
.getValue();
Type output_type = UnrankedTensorType::get(rhs_ty.getElementType());
rhs = CreateOpAndInfer<tosa::TransposeOp>(rewriter, op->getLoc(),
output_type, rhs, perms)
.getResult();
}
auto matmul = CreateOpAndInfer<tosa::MatMulOp>(rewriter, op->getLoc(),
result_ty, lhs, rhs)
.getResult();
if (lhs_is_qtype) {
matmul = buildRescaleOpConvOutput(rewriter, op, matmul, lhs_type, rhs_type,
result_ty);
}
rewriter.replaceOp(op, matmul);
return success();
}
LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_fc_op = cast<TFL::FullyConnectedOp>(op);
ShapedType output_type =
tfl_fc_op.getResult(0).getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!output_type) return failure();
RankedTensorType input_type =
tfl_fc_op.input().getType().dyn_cast<RankedTensorType>();
RankedTensorType filter_type =
tfl_fc_op.filter().getType().dyn_cast<RankedTensorType>();
RankedTensorType bias_type =
tfl_fc_op.bias().getType().dyn_cast<RankedTensorType>();
if (!input_type || !filter_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::QuantizedType>();
bool filter_is_qtype =
filter_type.getElementType().isa<mlir::quant::QuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::QuantizedType>();
if ((input_is_qtype != filter_is_qtype) ||
(input_is_qtype != output_is_qtype)) {
return op->emitOpError(
"ConvertTFLFullyConnectedOp: input/filter/output tensor should "
"be all quantized or all floating-point.");
}
Value input_val = tfl_fc_op.input();
// tfl.fully_connected() can takes various dimension tensor as input
// need to reshape it to rank 2 tensor, which tosa.fully_connected only
// supports if input tensor is rank 4. It's not always reshaping to (dim[0] *
// dim[1], dim[2] * dim[3]).
// In some networks it's reshaping to (dim[0], dim[1] * dim[2] * dim[3]) so a
// more general way to determine the reshape's shape is by looking at filter's
// shape[1].
if (input_type.getRank() != 2) {
int64_t num_elems = filter_type.getShape()[1];
int64_t num_batch = input_type.getNumElements() / num_elems;
SmallVector<int64_t, 2> shape_vals({num_batch, num_elems});
RankedTensorType reshape_type =
RankedTensorType::get(shape_vals, input_type.getElementType());
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), reshape_type, tfl_fc_op.input(),
rewriter.getI64ArrayAttr(shape_vals));
input_val = reshape_op.getResult();
}
Value bias_val;
if (!bias_type) {
// For some matmuls, the bias may actually be a "UnitType" which has no
// value. TOSA requires bias to be an array of output_channel_count values,
// so create a constant of the appropriate number and type of zeros.
SmallVector<int64_t, 1> bias_shape({filter_type.getShape()[0]});
RankedTensorType new_bias_type;
DenseElementsAttr bias_attr;
if (input_type.getElementType().isa<FloatType>()) {
SmallVector<float> bias_arr(bias_shape[0]);
for (int i = 0; i < bias_shape[0]; i++) {
bias_arr[i] = 0.0;
}
new_bias_type =
RankedTensorType::get(bias_shape, input_type.getElementType());
bias_attr =
DenseElementsAttr::get(new_bias_type, llvm::makeArrayRef(bias_arr));
} else {
SmallVector<int32_t> bias_arr(bias_shape[0]);
for (int i = 0; i < bias_shape[0]; i++) {
bias_arr[i] = 0;
}
if (!input_is_qtype) {
return op->emitOpError(
"ConvertTFLFullyConnectedOp: input must be quantized type if it's "
"not float type.");
}
auto input_qtype =
input_type.getElementType().cast<mlir::quant::QuantizedType>();
Type new_bias_ety = input_qtype.getStorageTypeIntegralWidth() == 16
? rewriter.getIntegerType(48)
: rewriter.getI32Type();
new_bias_type = RankedTensorType::get(bias_shape, new_bias_ety);
bias_attr =
DenseElementsAttr::get(new_bias_type, llvm::makeArrayRef(bias_arr));
}
auto bias_op = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
new_bias_type, bias_attr);
bias_val = bias_op.getResult();
bias_type = new_bias_type;
} else {
bias_val = tfl_fc_op.bias();
}
Type bias_ety = bias_val.getType().cast<ShapedType>().getElementType();
auto fc_op = CreateOpAndInfer<tosa::FullyConnectedOp>(
rewriter, op->getLoc(), UnrankedTensorType::get(bias_ety), input_val,
tfl_fc_op.filter(), bias_val);
Value fc_output;
if (input_is_qtype) {
fc_output = buildRescaleOpConvOutput(
rewriter, op, fc_op.getResult(), input_type, filter_type,
UnrankedTensorType::get(output_type.getElementType()));
} else {
fc_output = fc_op.getResult();
}
// If we know the output rank, we need to ensure the output shape is correct.
ShapedType fc_type = fc_output.getType().cast<ShapedType>();
if (output_type.hasRank()) {
llvm::SmallVector<int64_t> output_shape;
fc_output = CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
UnrankedTensorType::get(fc_type.getElementType()), fc_output,
rewriter.getI64ArrayAttr(output_type.getShape()));
}
auto fused_activation_fn = tfl_fc_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val =
convertFusedActivation(rewriter, op, fc_output, fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {fc_output});
return success();
}
LogicalResult ConvertTFLConcatenationOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_concat_op = cast<TFL::ConcatenationOp>(op);
SmallVector<Value> values(tfl_concat_op.values());
IntegerAttr axis_attr;
{
auto tmpAttr = tfl_concat_op.axisAttr();
if (!tmpAttr) {
tmpAttr = rewriter.getI64IntegerAttr(0);
}
axis_attr = tmpAttr;
}
int32_t axis = axis_attr.getInt();
llvm::Optional<Value> result =
convertConcatV2Op(rewriter, op, tfl_concat_op.getResult(), values, axis);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLReshapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_reshape_op = cast<TFL::ReshapeOp>(op);
ShapedType output_type =
tfl_reshape_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped tensor output
if (!output_type) return failure();
SmallVector<int64_t> shape_vals;
// Either the output type needs to be ranked or we need a constant input
// to compute the output rank.
ElementsAttr shape_attr;
if (!matchPattern(tfl_reshape_op.shape(), m_Constant(&shape_attr))) {
if (!output_type.hasRank()) return failure();
shape_vals.resize(output_type.getRank(), -1);
} else {
for (auto dim : shape_attr.getValues<int32_t>()) shape_vals.push_back(dim);
}
// Propagate the agreement between the output shape and constant value.
if (output_type.hasRank()) {
if (output_type.getRank() != shape_vals.size()) return failure();
for (int i = 0; i < output_type.getRank(); i++) {
if (shape_vals[i] == -1) shape_vals[i] = output_type.getDimSize(i);
}
}
// We cannot handle more than 1 dynamic dimension.
int64_t dynamic_count = 0;
for (auto val : shape_vals)
if (val == -1) dynamic_count++;
if (dynamic_count > 1) return failure();
ArrayAttr new_shape_attr = rewriter.getI64ArrayAttr(shape_vals);
CreateReplaceOpAndInfer<tosa::ReshapeOp>(
rewriter, op, output_type, tfl_reshape_op.input(), new_shape_attr);
return success();
}
LogicalResult ConvertTFLRankOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_rank_op = cast<TFL::RankOp>(op);
RankedTensorType input_type =
tfl_rank_op.input().getType().dyn_cast<RankedTensorType>();
if (!input_type) return failure();
int32_t rank = input_type.getRank();
RankedTensorType rank_type =
RankedTensorType::get({1}, rewriter.getIntegerType(32));
auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
auto rank_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
rank_type, rank_attr);
rewriter.replaceOp(op, {rank_const.getResult()});
return success();
}
LogicalResult ConvertTFLShapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_shape_op = cast<TFL::ShapeOp>(op);
RankedTensorType output_type =
tfl_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
RankedTensorType input_type =
tfl_shape_op.input().getType().dyn_cast<RankedTensorType>();
if (!input_type || !input_type.hasStaticShape())
return rewriter.notifyMatchFailure(op, "input shape not static");
auto input_shape = input_type.getShape();
SmallVector<int32_t> shape_arr;
for (int i = 0; i < input_shape.size(); i++) {
shape_arr.emplace_back(input_shape[i]);
}
RankedTensorType shape_type = RankedTensorType::get(
{static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
auto shape_attr =
DenseElementsAttr::get(shape_type, llvm::makeArrayRef(shape_arr));
auto shape_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
shape_type, shape_attr);
rewriter.replaceOp(op, {shape_const.getResult()});
return success();
}
LogicalResult ConvertTFLExpandDimsOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_expanddims_op = cast<TFL::ExpandDimsOp>(op);
llvm::Optional<Value> result =
convertExpandDimsOp(rewriter, op, tfl_expanddims_op.getResult(),
tfl_expanddims_op.input(), tfl_expanddims_op.dim());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSqueezeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_squeeze_op = cast<TFL::SqueezeOp>(op);
// Copy squeeze_dims into int32_t array
auto squeeze_dims_attr = tfl_squeeze_op.squeeze_dimsAttr();
SmallVector<int32_t> squeeze_dims;
for (auto& squeeze_dim : squeeze_dims_attr) {
squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
}
llvm::Optional<Value> result =
convertSqueezeOp(rewriter, op, tfl_squeeze_op.getResult(),
tfl_squeeze_op.input(), squeeze_dims);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLFillOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_fill_op = cast<TFL::FillOp>(op);
RankedTensorType output_type =
tfl_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
ElementsAttr dims_elems;
if (!matchPattern(tfl_fill_op.dims(), m_Constant(&dims_elems)))
return failure();
SmallVector<int64_t> dims_vals;
uint32_t total_size = 1;
for (int i = 0; i < dims_elems.getNumElements(); i++) {
dims_vals.push_back(dims_elems.getValues<APInt>()[i].getSExtValue());
total_size *= dims_vals[i];
}
ElementsAttr value_elem;
if (!matchPattern(tfl_fill_op.input(), m_Constant(&value_elem)))
return failure();
RankedTensorType fill_type = RankedTensorType::get(
ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
DenseElementsAttr fill_attr;
// Convert to a compatible zero type.
if (value_elem.getType().getElementType().isa<FloatType>()) {
SmallVector<float> fill_arr(
total_size, value_elem.getValues<APFloat>()[0].convertToFloat());
fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
} else {
SmallVector<int32_t> fill_arr(
total_size, value_elem.getValues<APInt>()[0].getLimitedValue());
fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
}
auto fill_const_op = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
fill_type, fill_attr);
rewriter.replaceOp(op, {fill_const_op.getResult()});
return success();
}
LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_any_op = cast<TFL::ReduceAnyOp>(op);
RankedTensorType output_type =
tfl_any_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_any_op.reduction_indices(), m_Constant(&axes_elems)))
return failure();
llvm::Optional<Value> result = convertReduceAnyOp(
rewriter, op, output_type, tfl_any_op.input(), axes_elems);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_max_op = cast<TFL::ReduceMaxOp>(op);
RankedTensorType output_type =
tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_max_op.axes(), m_Constant(&axes_elems)))
return failure();
llvm::Optional<Value> result = convertReduceMaxOp(
rewriter, op, output_type, tfl_max_op.input(), axes_elems);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLReduceMinOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_min_op = cast<TFL::ReduceMinOp>(op);
RankedTensorType output_type =
tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_min_op.axes(), m_Constant(&axes_elems)))
return failure();
llvm::Optional<Value> result = convertReduceMinOp(
rewriter, op, output_type, tfl_min_op.input(), axes_elems);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLReduceProdOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_prod_op = cast<TFL::ReduceProdOp>(op);
RankedTensorType output_type =
tfl_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_prod_op.axes(), m_Constant(&axes_elems)))
return failure();
llvm::Optional<Value> result = convertReduceProdOp(
rewriter, op, output_type, tfl_prod_op.input(), axes_elems);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLMeanOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_mean_op = cast<TFL::MeanOp>(op);
RankedTensorType output_type =
tfl_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_mean_op.axis(), m_Constant(&axes_elems)))
return failure();
llvm::Optional<Value> result = convertReduceMeanOp(
rewriter, op, output_type, tfl_mean_op.input(), axes_elems);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSumOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_sum_op = cast<TFL::SumOp>(op);
RankedTensorType output_type =
tfl_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_sum_op.axes(), m_Constant(&axes_elems)))
return failure();
llvm::Optional<Value> result = convertReduceSumOp(
rewriter, op, output_type, tfl_sum_op.input(), axes_elems);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLEluOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_elu_op = cast<TFL::EluOp>(op);
llvm::Optional<Value> result =
convertEluOp(rewriter, op, tfl_elu_op.getResult(), tfl_elu_op.x());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSoftmaxOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_softmax_op = cast<TFL::SoftmaxOp>(op);
llvm::Optional<Value> result = convertSoftmaxOp(
rewriter, op, tfl_softmax_op.getResult(), tfl_softmax_op.input(),
tfl_softmax_op.betaAttr().getValueAsDouble());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSqrtOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_rsqrt_op = cast<TFL::SqrtOp>(op);
auto rsqrt = CreateOpAndInfer<tosa::RsqrtOp>(
rewriter, op->getLoc(), tfl_rsqrt_op.getType(), tfl_rsqrt_op.x());
CreateReplaceOpAndInfer<tosa::ReciprocalOp>(rewriter, op, rsqrt.getType(),
rsqrt);
return success();
}
LogicalResult ConvertTFLL2NormalizationOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_l2norm_op = cast<TFL::L2NormalizationOp>(op);
auto input = tfl_l2norm_op.input();
auto input_ty = input.getType().cast<ShapedType>();
auto loc = op->getLoc();
if (!input_ty.hasRank()) return failure();
if (input_ty.getElementType().isF32()) {
auto shift = rewriter.getIntegerAttr(rewriter.getI32Type(), 0);
auto result_ty = UnrankedTensorType::get(input_ty.getElementType());
auto mul = CreateOpAndInfer<tosa::MulOp>(rewriter, loc, result_ty, input,
input, shift);
auto sum = CreateOpAndInfer<tosa::ReduceSumOp>(
rewriter, loc, result_ty, mul,
rewriter.getI64IntegerAttr(input_ty.getRank() - 1));
SmallVector<float> min(1, sqrt(std::numeric_limits<float>::min()));
Value min_val = getConstTensor<float>(rewriter, op, min, {}).getValue();
auto max = CreateOpAndInfer<tosa::MaximumOp>(rewriter, loc, result_ty, sum,
min_val);
auto rsqrt = CreateOpAndInfer<tosa::RsqrtOp>(rewriter, loc, result_ty, max)
.getResult();
auto result = CreateOpAndInfer<tosa::MulOp>(rewriter, loc, result_ty, rsqrt,
input, shift)
.getResult();
auto fused_activation_fn = tfl_l2norm_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val =
convertFusedActivation(rewriter, op, result, fused_activation_fn);
if (!fused_activation_val) return failure();
result = fused_activation_val.getValue();
}
rewriter.replaceOp(op, result);
return success();
}
return failure();
}
LogicalResult ConvertTFLLogSoftmaxOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_logsoftmax_op = cast<TFL::LogSoftmaxOp>(op);
llvm::Optional<Value> result = convertLogSoftmaxOp(
rewriter, op, tfl_logsoftmax_op.getResult(), tfl_logsoftmax_op.input());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSliceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_slice_op = cast<TFL::SliceOp>(op);
ShapedType output_type =
tfl_slice_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped tensor output
if (!output_type) return failure();
ElementsAttr begin_elems, size_elems;
SmallVector<int64_t> begin_vals, size_vals;
if (!matchPattern(tfl_slice_op.begin(), m_Constant(&begin_elems)) ||
!matchPattern(tfl_slice_op.size(), m_Constant(&size_elems))) {
return failure();
}
for (int i = 0; i < begin_elems.getNumElements(); i++)
begin_vals.push_back(begin_elems.getValues<APInt>()[i].getSExtValue());
for (int i = 0; i < size_elems.getNumElements(); i++)
size_vals.push_back(size_elems.getValues<APInt>()[i].getSExtValue());
ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
CreateReplaceOpAndInfer<tosa::SliceOp>(rewriter, op, output_type,
tfl_slice_op.input(), begin, size);
return success();
}
LogicalResult ConvertTFLTileOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_tile_op = cast<TFL::TileOp>(op);
ShapedType output_type =
tfl_tile_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!output_type) return failure();
ElementsAttr multiples_elems;
if (!matchPattern(tfl_tile_op.multiples(), m_Constant(&multiples_elems)))
return failure();
SmallVector<int64_t> multiples_vals;
for (int i = 0; i < multiples_elems.getNumElements(); i++)
multiples_vals.push_back(
multiples_elems.getValues<APInt>()[i].getSExtValue());
ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
CreateReplaceOpAndInfer<tosa::TileOp>(rewriter, op, output_type,
tfl_tile_op.input(), multiples_attr);
return success();
}
LogicalResult ConvertTFLTransposeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_transpose_op = cast<TFL::TransposeOp>(op);
Type output_type = tfl_transpose_op.getResult().getType();
CreateReplaceOpAndInfer<tosa::TransposeOp>(rewriter, op, output_type,
tfl_transpose_op.input(),
tfl_transpose_op.perm());
return success();
}
LogicalResult ConvertTFLPackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_pack_op = cast<TFL::PackOp>(op);
SmallVector<Value> inputs(tfl_pack_op.values());
assert(!inputs.empty());
IntegerAttr axis_attr;
{
auto tmpAttr = tfl_pack_op.axisAttr();
if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
axis_attr = tmpAttr;
}
int32_t axis_i32 = axis_attr.getInt();
llvm::Optional<Value> result =
convertPackOp(rewriter, op, tfl_pack_op.getResult(), inputs, axis_i32);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLUnpackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_unpack_op = cast<TFL::UnpackOp>(op);
IntegerAttr axis_attr;
{
auto tmpAttr = tfl_unpack_op.axisAttr();
if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
axis_attr = tmpAttr;
}
int32_t axis_i32 = axis_attr.getInt();
llvm::Optional<SmallVector<Value>> results =
convertUnpackOp(rewriter, op, tfl_unpack_op.input(), axis_i32);
if (!results) return failure();
rewriter.replaceOp(op, results.getValue());
return success();
}
// Splits in num_split parts along split_dim
LogicalResult ConvertTFLSplitOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_split_op = cast<TFL::SplitOp>(op);
// Get the number of splits
int32_t num_split = -1;
auto numSplitAttr = tfl_split_op.num_splitsAttr();
if (numSplitAttr) {
num_split = numSplitAttr.getInt();
} else {
return failure();
}
// Get the axis
ElementsAttr axisAttrElems;
if (!matchPattern(tfl_split_op.split_dim(), m_Constant(&axisAttrElems))) {
return op->emitOpError("Cannot read split_dim elems");
}
// The axis/split_dim parameter is stored as a 0D tensor instead of
// an integer attribute in TFLite MLIR.
int32_t axis = axisAttrElems.getValues<APInt>()[0].getSExtValue();
llvm::Optional<SmallVector<Value>> results =
convertSplitOp(rewriter, op, tfl_split_op.getResult(0),
tfl_split_op.value(), num_split, axis);
if (!results) return failure();
rewriter.replaceOp(op, results.getValue());
return success();
}
// Splits in num_split parts along split_dim
LogicalResult ConvertTFLSplitVOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_splitv_op = cast<TFL::SplitVOp>(op);
// Get the size_splits array
SmallVector<int32_t> size_split;
ElementsAttr size_split_elems;
if (!matchPattern(tfl_splitv_op.size_splits(),
m_Constant(&size_split_elems))) {
return failure();
}
for (int i = 0; i < size_split_elems.getNumElements(); i++) {
size_split.push_back(size_split_elems.getValues<APInt>()[i].getSExtValue());
}
// Get the axis
ElementsAttr axisAttrElems;
if (!matchPattern(tfl_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
return op->emitOpError("Cannot read split_dim elems");
}
// The axis/split_dim parameter is stored as a 0D tensor instead of
// an integer attribute in TFLite MLIR.
int32_t axis = axisAttrElems.getValues<APInt>()[0].getSExtValue();
llvm::Optional<SmallVector<Value>> results =
convertSplitVOp(rewriter, op, tfl_splitv_op.getResult(0),
tfl_splitv_op.value(), size_split, axis);
if (!results) return failure();
rewriter.replaceOp(op, results.getValue());
return success();
}
LogicalResult ConvertTFLPadOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_pad_op = cast<TFL::PadOp>(op);
ShapedType output_type =
tfl_pad_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!output_type) return failure();
auto pad_op =
CreateOpAndInfer<tosa::PadOp>(rewriter, op->getLoc(), output_type,
tfl_pad_op.input(), tfl_pad_op.padding());
rewriter.replaceOp(op, {pad_op.getResult()});
return success();
}
LogicalResult ConvertTFLPadV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_pad_op = cast<TFL::PadV2Op>(op);
Value input = tfl_pad_op.input();
Value padding = tfl_pad_op.padding();
Value constant_value = tfl_pad_op.constant_values();
CreateReplaceOpAndInfer<tosa::PadOp>(rewriter, op, tfl_pad_op.getType(),
input, padding, constant_value);
return success();
}
LogicalResult ConvertTFLResizeBilinearOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_resize_op = cast<TFL::ResizeBilinearOp>(op);
RankedTensorType output_type =
tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
llvm::Optional<Value> result = convertResizeOp(
rewriter, op, output_type, tfl_resize_op.input(), StringRef("BILINEAR"),
tfl_resize_op.align_cornersAttr().getValue(),
tfl_resize_op.half_pixel_centersAttr().getValue());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLResizeNearestNeighborOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_resize_op = cast<TFL::ResizeNearestNeighborOp>(op);
RankedTensorType output_type =
tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
llvm::Optional<Value> result =
convertResizeOp(rewriter, op, output_type, tfl_resize_op.input(),
StringRef("NEAREST_NEIGHBOR"),
tfl_resize_op.align_cornersAttr().getValue(),
tfl_resize_op.half_pixel_centersAttr().getValue());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSelectOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_sel_op = cast<TFL::SelectOp>(op);
llvm::Optional<Value> result =
convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSelectV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_sel_op = cast<TFL::SelectV2Op>(op);
llvm::Optional<Value> result =
convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSpaceToBatchNdOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_s2b_op = cast<TFL::SpaceToBatchNdOp>(op);
llvm::Optional<Value> result = convertSpaceToBatchNDOp(
rewriter, op, tfl_s2b_op.getResult(), tfl_s2b_op.input(),
tfl_s2b_op.block_shape(), tfl_s2b_op.paddings());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLBatchToSpaceNdOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_b2s_op = cast<TFL::BatchToSpaceNdOp>(op);
llvm::Optional<Value> result = convertBatchToSpaceNDOp(
rewriter, op, tfl_b2s_op.getResult(), tfl_b2s_op.input(),
tfl_b2s_op.block_shape(), tfl_b2s_op.indices());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSpaceToDepthOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_s2d_op = cast<TFL::SpaceToDepthOp>(op);
auto block_size_attr = tfl_s2d_op.block_sizeAttr();
llvm::Optional<Value> result = convertSpaceToDepthOp(
rewriter, op, tfl_s2d_op.getResult(), tfl_s2d_op.input(), block_size_attr,
rewriter.getStringAttr("NHWC"));
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLDepthToSpaceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_d2s_op = cast<TFL::DepthToSpaceOp>(op);
auto block_size_attr = tfl_d2s_op.block_sizeAttr();
llvm::Optional<Value> result = convertDepthToSpaceOp(
rewriter, op, tfl_d2s_op.getResult(), tfl_d2s_op.input(), block_size_attr,
rewriter.getStringAttr("NHWC"));
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLStridedSliceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_ss_op = cast<TFL::StridedSliceOp>(op);
llvm::Optional<Value> result = convertStridedSliceOp(
rewriter, op, tfl_ss_op.getResult(), tfl_ss_op.input(), tfl_ss_op.begin(),
tfl_ss_op.end(), tfl_ss_op.strides(), tfl_ss_op.begin_maskAttr().getInt(),
tfl_ss_op.end_maskAttr().getInt(), tfl_ss_op.ellipsis_maskAttr().getInt(),
tfl_ss_op.new_axis_maskAttr().getInt(),
tfl_ss_op.shrink_axis_maskAttr().getInt());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLZerosLikeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_zeroslike_op = cast<TFL::ZerosLikeOp>(op);
llvm::Optional<Value> result = convertZerosLikeOp(
rewriter, op, tfl_zeroslike_op.getResult(), tfl_zeroslike_op.input());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLHardSwishOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_hardswish_op = cast<TFL::HardSwishOp>(op);
RankedTensorType output_type =
tfl_hardswish_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
RankedTensorType input_type =
tfl_hardswish_op.input().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!input_type) return failure();
// TFL hardswish: f(x) -> (x * relu6(x+3))/6
if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
// Should match TFLite reference numerical behavior
mlir::quant::UniformQuantizedType input_qtype =
input_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
mlir::quant::UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
auto hardswish_func = [](double v) -> double {
double w = v + 3.0;
w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w;
return v * w / 6.0;
};
if (input_qtype.getStorageTypeIntegralWidth() == 8) {
// Implement with 8-bit table lookup.
Value table_const = getTosaConst8bitTable(
rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
output_qtype.getScale(), output_qtype.getZeroPoint(), hardswish_func);
CreateReplaceOpAndInfer<tosa::TableOp>(
rewriter, op, output_type, tfl_hardswish_op.input(), table_const);
}
} else {
// op1 = constop(3)
// op2 = add(x, op1)
// op3 = clamp(op2, 0, 6)
// op4 = mul(x, op3)
// op5 = reciprocal(6)
// op6 = mul (op4, op5)
Value op1_value = getTosaConstTensorSingleF32(rewriter, op, 3.0);
auto op2_add_x_op1 =
CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
tfl_hardswish_op.input(), op1_value);
auto op3_relu_op2_6 = CreateOpAndInfer<tosa::ClampOp>(
rewriter, op->getLoc(), output_type, op2_add_x_op1.getResult(),
rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(0),
rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(6.0f));
auto op4_mul_x_op3 = CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(), output_type, tfl_hardswish_op.input(),
op3_relu_op2_6.getResult(), 0);
auto const_6 = getTosaConstTensorSingleF32(rewriter, op, 6.0);
auto op5_reciprocal_6 = CreateOpAndInfer<tosa::ReciprocalOp>(
rewriter, op->getLoc(), const_6.getType(), const_6);
auto op6_mul_op4_op5 = CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(), output_type, op4_mul_x_op3.getResult(),
op5_reciprocal_6.getResult(), 0);
rewriter.replaceOp(op, {op6_mul_op4_op5.getResult()});
}
return success();
}
LogicalResult ConvertTFLSinOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_sin_op = cast<TFL::SinOp>(op);
Location loc = op->getLoc();
Value input = tfl_sin_op.x();
RankedTensorType input_ty = input.getType().dyn_cast<RankedTensorType>();
ShapedType output_ty =
tfl_sin_op.getResult().getType().dyn_cast<ShapedType>();
Type input_ety = input_ty.getElementType();
Type output_ety = output_ty.getElementType();
if (!input_ty || !output_ty) return failure();
if (input_ety != output_ety) {
return rewriter.notifyMatchFailure(
op, "ConvertTFLSinOp: input/output element type must match");
}
bool input_is_fp = input_ty.getElementType().isF32();
bool output_is_fp = output_ty.getElementType().isF32();
if (!input_is_fp || !output_is_fp) {
return rewriter.notifyMatchFailure(
op, "ConvertTFLSinOp: input/result must be fp32.");
}
// To perform a sin operation we remap the sin domain to be over a single
// period of the function, remapping to the domain of the table function.
// We then remap the range of the table function to map to the range of the
// sin operation.
// 1. Normalize the period of the domain from [0, 2Ï€) to [0, 1).
auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type());
Value fp_scale = rewriter.create<tosa::ConstOp>(
loc, fp_scalar_ty,
DenseElementsAttr::get(fp_scalar_ty, {static_cast<float>(0.5 / M_PI)}));
// 2. Remap the periodic behavior of the domain to line up within [0, 1).
Value fp_scaled = CreateOpAndInfer<tosa::MulOp>(
rewriter, loc, input_ty, input, fp_scale, rewriter.getI32IntegerAttr(0));
auto floored =
CreateOpAndInfer<tosa::FloorOp>(rewriter, loc, input_ty, fp_scaled);
auto repeated = CreateOpAndInfer<tosa::SubOp>(rewriter, loc, input_ty,
fp_scaled, floored);
// 3. Scale and translate the normalized domain to the table domain. This
// includes a translating and scaling to [-int16_max, int16_max] and casting
// to an i16.
Value one = rewriter.create<tosa::ConstOp>(
loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {1.0f}));
Value two = rewriter.create<tosa::ConstOp>(
loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {2.0f}));
auto scale_up = CreateOpAndInfer<tosa::MulOp>(
rewriter, loc, input_ty, repeated, two, rewriter.getI32IntegerAttr(0));
auto translate =
CreateOpAndInfer<tosa::SubOp>(rewriter, loc, input_ty, scale_up, one);
Value int_limit = rewriter.create<tosa::ConstOp>(
loc, fp_scalar_ty,
DenseElementsAttr::get(
fp_scalar_ty,
{static_cast<float>(std::numeric_limits<int16_t>::max())}));
auto int_scaled =
CreateOpAndInfer<tosa::MulOp>(rewriter, loc, input_ty, translate,
int_limit, rewriter.getI32IntegerAttr(0));
auto int16_ty = input_ty.clone(rewriter.getIntegerType(16));
auto casted =
CreateOpAndInfer<tosa::CastOp>(rewriter, loc, int16_ty, int_scaled);
// 4. Compute the lookup table using the range of [-255, 255] for sin.
llvm::SmallVector<int16_t> values;
const int num_values = 513;
values.resize(num_values, 0);
// First and last values should be 0;
for (int i = 1; i < num_values - 1; ++i)
values[i] = std::numeric_limits<int16_t>::max() *
sin(static_cast<float>(i) * 2.0 * M_PI / (num_values - 1.0));
auto table_ty =
RankedTensorType::get({num_values}, rewriter.getIntegerType(16));
Value table = rewriter.create<tosa::ConstOp>(
loc, table_ty,
DenseElementsAttr::get(table_ty, llvm::makeArrayRef(values)));
auto table_result_ty = input_ty.clone(rewriter.getIntegerType(32));
auto table_result = CreateOpAndInfer<tosa::TableOp>(
rewriter, loc, table_result_ty, casted, table);
// 5. The range of table is a 23-bit two's compliment value. Normalize the
// range by casting to an fp32 and dividing by 2^22.
auto table_result_fp =
CreateOpAndInfer<CastOp>(rewriter, loc, input_ty, table_result);
auto output_scale = rewriter.create<ConstOp>(
loc, fp_scalar_ty,
DenseElementsAttr::get(
fp_scalar_ty,
{static_cast<float>(1.0 / static_cast<float>(1 << 22))}));
CreateReplaceOpAndInfer<MulOp>(rewriter, op, output_ty, table_result_fp,
output_scale, rewriter.getI32IntegerAttr(0));
return success();
}
LogicalResult ConvertTFLCosOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_cos_op = cast<TFL::CosOp>(op);
Value input = tfl_cos_op.x();
RankedTensorType input_ty = input.getType().dyn_cast<RankedTensorType>();
ShapedType output_ty =
tfl_cos_op.getResult().getType().dyn_cast<ShapedType>();
if (!input_ty || !output_ty) return failure();
bool input_is_fp = input_ty.getElementType().isa<mlir::FloatType>();
bool output_is_fp = output_ty.getElementType().isa<mlir::FloatType>();
if (!input_is_fp || !output_is_fp) {
return rewriter.notifyMatchFailure(
op, "ConvertTFLCosOp: input/result must be fp.");
}
// Replace with the equivalent sin operation:
// cos(x) = sin(x + π / 2).
auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type());
auto pi_2 = rewriter.create<ConstOp>(
op->getLoc(), fp_scalar_ty,
DenseElementsAttr::get(fp_scalar_ty, {static_cast<float>(M_PI_2)}));
auto offset = rewriter.create<AddOp>(op->getLoc(), input_ty, input, pi_2);
CreateReplaceOpAndInfer<TFL::SinOp>(rewriter, op, output_ty, offset);
return success();
}
LogicalResult ConvertTFLLogisticOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_logistic_op = cast<TFL::LogisticOp>(op);
ShapedType output_type =
tfl_logistic_op.getResult().getType().dyn_cast<ShapedType>();
RankedTensorType input_type =
tfl_logistic_op.x().getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLLogisticOp: input/output tensor should "
"be all quantized or all floating-point.");
}
if (input_is_qtype) {
ShapedType int32_type = output_type.clone(rewriter.getIntegerType(32));
mlir::quant::UniformQuantizedType input_qtype =
input_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
mlir::quant::UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
auto sigmoid_func = [](double x) -> double {
return 1.0 / (1.0 + std::exp(-x));
};
if (input_qtype.getStorageTypeIntegralWidth() == 8) {
Value table_const = getTosaConst8bitTable(
rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
output_qtype.getScale(), output_qtype.getZeroPoint(), sigmoid_func);
CreateReplaceOpAndInfer<tosa::TableOp>(rewriter, op, output_type,
tfl_logistic_op.x(), table_const);
} else { // int16
if (input_qtype.getZeroPoint() != 0 || output_qtype.getZeroPoint() != 0) {
op->emitOpError(
"ConvertTFLLogistic: input/output zeropoint should be 0 in 16-bit "
"mode");
return failure();
}
double input_min = -32768 * input_qtype.getScale();
double input_max = 32767 * input_qtype.getScale();
// Generate table with gen_lut() in
// tensorflow/lite/kernels/internal/common.h
Value table_const = getTosaConst16bitTable(rewriter, op, sigmoid_func,
input_min, input_max);
auto op1_table_in = CreateOpAndInfer<tosa::TableOp>(
rewriter, op->getLoc(), int32_type, tfl_logistic_op.x(), table_const);
Value op2_rescale_op1 =
buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
1.0 / 128.0, 0, 0, false, true);
rewriter.replaceOp(op, {op2_rescale_op1});
}
} else {
CreateReplaceOpAndInfer<tosa::SigmoidOp>(rewriter, op, output_type,
tfl_logistic_op.x());
}
return success();
}
LogicalResult ConvertTFLTanhOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_tanh_op = cast<TFL::TanhOp>(op);
ShapedType output_type =
tfl_tanh_op.getResult().getType().dyn_cast<ShapedType>();
RankedTensorType input_type =
tfl_tanh_op.input().getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLTanhOp: input/output tensor should "
"be all quantized or all floating-point.");
}
if (input_is_qtype) {
ShapedType int32_type = output_type.clone(rewriter.getIntegerType(32));
mlir::quant::UniformQuantizedType input_qtype =
input_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
mlir::quant::UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
auto tanh_func = [](double x) -> double {
x = std::exp(-2.0 * x);
return (1.0 - x) / (1.0 + x);
};
if (input_qtype.getStorageTypeIntegralWidth() == 8) {
Value table_const = getTosaConst8bitTable(
rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
output_qtype.getScale(), output_qtype.getZeroPoint(), tanh_func);
CreateReplaceOpAndInfer<tosa::TableOp>(rewriter, op, output_type,
tfl_tanh_op.input(), table_const);
} else { // int16
if (input_qtype.getZeroPoint() != 0 || output_qtype.getZeroPoint() != 0) {
op->emitOpError(
"ConvertTFLLogistic: input/output zeropoint should be 0 in 16-bit "
"mode");
return failure();
}
double input_min = -32768 * input_qtype.getScale();
double input_max = 32767 * input_qtype.getScale();
// Generate table with gen_lut() in
// tensorflow/lite/kernels/internal/common.h
Value table_const =
getTosaConst16bitTable(rewriter, op, tanh_func, input_min, input_max);
auto op1_table_in = CreateOpAndInfer<tosa::TableOp>(
rewriter, op->getLoc(), int32_type, tfl_tanh_op.input(), table_const);
Value op2_rescale_op1 =
buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
1.0 / 128.0, 0, 0, false, true);
rewriter.replaceOp(op, {op2_rescale_op1});
}
} else {
CreateReplaceOpAndInfer<tosa::TanhOp>(rewriter, op, output_type,
tfl_tanh_op.input());
}
return success();
}
LogicalResult ConvertTFLPReluOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_prelu_op = cast<TFL::PReluOp>(op);
RankedTensorType output_type =
tfl_prelu_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
return failure();
}
LogicalResult ConvertTFLLeakyReluOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_leakyrelu_op = cast<TFL::LeakyReluOp>(op);
RankedTensorType input_type =
tfl_leakyrelu_op.input().getType().dyn_cast<RankedTensorType>();
ShapedType output_type =
tfl_leakyrelu_op.getResult().getType().dyn_cast<ShapedType>();
if (!input_type || !output_type) return failure();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
// Implement LeakyRelu as element-wise:
// out = x > 0 ? x : alpha * x
//
// In TOSA ops:
//
// const_zero = constant(0)
// op1 = mul(x, alpha)
// op2 = greater_equal(x, const_zero)
// out = select(a2, x, a1)
//
// If alpha can be constrained to 0.0 <= alpha <= 1.0, then
// an alternative simpler lowering could be implemented with:
//
// max(mul(x, alapha), x)
//
// But this alternative is not robust unless alpha meets those constraints.
FloatAttr tmpAttr = tfl_leakyrelu_op.alphaAttr();
// There is disagreement between the MLIR .td defaults and TF
// documentation on 0.2 vs 0.3, but 0.2 will be used here.
double alpha = 0.2;
if (tmpAttr) {
alpha = tmpAttr.getValueAsDouble();
}
if (output_is_qtype) {
// op1 = rescale(input)
// rescaled_alpha = (alpha << alpha_shift) // Remains within int32 range
// op2 = mul(rescaled_input, rescaled_alpha, alpha_shift)
// op3 = greater_equal(op1, 0)
// op4 = select(op3, op1, op2)
// out = rescale(op4)
ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
UniformQuantizedType input_qtype =
input_type.getElementType().cast<UniformQuantizedType>();
UniformQuantizedType output_qtype =
output_type.getElementType().cast<UniformQuantizedType>();
double scale_alpha =
input_qtype.getScale() * alpha / output_qtype.getScale();
double scale_identity = input_qtype.getScale() / output_qtype.getScale();
Value op1_rescale_in =
buildRescaleToInt32(rewriter, op, tfl_leakyrelu_op.input(), 1.0,
input_qtype.getZeroPoint());
Value const_zero = getTosaConstTensorSingleI32(rewriter, op, 0);
auto op2_ge = CreateOpAndInfer<tosa::GreaterEqualOp>(
rewriter, op->getLoc(), rescale_type.clone(rewriter.getI1Type()),
op1_rescale_in, const_zero);
Value op3_rescale_alpha_in = buildRescale(
rewriter, op, output_type, tfl_leakyrelu_op.input(), scale_alpha,
input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true);
Value op4_rescale_identity_in = buildRescale(
rewriter, op, output_type, tfl_leakyrelu_op.input(), scale_identity,
input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true);
CreateReplaceOpAndInfer<tosa::SelectOp>(rewriter, op, output_type, op2_ge,
op4_rescale_identity_in,
op3_rescale_alpha_in);
return success();
} else {
Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0);
auto op1_mul = CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(), output_type, tfl_leakyrelu_op.input(),
getTosaConstTensorSingleF32(rewriter, op, alpha), 0);
auto op2_ge = CreateOpAndInfer<tosa::GreaterEqualOp>(
rewriter, op->getLoc(), output_type.clone(rewriter.getIntegerType(1)),
tfl_leakyrelu_op.input(), const_zero);
CreateReplaceOpAndInfer<tosa::SelectOp>(rewriter, op, output_type, op2_ge,
tfl_leakyrelu_op.input(),
op1_mul.getResult());
return success();
}
}
LogicalResult ConvertTFLNegOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_neg_op = cast<TFL::NegOp>(op);
ShapedType output_type =
tfl_neg_op.getResult().getType().dyn_cast<ShapedType>();
if (!output_type) return failure();
CreateReplaceOpAndInfer<tosa::NegateOp>(rewriter, op, output_type,
tfl_neg_op.x());
return success();
}
LogicalResult ConvertTFLYieldOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
rewriter.replaceOpWithNewOp<tosa::YieldOp>(op, op->getResultTypes(),
op->getOperands());
return success();
}
LogicalResult ConvertTFLCustomOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_custom_op = cast<TFL::CustomOp>(op);
rewriter.replaceOpWithNewOp<tosa::CustomOp>(
op, op->getResultTypes(), tfl_custom_op.custom_code(), op->getOperands());
return success();
}
LogicalResult ConvertTFLReverseV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_reverse_op = cast<TFL::ReverseV2Op>(op);
RankedTensorType input_type =
tfl_reverse_op.input().getType().dyn_cast<RankedTensorType>();
RankedTensorType output_type =
tfl_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) return failure();
ElementsAttr axis_elems;
if (!matchPattern(tfl_reverse_op.axis(), m_Constant(&axis_elems)))
return failure();
auto input_rank = input_type.getShape().size();
Value val = tfl_reverse_op.input();
if (axis_elems.getNumElements() == 0) {
auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
rewriter, op->getLoc(), output_type, val);
val = identity_op.getResult();
} else {
for (int i = 0; i < axis_elems.getNumElements(); i++) {
int64_t axis_val = axis_elems.getValues<APInt>()[i].getSExtValue();
if (axis_val < 0) axis_val += input_rank;
auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
auto reverse_op = CreateOpAndInfer<tosa::ReverseOp>(
rewriter, op->getLoc(), output_type, val, axis_attr);
val = reverse_op.getResult();
}
}
rewriter.replaceOp(op, {val});
return success();
}
LogicalResult ConvertTFLQuantizeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_quantize_op = cast<TFL::QuantizeOp>(op);
RankedTensorType input_type =
tfl_quantize_op.input().getType().dyn_cast<RankedTensorType>();
ShapedType output_type =
tfl_quantize_op.getResult().getType().dyn_cast<ShapedType>();
if (!input_type || !output_type) return failure();
ShapedType qtype =
tfl_quantize_op.getResult().getType().dyn_cast<ShapedType>();
if (!qtype) return failure();
UniformQuantizedType element_type =
qtype.getElementType().dyn_cast<UniformQuantizedType>();
if (!element_type) return failure();
UniformQuantizedType input_element_type =
input_type.getElementType().dyn_cast<UniformQuantizedType>();
// If input is already a quantized type, this is basically a RESCALE (or
// tensorflow::ops::Requantize)
if (input_element_type) {
double rescale_scale =
input_element_type.getScale() / element_type.getScale();
Value rescale_op =
buildRescale(rewriter, op, output_type, tfl_quantize_op.input(),
rescale_scale, input_element_type.getZeroPoint(),
element_type.getZeroPoint(), true, true);
rewriter.replaceOp(op, {rescale_op});
return success();
} else {
double scale = 1 / element_type.getScale();
int64_t zp = element_type.getZeroPoint();
int64_t num_bits = element_type.getStorageTypeIntegralWidth();
zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
llvm::Optional<Value> result = convertQuantizeOp(
rewriter, op, output_type, tfl_quantize_op.input(), scale, zp);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
}
LogicalResult ConvertTFLDequantizeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_dequantize_op = cast<TFL::DequantizeOp>(op);
ShapedType output_type =
tfl_dequantize_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!output_type) return failure();
RankedTensorType qtype =
tfl_dequantize_op.input().getType().dyn_cast<RankedTensorType>();
if (!qtype) return failure();
Type element_type = qtype.getElementType();
if (element_type.isa<FloatType>()) {
CreateReplaceOpAndInfer<tosa::CastOp>(rewriter, op, output_type,
tfl_dequantize_op.input());
return success();
}
if (auto eq_ty = element_type.dyn_cast<quant::UniformQuantizedType>()) {
double scale = eq_ty.getScale();
int64_t zp = eq_ty.getZeroPoint();
int64_t num_bits = eq_ty.getStorageTypeIntegralWidth();
zp = eq_ty.isSigned() ? zp : zp - (1 << (num_bits - 1));
llvm::Optional<Value> result = convertDequantizeOp(
rewriter, op, output_type, tfl_dequantize_op.input(), scale, zp, 0);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
if (quant::UniformQuantizedPerAxisType eq_ty =
element_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
SmallVector<float> zps;
for (auto zp : eq_ty.getZeroPoints()) {
int64_t num_bits = eq_ty.getStorageTypeIntegralWidth();
zps.push_back(eq_ty.isSigned() ? zp : zp - (1 << (num_bits - 1)));
}
SmallVector<float> scales;
for (auto scale : eq_ty.getScales()) {
scales.push_back(scale);
}
llvm::Optional<Value> result = convertDequantizeOp(
rewriter, op, output_type, tfl_dequantize_op.input(), scales, zps,
eq_ty.getQuantizedDimension());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
return failure();
}
LogicalResult ConvertTFLConstOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_const_op = cast<TFL::ConstOp>(op);
ShapedType output_type =
tfl_const_op.getResult().getType().dyn_cast<ShapedType>();
if (!output_type) return failure();
ElementsAttr elements = tfl_const_op.value();
Type element_type = elements.getType().getElementType();
if (output_type.getElementType().isa<quant::QuantizedType>()) {
output_type = RankedTensorType::get(output_type.getShape(), element_type);
}
// If the output shape is unranked we can extract the result shape from the
// attribute shape. This occurs as some TFLite folders create constants with
// unranked shapes.
if (!output_type.hasRank()) {
output_type = elements.getType().cast<ShapedType>().clone(element_type);
}
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, elements);
return success();
}
LogicalResult ConvertTFLQConstOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_qconst_op = cast<TFL::QConstOp>(op);
ShapedType output_type =
tfl_qconst_op.getResult().getType().dyn_cast<ShapedType>();
if (!output_type) return failure();
ElementsAttr elements = tfl_qconst_op.value();
// If the output shape is unranked we can extract the result shape from the
// attribute shape. This occurs as some TFLite folders create constants with
// unranked shapes.
if (!output_type.hasRank()) {
output_type = elements.getType().cast<ShapedType>().clone(
output_type.getElementType());
}
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, elements);
return success();
}
LogicalResult ConvertConstantOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_const_op = cast<arith::ConstantOp>(op);
ShapedType output_type =
tfl_const_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!output_type) return failure();
ElementsAttr attr = tfl_const_op.getValueAttr().dyn_cast<ElementsAttr>();
auto e_type = output_type.getElementType();
// TOSA only support up to 48-bits
// If source is higher than that, it's not representabble.
// For data type like 64 bits, we need to truncate them into 48 bits.
if (e_type.isInteger(64)) {
e_type = rewriter.getIntegerType(48);
attr = attr.cast<DenseIntOrFPElementsAttr>().mapValues(
e_type, [](const APInt& x) -> APInt { return x.trunc(48); });
}
if (!output_type.hasRank()) {
if (auto attr_type = attr.getType().dyn_cast<ShapedType>()) {
output_type = attr_type.clone(e_type);
}
}
output_type = output_type.clone(e_type);
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, attr);
return success();
}
LogicalResult ConvertTFLGatherOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_gather_op = cast<TFL::GatherOp>(op);
int32_t axis = tfl_gather_op.axisAttr().getInt();
int32_t batch_dims = 0;
if (auto batch_attr = tfl_gather_op.batch_dimsAttr()) {
batch_dims = static_cast<int32_t>(batch_attr.getInt());
}
llvm::Optional<Value> result = convertGatherOp(
rewriter, op, tfl_gather_op.getResult(), tfl_gather_op.params(),
tfl_gather_op.indices(), batch_dims, axis);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLGatherNdOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_gathernd_op = cast<TFL::GatherNdOp>(op);
llvm::Optional<Value> result =
convertGatherNdOp(rewriter, op, tfl_gathernd_op.getResult(),
tfl_gathernd_op.params(), tfl_gathernd_op.indices());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_sparse_to_dense_op = cast<TFL::SparseToDenseOp>(op);
auto indices = tfl_sparse_to_dense_op.sparse_indices();
auto values = tfl_sparse_to_dense_op.sparse_values();
auto default_value = tfl_sparse_to_dense_op.default_value();
auto indices_ty = indices.getType().cast<ShapedType>();
auto indices_ety = indices_ty.getElementType();
auto values_ty = values.getType().cast<ShapedType>();
auto result_ty =
tfl_sparse_to_dense_op.getResult().getType().cast<ShapedType>();
auto result_ety = result_ty.getElementType();
auto loc = op->getLoc();
if (!result_ty.hasStaticShape()) return failure();
auto result_rank = result_ty.getRank();
// We want to generate the default tensor we need to scatter. Note that the
// result_ty needs to be a statically shaped tensor.
ElementsAttr default_value_attr;
if (!matchPattern(default_value, m_Constant(&default_value_attr)))
return failure();
if (!default_value_attr.isSplat()) return failure();
ShapedType scatter_ty =
RankedTensorType::get({1, result_ty.getNumElements(), 1}, result_ety);
Value default_const = rewriter.create<tosa::ConstOp>(
loc, scatter_ty,
DenseElementsAttr::get(scatter_ty,
default_value_attr.getSplatValue<APInt>().sext(
result_ety.getIntOrFloatBitWidth())));
// We need to determine what the index multiplier does
llvm::SmallVector<int32_t> multiply_constant_ints;
multiply_constant_ints.resize(result_rank, 1);
for (int i = result_rank - 1; i > 0; i--) {
multiply_constant_ints[i - 1] =
result_ty.getDimSize(i) * multiply_constant_ints[i];
}
indices_ety = rewriter.getI32Type();
indices_ty = RankedTensorType::get(indices_ty.getShape(), indices_ety);
indices = CreateOpAndInfer<tosa::CastOp>(rewriter, loc, indices_ty, indices);
auto multiply_constant_type =
RankedTensorType::get({result_rank}, indices_ety);
auto multiply_constant_attr = DenseElementsAttr::get(
multiply_constant_type, llvm::makeArrayRef(multiply_constant_ints));
Value multiply_constant = CreateOpAndInfer<tosa::ConstOp>(
rewriter, loc, multiply_constant_type, multiply_constant_attr);
Value multiply_op = CreateOpAndInfer<tosa::MulOp>(
rewriter, loc, indices_ty, indices, multiply_constant, 0);
Value reduce_op = CreateOpAndInfer<tosa::ReduceSumOp>(
rewriter, loc, UnrankedTensorType::get(indices_ety), multiply_op,
rewriter.getI64IntegerAttr(1));
auto values_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(result_ety), values,
rewriter.getI64ArrayAttr(
ArrayRef<int64_t>{1, values_ty.getDimSize(0), 1}));
auto index_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(indices_ety), reduce_op,
rewriter.getI64ArrayAttr(ArrayRef<int64_t>{1, indices_ty.getDimSize(0)}));
auto scatter = CreateOpAndInfer<tosa::ScatterOp>(
rewriter, loc, UnrankedTensorType::get(result_ety), default_const,
index_reshape_op, values_reshape_op);
CreateReplaceOpAndInfer<tosa::ReshapeOp>(
rewriter, op, result_ty, scatter,
rewriter.getI64ArrayAttr(result_ty.getShape()));
return success();
}
LogicalResult ConvertTFLOneHotOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_one_hot_op = cast<TFL::OneHotOp>(op);
ElementsAttr depth_elems;
if (!matchPattern(tfl_one_hot_op.depth(), m_Constant(&depth_elems)))
return failure();
int32_t depth = depth_elems.getValues<APInt>()[0].getSExtValue();
IntegerAttr axisAttr = tfl_one_hot_op.axisAttr();
int32_t axis = axisAttr.getInt();
llvm::Optional<Value> result = convertOneHotOp(
rewriter, op, tfl_one_hot_op.getResult(), tfl_one_hot_op.indices(),
tfl_one_hot_op.on_value(), tfl_one_hot_op.off_value(), depth, axis);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLArgMaxOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto arg_max_op = cast<TFL::ArgMaxOp>(op);
ElementsAttr dim_elems;
if (!matchPattern(arg_max_op.dim(), m_Constant(&dim_elems))) return failure();
int32_t dim = dim_elems.getValues<APInt>()[0].getSExtValue();
CreateReplaceOpAndInfer<tosa::ArgMaxOp>(
rewriter, op, arg_max_op.getType(), arg_max_op.input(),
rewriter.getIntegerAttr(rewriter.getI64Type(), dim));
return success();
}
LogicalResult ConvertTFLFakeQuantOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto fakequant_op = cast<TFL::FakeQuantOp>(op);
ShapedType output_type =
fakequant_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!output_type) return failure();
llvm::Optional<Value> result =
convertFakeQuantOp(rewriter, op, output_type, fakequant_op.input(),
fakequant_op.minAttr().getValueAsDouble(),
fakequant_op.maxAttr().getValueAsDouble(),
fakequant_op.num_bitsAttr().getInt(),
fakequant_op.narrow_rangeAttr().getValue());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult LegalizeTFL::initialize(MLIRContext* context) {
RewritePatternSet patterns(context);
mlir::tosa::populateLegalizeTFLPatterns(context, patterns);
frozen_patterns_ = FrozenRewritePatternSet(
std::move(patterns), this->disabled_patterns_, this->enabled_patterns_);
return success();
}
void LegalizeTFL::runOnOperation() {
if (ApplyPatternsWithShapeResolution(getOperation(), this->frozen_patterns_)
.failed()) {
signalPassFailure();
}
}
} // namespace
void populateLegalizeTFLPatterns(MLIRContext* ctx,
RewritePatternSet& patterns) {
#define DEF_PATTERN_INSERT(PAT) \
patterns.addWithLabel<Convert##PAT##Op>({#PAT}, ctx);
DEF_PATTERN_INSERT(TFLAbs);
DEF_PATTERN_INSERT(TFLCeil);
DEF_PATTERN_INSERT(TFLFloor);
DEF_PATTERN_INSERT(TFLExp);
DEF_PATTERN_INSERT(TFLLog);
DEF_PATTERN_INSERT(TFLRsqrt);
DEF_PATTERN_INSERT(TFLLogicalNot);
DEF_PATTERN_INSERT(TFLCast);
DEF_PATTERN_INSERT(QuantStat);
DEF_PATTERN_INSERT(TFLLogicalAnd);
DEF_PATTERN_INSERT(TFLLogicalOr);
DEF_PATTERN_INSERT(TFLPow);
DEF_PATTERN_INSERT(TFLRelu);
DEF_PATTERN_INSERT(TFLRelu6);
DEF_PATTERN_INSERT(TFLEqual);
DEF_PATTERN_INSERT(TFLNotEqual);
DEF_PATTERN_INSERT(TFLGreater);
DEF_PATTERN_INSERT(TFLGreaterEqual);
DEF_PATTERN_INSERT(TFLAdd);
DEF_PATTERN_INSERT(TFLSub);
DEF_PATTERN_INSERT(TFLMul);
DEF_PATTERN_INSERT(TFLSquare);
DEF_PATTERN_INSERT(TFLSquaredDifference);
DEF_PATTERN_INSERT(TFLRound);
DEF_PATTERN_INSERT(TFLDiv);
DEF_PATTERN_INSERT(TFLMaximum);
DEF_PATTERN_INSERT(TFLMinimum);
DEF_PATTERN_INSERT(TFLFloorMod);
DEF_PATTERN_INSERT(TFLFloorDiv);
DEF_PATTERN_INSERT(TFLAddN);
DEF_PATTERN_INSERT(TFLAveragePool2D);
DEF_PATTERN_INSERT(TFLMaxPool2D);
DEF_PATTERN_INSERT(TFLConcatenation);
DEF_PATTERN_INSERT(TFLReshape);
DEF_PATTERN_INSERT(TFLRank);
DEF_PATTERN_INSERT(TFLShape);
DEF_PATTERN_INSERT(TFLExpandDims);
DEF_PATTERN_INSERT(TFLSqueeze);
DEF_PATTERN_INSERT(TFLFill);
DEF_PATTERN_INSERT(TFLElu);
DEF_PATTERN_INSERT(TFLSoftmax);
DEF_PATTERN_INSERT(TFLLogSoftmax);
DEF_PATTERN_INSERT(TFLSqrt);
DEF_PATTERN_INSERT(TFLL2Normalization);
DEF_PATTERN_INSERT(TFLReduceAny);
DEF_PATTERN_INSERT(TFLReduceMax);
DEF_PATTERN_INSERT(TFLReduceMin);
DEF_PATTERN_INSERT(TFLMean);
DEF_PATTERN_INSERT(TFLReduceProd);
DEF_PATTERN_INSERT(TFLSum);
DEF_PATTERN_INSERT(TFLConv2D);
DEF_PATTERN_INSERT(TFLTransposeConv);
DEF_PATTERN_INSERT(TFLDepthwiseConv2D);
DEF_PATTERN_INSERT(TFLFullyConnected);
DEF_PATTERN_INSERT(TFLBatchMatMul);
DEF_PATTERN_INSERT(TFLSplit);
DEF_PATTERN_INSERT(TFLSplitV);
DEF_PATTERN_INSERT(TFLPack);
DEF_PATTERN_INSERT(TFLUnpack);
DEF_PATTERN_INSERT(TFLTranspose);
DEF_PATTERN_INSERT(TFLTile);
DEF_PATTERN_INSERT(TFLSlice);
DEF_PATTERN_INSERT(TFLStridedSlice);
DEF_PATTERN_INSERT(TFLHardSwish);
DEF_PATTERN_INSERT(TFLZerosLike);
DEF_PATTERN_INSERT(TFLLess);
DEF_PATTERN_INSERT(TFLLessEqual);
DEF_PATTERN_INSERT(TFLPad);
DEF_PATTERN_INSERT(TFLPadV2);
DEF_PATTERN_INSERT(TFLResizeBilinear);
DEF_PATTERN_INSERT(TFLResizeNearestNeighbor);
DEF_PATTERN_INSERT(TFLSelect);
DEF_PATTERN_INSERT(TFLSelectV2);
DEF_PATTERN_INSERT(TFLSpaceToBatchNd);
DEF_PATTERN_INSERT(TFLBatchToSpaceNd);
DEF_PATTERN_INSERT(TFLSpaceToDepth);
DEF_PATTERN_INSERT(TFLDepthToSpace);
DEF_PATTERN_INSERT(TFLSin);
DEF_PATTERN_INSERT(TFLCos);
DEF_PATTERN_INSERT(TFLLogistic);
DEF_PATTERN_INSERT(TFLTanh);
DEF_PATTERN_INSERT(TFLPRelu);
DEF_PATTERN_INSERT(TFLLeakyRelu);
DEF_PATTERN_INSERT(TFLNeg);
DEF_PATTERN_INSERT(TFLYield);
DEF_PATTERN_INSERT(TFLCustom);
DEF_PATTERN_INSERT(TFLReverseV2);
DEF_PATTERN_INSERT(TFLQuantize);
DEF_PATTERN_INSERT(TFLDequantize);
DEF_PATTERN_INSERT(TFLConst);
DEF_PATTERN_INSERT(TFLQConst);
DEF_PATTERN_INSERT(TFLGather);
DEF_PATTERN_INSERT(TFLGatherNd);
DEF_PATTERN_INSERT(TFLSparseToDense);
DEF_PATTERN_INSERT(Constant);
DEF_PATTERN_INSERT(TFLOneHot);
DEF_PATTERN_INSERT(TFLArgMax);
DEF_PATTERN_INSERT(TFLFakeQuant);
}
// Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass.
std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeTFLPass(
ArrayRef<std::string> disabled_patterns,
ArrayRef<std::string> enabled_patterns) {
return std::make_unique<LegalizeTFL>(disabled_patterns, enabled_patterns);
}
} // namespace tosa
} // namespace mlir