blob: d027539f8139adcf594f4ec29abc97a7a5e9f449 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legalize TensorFlow Lite to TOSA
#include <climits>
#include <cstddef>
#include <cstdint>
#include <fstream>
#include <iterator>
#include <numeric>
#include <unordered_set>
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
#define PASS_NAME "tosa-legalize-tfl"
#define DEBUG_TYPE PASS_NAME
#define HARDSWISH_EXPLICIT_RESCALING false
// Conditionally avoid converting some TFLite ops to TOSA.
// By default, all conversions will be invoked.
//
// The denylist file lists patterns which are not legalized from TFLite to TOSA.
llvm::cl::opt<std::string> tfl_tosa_denylist(
"tfl-tosa-denylist",
llvm::cl::desc("<a list of patterns not legalized from TFLite to TOSA>"),
llvm::cl::init("transforms/tfl_tosa_denylist.txt"),
llvm::cl::value_desc("pattern name"));
namespace mlir {
namespace tosa {
namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc"
// Performs lowering to TOSA dialect.
class LegalizeTFL : public TosaLegalizeTFLPassBase<LegalizeTFL> {
public:
explicit LegalizeTFL() {}
void runOnFunction() override;
};
#include "tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.inc"
#define DECL_CONVERT_OP(tfl_op) \
struct ConvertTFL##tfl_op##Op : public RewritePattern { \
explicit ConvertTFL##tfl_op##Op(MLIRContext* context) \
: RewritePattern(TFL::tfl_op##Op::getOperationName(), 1, context) {} \
LogicalResult matchAndRewrite(Operation* op, \
PatternRewriter& rewriter) const override; \
}
DECL_CONVERT_OP(Relu);
DECL_CONVERT_OP(Relu6);
DECL_CONVERT_OP(Equal);
DECL_CONVERT_OP(NotEqual);
DECL_CONVERT_OP(Greater);
DECL_CONVERT_OP(GreaterEqual);
DECL_CONVERT_OP(Add);
DECL_CONVERT_OP(Sub);
DECL_CONVERT_OP(Mul);
DECL_CONVERT_OP(Square);
DECL_CONVERT_OP(SquaredDifference);
DECL_CONVERT_OP(Round);
DECL_CONVERT_OP(Div);
DECL_CONVERT_OP(Maximum);
DECL_CONVERT_OP(Minimum);
DECL_CONVERT_OP(FloorMod);
DECL_CONVERT_OP(FloorDiv);
DECL_CONVERT_OP(AddN);
DECL_CONVERT_OP(AveragePool2D);
DECL_CONVERT_OP(MaxPool2D);
DECL_CONVERT_OP(Concatenation);
DECL_CONVERT_OP(Reshape);
DECL_CONVERT_OP(Rank);
DECL_CONVERT_OP(Shape);
DECL_CONVERT_OP(ExpandDims);
DECL_CONVERT_OP(Squeeze);
DECL_CONVERT_OP(Fill);
DECL_CONVERT_OP(Elu);
DECL_CONVERT_OP(Softmax);
DECL_CONVERT_OP(LogSoftmax);
DECL_CONVERT_OP(Sqrt);
DECL_CONVERT_OP(ReduceAny);
DECL_CONVERT_OP(ReduceMax);
DECL_CONVERT_OP(ReduceMin);
DECL_CONVERT_OP(Mean);
DECL_CONVERT_OP(ReduceProd);
DECL_CONVERT_OP(Sum);
DECL_CONVERT_OP(Conv2D);
DECL_CONVERT_OP(TransposeConv);
DECL_CONVERT_OP(DepthwiseConv2D);
DECL_CONVERT_OP(FullyConnected);
DECL_CONVERT_OP(Split);
DECL_CONVERT_OP(SplitV);
DECL_CONVERT_OP(Pack);
DECL_CONVERT_OP(Unpack);
DECL_CONVERT_OP(Transpose);
DECL_CONVERT_OP(Tile);
DECL_CONVERT_OP(Slice);
DECL_CONVERT_OP(StridedSlice);
DECL_CONVERT_OP(HardSwish);
DECL_CONVERT_OP(ZerosLike);
DECL_CONVERT_OP(Less);
DECL_CONVERT_OP(LessEqual);
DECL_CONVERT_OP(Pad);
DECL_CONVERT_OP(ResizeBilinear);
DECL_CONVERT_OP(ResizeNearestNeighbor);
DECL_CONVERT_OP(Select);
DECL_CONVERT_OP(SelectV2);
DECL_CONVERT_OP(SpaceToBatchNd);
DECL_CONVERT_OP(BatchToSpaceNd);
DECL_CONVERT_OP(SpaceToDepth);
DECL_CONVERT_OP(DepthToSpace);
DECL_CONVERT_OP(Logistic);
DECL_CONVERT_OP(Tanh);
DECL_CONVERT_OP(PRelu);
DECL_CONVERT_OP(LeakyRelu);
DECL_CONVERT_OP(Neg);
DECL_CONVERT_OP(Yield);
DECL_CONVERT_OP(Custom);
DECL_CONVERT_OP(ReverseV2);
DECL_CONVERT_OP(Quantize);
DECL_CONVERT_OP(Dequantize);
DECL_CONVERT_OP(QConst);
DECL_CONVERT_OP(Gather);
DECL_CONVERT_OP(GatherNd);
DECL_CONVERT_OP(OneHot);
DECL_CONVERT_OP(ArgMax);
DECL_CONVERT_OP(FakeQuant);
#undef DECL_CONVERT_OP
// Input from tfl.conv2d takes 64 bits a bias, while tosa.conv2d expects 48
// bits. Need to do a customized truncate here instead of tablegen to handle
// attribute with negative value.
struct ConvertConstantOp : public RewritePattern {
explicit ConvertConstantOp(MLIRContext* context)
: RewritePattern(ConstantOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override;
};
LogicalResult ConvertTFLReluOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_relu_op = cast<TFL::ReluOp>(op);
ShapedType input_type = tfl_relu_op.x().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_relu_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!input_type || !output_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLReluOp: input/output tensor should "
"be all quantized or all floating-point.");
}
int64_t clamp_min = 0;
Value clamp_in = tfl_relu_op.x();
if (output_is_qtype) {
UniformQuantizedType input_qtype =
input_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
clamp_min = output_qtype.getZeroPoint();
clamp_in =
buildRescale(rewriter, op, output_type, tfl_relu_op.x(),
input_qtype.getScale() / output_qtype.getScale(),
input_qtype.getZeroPoint(), output_qtype.getZeroPoint(),
/*double_round=*/false, /*scale32=*/true);
}
CreateReplaceOpAndInfer<tosa::ClampOp>(
rewriter, op, output_type, clamp_in,
rewriter.getI64IntegerAttr(clamp_min),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
return success();
}
LogicalResult ConvertTFLRelu6Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_relu6_op = cast<TFL::Relu6Op>(op);
ShapedType input_type = tfl_relu6_op.x().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_relu6_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!input_type || !output_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLRelu6Op: input/output tensor should "
"be all quantized or all floating-point.");
}
int64_t clamp_min = 0;
int64_t clamp_max = 6;
Value clamp_in = tfl_relu6_op.x();
if (output_is_qtype && input_is_qtype) {
UniformQuantizedType input_qtype =
input_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
clamp_min = output_qtype.getZeroPoint();
clamp_max = std::llround(6.0f / output_qtype.getScale()) +
output_qtype.getZeroPoint();
clamp_in =
buildRescale(rewriter, op, output_type, tfl_relu6_op.x(),
input_qtype.getScale() / output_qtype.getScale(),
input_qtype.getZeroPoint(), output_qtype.getZeroPoint(),
/*double_round=*/false, /*scale32=*/true);
}
CreateReplaceOpAndInfer<tosa::ClampOp>(rewriter, op, output_type, clamp_in,
rewriter.getI64IntegerAttr(clamp_min),
rewriter.getI64IntegerAttr(clamp_max),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(6.0f));
return success();
}
static LogicalResult prepareMatchAndRewriteComparison(
Operation* op, mlir::OperandRange operands, PatternRewriter& rewriter,
llvm::SmallVectorImpl<Value>& newOperands) {
Value x = operands[0];
Value y = operands[1];
Value result = op->getResult(0);
ShapedType input_x_type = x.getType().dyn_cast<ShapedType>();
ShapedType input_y_type = y.getType().dyn_cast<ShapedType>();
ShapedType output_type = result.getType().dyn_cast<ShapedType>();
// Not a shaped tensor output
if (!input_x_type || !input_y_type || !output_type) return failure();
bool input_x_is_qtype =
input_x_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool input_y_is_qtype =
input_y_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_x_is_qtype != input_y_is_qtype ||
input_y_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLEqualOp: input/output tensor should "
"be all quantized or all floating-point.");
}
if (!output_is_qtype && !input_x_is_qtype && !input_y_is_qtype) {
newOperands.push_back(x);
newOperands.push_back(y);
return success();
}
UniformQuantizedType input_x_qtype =
input_x_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
UniformQuantizedType input_y_qtype =
input_y_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
if (input_x_qtype.getScale() != input_y_qtype.getScale() ||
input_x_qtype.getZeroPoint() != input_y_qtype.getZeroPoint()) {
return op->emitOpError(
"ConvertTFLEqualOp: input_x and input_y scale/zp "
"must be the same");
}
x = buildRescaleToInt32(rewriter, op, x, 1.0f, input_x_qtype.getZeroPoint());
y = buildRescaleToInt32(rewriter, op, y, 1.0f, input_y_qtype.getZeroPoint());
newOperands.push_back(x);
newOperands.push_back(y);
return success();
}
LogicalResult ConvertTFLEqualOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
CreateReplaceOpAndInfer<tosa::EqualOp>(
rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
return success();
}
LogicalResult ConvertTFLNotEqualOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
auto equal_op = CreateOpAndInfer<tosa::EqualOp>(
rewriter, op->getLoc(), op->getResult(0).getType(), newOperands[0],
newOperands[1]);
CreateReplaceOpAndInfer<tosa::LogicalNotOp>(
rewriter, op, op->getResult(0).getType(), equal_op);
return success();
}
LogicalResult ConvertTFLGreaterOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
CreateReplaceOpAndInfer<tosa::GreaterOp>(
rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
return success();
}
LogicalResult ConvertTFLGreaterEqualOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
CreateReplaceOpAndInfer<tosa::GreaterEqualOp>(
rewriter, op, op->getResult(0).getType(), newOperands[0], newOperands[1]);
return success();
}
LogicalResult ConvertTFLLessOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
CreateReplaceOpAndInfer<tosa::GreaterOp>(
rewriter, op, op->getResult(0).getType(), newOperands[1], newOperands[0]);
return success();
}
LogicalResult ConvertTFLLessEqualOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
llvm::SmallVector<Value, 2> newOperands;
LogicalResult status = prepareMatchAndRewriteComparison(
op, op->getOperands(), rewriter, newOperands);
if (status.failed()) return failure();
// Swapping the args handles the greater/less difference.
CreateReplaceOpAndInfer<tosa::GreaterEqualOp>(
rewriter, op, op->getResult(0).getType(), newOperands[1], newOperands[0]);
return success();
}
template <typename TflOp, typename TosaOp>
static LogicalResult matchAndRewriteAddSub(Operation* op,
mlir::OperandRange operands,
PatternRewriter& rewriter) {
auto tfl_add_op = cast<TflOp>(op);
ShapedType input_lhs_type =
tfl_add_op.lhs().getType().template dyn_cast<ShapedType>();
ShapedType input_rhs_type =
tfl_add_op.rhs().getType().template dyn_cast<ShapedType>();
ShapedType output_type =
tfl_add_op.getResult().getType().template dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
bool input_lhs_is_qtype =
input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool input_rhs_is_qtype =
input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_lhs_is_qtype != output_is_qtype ||
input_rhs_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLAddOp: input/output tensor should "
"be all quantized or all floating-point.");
}
Value output;
if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) {
ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
UniformQuantizedType input_lhs_qtype =
input_lhs_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
UniformQuantizedType input_rhs_qtype =
input_rhs_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
// Following quantization described in tensorflow/lite/kernels/add.cc
// In details it does:
// 1. Rescale inputs to scale = 2.0 x max(lhs.scale, rhs.scale)
// 2. Extra left shift to input to increase precision
// Where input_shift = 20 if input is 8-bit
// input_shift = 15 if input is 16-bit
double in_lhs_scale = input_lhs_qtype.getScale();
double in_rhs_scale = input_rhs_qtype.getScale();
double output_scale = output_qtype.getScale();
double max_scale_2x = 2.0 * std::max(in_lhs_scale, in_rhs_scale);
const int32_t SHIFT_8_BIT = 20;
const int32_t SHIFT_16_BIT = 15;
int32_t input_shift = (output_qtype.getStorageTypeIntegralWidth() == 16)
? SHIFT_16_BIT
: SHIFT_8_BIT;
double lhs_rescale_scale =
static_cast<double>(1 << input_shift) * in_lhs_scale / max_scale_2x;
double rhs_rescale_scale =
static_cast<double>(1 << input_shift) * in_rhs_scale / max_scale_2x;
double output_rescale_scale =
max_scale_2x / (output_scale * static_cast<double>(1 << input_shift));
Value op1_rescale_lhs =
buildRescaleToInt32(rewriter, op, tfl_add_op.lhs(), lhs_rescale_scale,
input_lhs_qtype.getZeroPoint());
Value op2_rescale_rhs =
buildRescaleToInt32(rewriter, op, tfl_add_op.rhs(), rhs_rescale_scale,
input_rhs_qtype.getZeroPoint());
auto op3_add_op1_op2 = CreateOpAndInfer<TosaOp>(
rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
Value op4_rescale_op3 = buildRescaleFromInt32(
rewriter, op, output_type, op3_add_op1_op2.getResult(),
output_rescale_scale, output_qtype.getZeroPoint());
output = op4_rescale_op3;
} else {
auto op1_add_in =
CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), output_type,
tfl_add_op.lhs(), tfl_add_op.rhs());
output = op1_add_in.getResult();
}
auto fused_activation_fn = tfl_add_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val =
convertFusedActivation(rewriter, op, output, fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {output});
return success();
}
LogicalResult ConvertTFLAddOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
return matchAndRewriteAddSub<TFL::AddOp, tosa::AddOp>(op, op->getOperands(),
rewriter);
}
LogicalResult ConvertTFLSubOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
return matchAndRewriteAddSub<TFL::SubOp, tosa::SubOp>(op, op->getOperands(),
rewriter);
}
LogicalResult ConvertTFLMulOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_mul_op = cast<TFL::MulOp>(op);
llvm::Optional<Value> result = convertMultiplyOp(
rewriter, op, tfl_mul_op.getResult(), tfl_mul_op.lhs(), tfl_mul_op.rhs());
if (!result) return failure();
auto fused_activation_fn = tfl_mul_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val = convertFusedActivation(
rewriter, op, result.getValue(), fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSquareOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_square_op = cast<TFL::SquareOp>(op);
llvm::Optional<Value> result =
convertMultiplyOp(rewriter, op, tfl_square_op.getResult(),
tfl_square_op.x(), tfl_square_op.x());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSquaredDifferenceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_squared_op = cast<TFL::SquaredDifferenceOp>(op);
llvm::Optional<Value> result =
convertSquaredDifferenceOp(rewriter, op, tfl_squared_op.getResult(),
tfl_squared_op.lhs(), tfl_squared_op.rhs());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLRoundOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_round_op = cast<TFL::RoundOp>(op);
ShapedType input_type = tfl_round_op.x().getType().dyn_cast<ShapedType>();
if (!input_type) {
return op->emitOpError("Round: input not shaped tensor type");
}
if (input_type.getElementType().isa<FloatType>()) {
llvm::Optional<Value> result = convertRoundOp(
rewriter, op, tfl_round_op.getResult(), tfl_round_op.x());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
} else {
// Round on int is nonsensical. Instead, replace uses of result with the
// input.
tfl_round_op.replaceAllUsesWith(tfl_round_op.x());
return success();
}
}
LogicalResult ConvertTFLDivOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_div_op = cast<TFL::DivOp>(op);
ShapedType output_type =
tfl_div_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!output_type) return failure();
auto fused_activation_fn = tfl_div_op.fused_activation_functionAttr();
Type element_type = output_type.getElementType();
Value div_op;
if (element_type.isa<IntegerType>()) {
div_op = CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), output_type,
tfl_div_op.lhs(), tfl_div_op.rhs())
.getResult();
} else {
auto reciprocal_op = CreateOpAndInfer<tosa::ReciprocalOp>(
rewriter, op->getLoc(), tfl_div_op.rhs().getType(), tfl_div_op.rhs());
div_op = CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
tfl_div_op.lhs(),
reciprocal_op.getResult(), 0)
.getResult();
}
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val =
convertFusedActivation(rewriter, op, div_op, fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {div_op});
return success();
}
LogicalResult ConvertTFLMaximumOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_max_op = cast<TFL::MaximumOp>(op);
ShapedType input_lhs_type = tfl_max_op.lhs().getType().dyn_cast<ShapedType>();
ShapedType input_rhs_type = tfl_max_op.rhs().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_max_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped tensor output
if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
bool input_lhs_is_qtype =
input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool input_rhs_is_qtype =
input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_lhs_is_qtype != output_is_qtype ||
input_rhs_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLMaximumOp: input/output tensor should "
"be all quantized or all floating-point.");
}
Value output;
if (output_is_qtype) {
ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
Value op1_rescale_lhs =
buildRescaleToInt32(rewriter, op, tfl_max_op.lhs(), 1.0f, 0);
Value op2_rescale_rhs =
buildRescaleToInt32(rewriter, op, tfl_max_op.rhs(), 1.0f, 0);
auto op3_max_op1_op2 = CreateOpAndInfer<tosa::MaximumOp>(
rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
Value op4_rescale_op3 = buildRescaleFromInt32(
rewriter, op, output_type, op3_max_op1_op2.getResult(), 1.0f, 0);
output = op4_rescale_op3;
} else {
auto op1_max_in =
CreateOpAndInfer<tosa::MaximumOp>(rewriter, op->getLoc(), output_type,
tfl_max_op.lhs(), tfl_max_op.rhs());
output = op1_max_in.getResult();
}
rewriter.replaceOp(op, {output});
return success();
}
LogicalResult ConvertTFLMinimumOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_min_op = cast<TFL::MinimumOp>(op);
ShapedType input_lhs_type = tfl_min_op.lhs().getType().dyn_cast<ShapedType>();
ShapedType input_rhs_type = tfl_min_op.rhs().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_min_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped tensor output
if (!input_lhs_type || !input_rhs_type || !output_type) return failure();
bool input_lhs_is_qtype =
input_lhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool input_rhs_is_qtype =
input_rhs_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_lhs_is_qtype != output_is_qtype ||
input_rhs_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLMinimumOp: input/output tensor should "
"be all quantized or all floating-point.");
}
Value output;
if (output_is_qtype) {
ShapedType rescale_type = output_type.clone(rewriter.getI32Type());
Value op1_rescale_lhs =
buildRescaleToInt32(rewriter, op, tfl_min_op.lhs(), 1.0f, 0);
Value op2_rescale_rhs =
buildRescaleToInt32(rewriter, op, tfl_min_op.rhs(), 1.0f, 0);
auto op3_min_op1_op2 = CreateOpAndInfer<tosa::MinimumOp>(
rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs);
Value op4_rescale_op3 = buildRescaleFromInt32(
rewriter, op, output_type, op3_min_op1_op2.getResult(), 1.0f, 0);
output = op4_rescale_op3;
} else {
auto op1_min_in =
CreateOpAndInfer<tosa::MinimumOp>(rewriter, op->getLoc(), output_type,
tfl_min_op.lhs(), tfl_min_op.rhs());
output = op1_min_in.getResult();
}
rewriter.replaceOp(op, {output});
return success();
}
LogicalResult ConvertTFLFloorDivOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_floordiv_op = cast<TFL::FloorDivOp>(op);
llvm::Optional<Value> result =
convertFloorDivOp(rewriter, op, tfl_floordiv_op.getResult(),
tfl_floordiv_op.lhs(), tfl_floordiv_op.rhs());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLFloorModOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_floormod_op = cast<TFL::FloorModOp>(op);
llvm::Optional<Value> result =
convertFloorModOp(rewriter, op, tfl_floormod_op.getResult(),
tfl_floormod_op.lhs(), tfl_floormod_op.rhs());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLAddNOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_addn_op = cast<TFL::AddNOp>(op);
ShapedType output_type =
tfl_addn_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped output
if (!output_type) return failure();
SmallVector<Value> inputs(tfl_addn_op.inputs());
assert(inputs.size() >= 2);
auto newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(),
output_type, inputs[0], inputs[1]);
for (int i = 2; i < inputs.size(); i++) {
newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
inputs[i], newOp.getResult());
}
rewriter.replaceOp(op, {newOp.getResult()});
return success();
}
LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_avgpool_op = cast<TFL::AveragePool2DOp>(op);
ShapedType input_type =
tfl_avgpool_op.input().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_avgpool_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped output
if (!output_type) return failure();
// Kernels and strides are dimensionally ordered
SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
ArrayAttr kernel_size;
ArrayAttr stride;
ArrayAttr pad;
{
int64_t kernel_h = tfl_avgpool_op.filter_height();
int64_t kernel_w = tfl_avgpool_op.filter_width();
kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
// i64array is formatted as NHWC now
i64array[1] = kernel_h;
i64array[2] = kernel_w;
}
{
int64_t stride_h = tfl_avgpool_op.stride_h();
int64_t stride_w = tfl_avgpool_op.stride_w();
stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
}
{
tensorflow::Padding tf_pad;
if (!GetPaddingFromString(tfl_avgpool_op.padding().str(), &tf_pad).ok())
return failure();
// Pooling has no non-unit dilation
ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
RankedTensorType filter_type = RankedTensorType::get(
llvm::makeArrayRef(i64array), rewriter.getIntegerType(64));
// TFLite doesn't support explicit padding
if (!getPaddingValuesFromPadType(
tf_pad,
tensorflow::FORMAT_NHWC, // TFLite only supports this
1, // tensorflow::FORMAT_OHWI,
input_type, filter_type, stride, dilation, rewriter, pad))
return failure();
}
CreateReplaceOpAndInfer<tosa::AvgPool2dOp>(rewriter, op, output_type,
tfl_avgpool_op.input(),
kernel_size, stride, pad);
return success();
}
LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_maxpool_op = cast<TFL::MaxPool2DOp>(op);
ShapedType input_type =
tfl_maxpool_op.input().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_maxpool_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped type
if (!output_type) return failure();
// Kernels and strides are dimensionally ordered
SmallVector<int64_t, 4> i64array({1, 1, 1, 1});
ArrayAttr kernel_size;
ArrayAttr stride;
ArrayAttr pad;
{
int64_t kernel_h = tfl_maxpool_op.filter_height();
int64_t kernel_w = tfl_maxpool_op.filter_width();
kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
// i64array is formatted as NHWC now
i64array[1] = kernel_h;
i64array[2] = kernel_w;
}
{
int64_t stride_h = tfl_maxpool_op.stride_h();
int64_t stride_w = tfl_maxpool_op.stride_w();
stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
}
{
tensorflow::Padding tf_pad;
if (!GetPaddingFromString(tfl_maxpool_op.padding().str(), &tf_pad).ok())
return failure();
// Pooling has no non-unit dilation
ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
RankedTensorType filter_type =
RankedTensorType::get(i64array, rewriter.getIntegerType(64));
// TFLite doesn't support explicit padding
if (!getPaddingValuesFromPadType(
tf_pad,
tensorflow::FORMAT_NHWC, // TFLite only supports this
1, // tensorflow::FORMAT_OHWI,
input_type, filter_type, stride, dilation, rewriter, pad))
return failure();
}
CreateReplaceOpAndInfer<tosa::MaxPool2dOp>(rewriter, op, output_type,
tfl_maxpool_op.input(),
kernel_size, stride, pad);
return success();
}
LogicalResult ConvertTFLConv2DOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_conv2d_op = cast<TFL::Conv2DOp>(op);
RankedTensorType input_type =
tfl_conv2d_op.input().getType().dyn_cast<RankedTensorType>();
RankedTensorType filter_type =
tfl_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
ShapedType output_type =
tfl_conv2d_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!input_type) return failure();
if (!output_type) return failure();
if (!filter_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::QuantizedType>();
bool filter_is_qtype =
filter_type.getElementType().isa<mlir::quant::QuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::QuantizedType>();
if ((input_is_qtype != filter_is_qtype) ||
(input_is_qtype != output_is_qtype)) {
return op->emitOpError(
"ConvertTFLConv2DOp: input/filter/output tensor should "
"be all quantized or all floating-point.");
}
ArrayAttr pad;
ArrayAttr stride;
ArrayAttr dilation;
{
int64_t stride_h = tfl_conv2d_op.stride_h();
int64_t stride_w = tfl_conv2d_op.stride_w();
stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
}
{
int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
}
{
tensorflow::Padding tf_pad;
if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
return failure();
// TFLite doesn't support explicit padding
if (!getPaddingValuesFromPadType(
tf_pad,
tensorflow::FORMAT_NHWC, // TFLite only supports this
1, // tensorflow::FORMAT_OHWI,
input_type, filter_type, stride, dilation, rewriter, pad))
return failure();
}
Value unquantized_bias = tfl_conv2d_op.bias();
auto a1_conv2d_op = CreateOpAndInfer<tosa::Conv2DOp>(
rewriter, op->getLoc(), output_type, tfl_conv2d_op.input(),
tfl_conv2d_op.filter(), unquantized_bias, pad, stride, dilation);
Value conv2d_output;
if (input_is_qtype) {
conv2d_output =
buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
input_type, filter_type, output_type);
} else {
conv2d_output = a1_conv2d_op.getResult();
}
auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val = convertFusedActivation(
rewriter, op, conv2d_output, fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {conv2d_output});
return success();
}
LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_conv_op = cast<TFL::TransposeConvOp>(op);
ShapedType input_type = tfl_conv_op.input().getType().dyn_cast<ShapedType>();
ShapedType filter_type =
tfl_conv_op.weights().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_conv_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!input_type) return failure();
if (!output_type) return failure();
if (!filter_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::QuantizedType>();
bool filter_is_qtype =
filter_type.getElementType().isa<mlir::quant::QuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::QuantizedType>();
if ((input_is_qtype != filter_is_qtype) ||
(input_is_qtype != output_is_qtype)) {
return op->emitOpError(
"ConvertTFLConv2DOp: input/filter/output tensor should "
"be all quantized or all floating-point.");
}
ArrayAttr stride;
ArrayAttr dilation;
ArrayAttr outpad;
ArrayAttr output_shape;
{
int64_t stride_h = tfl_conv_op.stride_h();
int64_t stride_w = tfl_conv_op.stride_w();
stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
}
// tfl.transpose_conv doesn't support dilations
dilation = rewriter.getI64ArrayAttr({1, 1});
{
tensorflow::Padding tf_pad;
if (!GetPaddingFromString(tfl_conv_op.padding().str(), &tf_pad).ok())
return failure();
if (!getTransposeConv2dPaddingValues(
tf_pad,
tensorflow::FORMAT_NHWC, // TFLite only supports this
1, // tensorflow::FORMAT_OHWI,
input_type, filter_type, output_type, stride, dilation, rewriter,
outpad))
return failure();
}
{
ElementsAttr output_shape_elems;
// Match from input_size tensor first
if (matchPattern(tfl_conv_op.output_shape(),
m_Constant(&output_shape_elems))) {
SmallVector<int64_t> shape_vec;
for (int i = 0; i < output_shape_elems.getNumElements(); i++)
shape_vec.push_back(
output_shape_elems.getValue<IntegerAttr>(i).getInt());
output_shape = rewriter.getI64ArrayAttr(shape_vec);
} else if (output_type.hasRank()) {
// Use output tensor's shape otherwise
output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
} else {
// TODO(suderman): Figure out rankless shape propagation.
return failure();
}
}
int output_channel = 0;
// TODO(suderman): We need to figure out how to guarantee output channel
// propagation.
if (output_type.hasRank()) {
output_channel = output_type.getDimSize(3);
} else if (filter_type.hasRank()) {
output_channel = filter_type.getDimSize(0);
} else {
return failure();
}
llvm::Optional<Value> zero_bias;
if (input_is_qtype) {
uint32_t input_bits = input_type.getElementType()
.dyn_cast<mlir::quant::QuantizedType>()
.getStorageTypeIntegralWidth();
uint32_t weight_bits = filter_type.getElementType()
.dyn_cast<mlir::quant::QuantizedType>()
.getStorageTypeIntegralWidth();
if (input_bits == 16 && weight_bits == 8) {
SmallVector<APInt> vec(output_channel, APInt(48, 0, true));
zero_bias = getConstTensor<APInt>(rewriter, op, vec, {output_channel});
} else {
SmallVector<int32_t> vec(output_channel, 0);
zero_bias = getConstTensor<int32_t>(rewriter, op, vec, {output_channel});
}
} else {
SmallVector<float> vec(output_channel, 0.0f);
zero_bias = getConstTensor<float>(rewriter, op, vec, {output_channel});
}
if (!zero_bias) return failure();
auto a1_conv2d_op = CreateOpAndInfer<tosa::TransposeConv2DOp>(
rewriter, op->getLoc(), output_type, tfl_conv_op.input(),
tfl_conv_op.weights(), zero_bias.getValue(), outpad, stride, dilation,
output_shape);
Value conv2d_output;
if (input_is_qtype) {
conv2d_output =
buildRescaleOpConvOutput(rewriter, op, a1_conv2d_op.getResult(),
input_type, filter_type, output_type);
} else {
conv2d_output = a1_conv2d_op.getResult();
}
rewriter.replaceOp(op, {conv2d_output});
return success();
}
LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_conv2d_op = cast<TFL::DepthwiseConv2DOp>(op);
ShapedType input_type =
tfl_conv2d_op.input().getType().dyn_cast<ShapedType>();
ShapedType filter_type =
tfl_conv2d_op.filter().getType().dyn_cast<ShapedType>();
ShapedType output_type =
tfl_conv2d_op.getResult().getType().dyn_cast<ShapedType>();
// Not a shaped output
if (!input_type) return failure();
if (!output_type) return failure();
if (!filter_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::QuantizedType>();
bool filter_is_qtype =
filter_type.getElementType().isa<mlir::quant::QuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::QuantizedType>();
if ((input_is_qtype != filter_is_qtype) ||
(input_is_qtype != output_is_qtype)) {
return op->emitOpError(
"ConvertTFLConv2DOp: input/filter/output tensor should "
"be all quantized or all floating-point.");
}
// We need the filter shape to compute the transpose.
if (!filter_type.hasRank()) return failure();
auto filter_shape = filter_type.getShape();
// Operator depthwiseConv2D
// TFLite orders the depthwiseConv2D filter in IHWO, while TOSA orders
// filter in HWIO
//
// The lowering reorders the filter.
//
// a1_transpose = tosa.transpose(filter, {1, 2, 3, 0}) // HWIO
// a2_reshape = tosa.reshape(filter, H, W, depth_multiplier, I /
// depth_multiplier)
// a3_transpose_conv2d = tosa.transpose_conv2d(input, a2_reshape, padding,
// stride, dilation)
ArrayAttr pad;
ArrayAttr stride;
ArrayAttr dilation;
auto depth_multiplier = tfl_conv2d_op.depth_multiplierAttr();
{
int64_t stride_h = tfl_conv2d_op.stride_h();
int64_t stride_w = tfl_conv2d_op.stride_w();
stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
}
{
int64_t dilation_h = tfl_conv2d_op.dilation_h_factor();
int64_t dilation_w = tfl_conv2d_op.dilation_w_factor();
dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
}
{
tensorflow::Padding tf_pad;
if (!GetPaddingFromString(tfl_conv2d_op.padding().str(), &tf_pad).ok())
return failure();
if (!getPaddingValuesFromPadType(
tf_pad,
tensorflow::FORMAT_NHWC, // TFLite only supports this
1, // tensorflow::FORMAT_OHWI,
input_type, filter_type, stride, dilation, rewriter, pad))
return failure();
}
SmallVector<int64_t, 4> a1_transpose_dims;
a1_transpose_dims.push_back(filter_shape[1]);
a1_transpose_dims.push_back(filter_shape[2]);
a1_transpose_dims.push_back(filter_shape[3]);
a1_transpose_dims.push_back(filter_shape[0]);
SmallVector<int64_t, 4> a2_reshape_dims;
a2_reshape_dims.push_back(a1_transpose_dims[0]);
a2_reshape_dims.push_back(a1_transpose_dims[1]);
a2_reshape_dims.push_back(a1_transpose_dims[2] / depth_multiplier.getInt());
a2_reshape_dims.push_back(depth_multiplier.getInt());
llvm::Optional<Value> a1_filter_transpose_perms = getConstTensor<int32_t>(
rewriter, op, /*vec=*/{1, 2, 3, 0}, /*shape=*/{4});
if (!a1_filter_transpose_perms) return failure();
auto a1_filter_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
filter_type.getElementType()),
tfl_conv2d_op.filter(), a1_filter_transpose_perms.getValue());
auto a2_filter_reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(ArrayRef<int64_t>(a2_reshape_dims),
filter_type.getElementType()),
a1_filter_transpose_op.getResult(),
rewriter.getI64ArrayAttr(a2_reshape_dims));
Value unquantized_bias = tfl_conv2d_op.bias();
auto a3_depthwise_conv2d_op = CreateOpAndInfer<tosa::DepthwiseConv2DOp>(
rewriter, op->getLoc(), output_type, tfl_conv2d_op.input(),
a2_filter_reshape_op.getResult(), unquantized_bias, pad, stride,
dilation);
Value conv2d_output;
if (input_is_qtype) {
conv2d_output = buildRescaleOpConvOutput(
rewriter, op, a3_depthwise_conv2d_op.getResult(), input_type,
filter_type, output_type);
} else {
conv2d_output = a3_depthwise_conv2d_op.getResult();
}
auto fused_activation_fn = tfl_conv2d_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val = convertFusedActivation(
rewriter, op, conv2d_output, fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {conv2d_output});
return success();
}
LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_fc_op = cast<TFL::FullyConnectedOp>(op);
ShapedType output_type =
tfl_fc_op.getResult(0).getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!output_type) return failure();
RankedTensorType input_type =
tfl_fc_op.input().getType().dyn_cast<RankedTensorType>();
RankedTensorType filter_type =
tfl_fc_op.filter().getType().dyn_cast<RankedTensorType>();
RankedTensorType bias_type =
tfl_fc_op.bias().getType().dyn_cast<RankedTensorType>();
if (!input_type || !filter_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::QuantizedType>();
bool filter_is_qtype =
filter_type.getElementType().isa<mlir::quant::QuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::QuantizedType>();
if ((input_is_qtype != filter_is_qtype) ||
(input_is_qtype != output_is_qtype)) {
return op->emitOpError(
"ConvertTFLFullyConnectedOp: input/filter/output tensor should "
"be all quantized or all floating-point.");
}
Value input_val = tfl_fc_op.input();
// tfl.fully_connected() can takes various dimension tensor as input
// need to reshape it to rank 2 tensor, which tosa.fully_connected only
// supports if input tensor is rank 4. It's not always reshaping to (dim[0] *
// dim[1], dim[2] * dim[3]).
// In some networks it's reshaping to (dim[0], dim[1] * dim[2] * dim[3]) so a
// more general way to determine the reshape's shape is by looking at filter's
// shape[1].
if (input_type.getRank() != 2) {
int64_t num_elems = filter_type.getShape()[1];
int64_t num_batch = input_type.getNumElements() / num_elems;
SmallVector<int64_t, 2> shape_vals({num_batch, num_elems});
RankedTensorType reshape_type =
RankedTensorType::get(shape_vals, input_type.getElementType());
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), reshape_type, tfl_fc_op.input(),
rewriter.getI64ArrayAttr(shape_vals));
input_val = reshape_op.getResult();
}
Value bias_val;
if (!bias_type) {
// For some matmuls, the bias may actually be a "UnitType" which has no
// value. TOSA requires bias to be an array of output_channel_count values,
// so create a constant of the appropriate number and type of zeros.
SmallVector<int64_t, 1> bias_shape({filter_type.getShape()[0]});
RankedTensorType bias_type =
RankedTensorType::get(bias_shape, input_type.getElementType());
DenseElementsAttr bias_attr;
if (input_type.getElementType().isa<FloatType>()) {
SmallVector<float> bias_arr(bias_shape[0]);
for (int i = 0; i < bias_shape[0]; i++) {
bias_arr[i] = 0.0;
}
bias_attr =
DenseElementsAttr::get(bias_type, llvm::makeArrayRef(bias_arr));
} else {
SmallVector<int32_t> bias_arr(bias_shape[0]);
for (int i = 0; i < bias_shape[0]; i++) {
bias_arr[i] = 0;
}
bias_attr =
DenseElementsAttr::get(bias_type, llvm::makeArrayRef(bias_arr));
}
auto bias_op = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
bias_type, bias_attr);
bias_val = bias_op.getResult();
} else {
bias_val = tfl_fc_op.bias();
}
auto fc_op = CreateOpAndInfer<tosa::FullyConnectedOp>(
rewriter, op->getLoc(), output_type, input_val, tfl_fc_op.filter(),
bias_val);
Value fc_output;
if (input_is_qtype) {
fc_output = buildRescaleOpConvOutput(rewriter, op, fc_op.getResult(),
input_type, filter_type, output_type);
} else {
fc_output = fc_op.getResult();
}
auto fused_activation_fn = tfl_fc_op.fused_activation_functionAttr();
if (fused_activation_fn) {
llvm::Optional<Value> fused_activation_val =
convertFusedActivation(rewriter, op, fc_output, fused_activation_fn);
if (!fused_activation_val) return failure();
rewriter.replaceOp(op, {fused_activation_val.getValue()});
return success();
}
rewriter.replaceOp(op, {fc_output});
return success();
}
LogicalResult ConvertTFLConcatenationOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_concat_op = cast<TFL::ConcatenationOp>(op);
SmallVector<Value> values(tfl_concat_op.values());
IntegerAttr axis_attr;
{
auto tmpAttr = tfl_concat_op.axisAttr();
if (!tmpAttr) {
tmpAttr = rewriter.getI64IntegerAttr(0);
}
axis_attr = tmpAttr;
}
int32_t axis = axis_attr.getInt();
llvm::Optional<Value> result =
convertConcatV2Op(rewriter, op, tfl_concat_op.getResult(), values, axis);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLReshapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_reshape_op = cast<TFL::ReshapeOp>(op);
ShapedType output_type =
tfl_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a shaped tensor output
if (!output_type) return failure();
SmallVector<int64_t> shape_vals;
// Either the output type needs to be ranked or we need a constant input
// to compute the output rank.
ElementsAttr shape_attr;
if (!matchPattern(tfl_reshape_op.shape(), m_Constant(&shape_attr))) {
if (!output_type.hasRank()) return failure();
shape_vals.resize(output_type.getRank(), -1);
} else {
for (auto dim : shape_attr.getValues<int32_t>()) shape_vals.push_back(dim);
}
// Propagate the agreement between the output shape and constant value.
if (output_type.hasRank()) {
if (output_type.getRank() != shape_vals.size()) return failure();
for (int i = 0; i < output_type.getRank(); i++) {
if (shape_vals[i] == -1) shape_vals[i] = output_type.getDimSize(i);
}
}
// We cannot handle more than 1 dynamic dimension.
int64_t dynamic_count = 0;
for (auto val : shape_vals)
if (val == -1) dynamic_count++;
if (dynamic_count > 1) return failure();
ArrayAttr new_shape_attr = rewriter.getI64ArrayAttr(shape_vals);
CreateReplaceOpAndInfer<tosa::ReshapeOp>(
rewriter, op, output_type, tfl_reshape_op.input(), new_shape_attr);
return success();
}
LogicalResult ConvertTFLRankOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_rank_op = cast<TFL::RankOp>(op);
RankedTensorType input_type =
tfl_rank_op.input().getType().dyn_cast<RankedTensorType>();
if (!input_type) return failure();
int32_t rank = input_type.getRank();
RankedTensorType rank_type =
RankedTensorType::get({1}, rewriter.getIntegerType(32));
auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
auto rank_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
rank_type, rank_attr);
rewriter.replaceOp(op, {rank_const.getResult()});
return success();
}
LogicalResult ConvertTFLShapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_shape_op = cast<TFL::ShapeOp>(op);
RankedTensorType output_type =
tfl_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
RankedTensorType input_type =
tfl_shape_op.input().getType().dyn_cast<RankedTensorType>();
if (!input_type) return failure();
auto input_shape = input_type.getShape();
SmallVector<int32_t> shape_arr;
for (int i = 0; i < input_shape.size(); i++) {
shape_arr.emplace_back(input_shape[i]);
}
RankedTensorType shape_type = RankedTensorType::get(
{static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
auto shape_attr =
DenseElementsAttr::get(shape_type, llvm::makeArrayRef(shape_arr));
auto shape_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
shape_type, shape_attr);
rewriter.replaceOp(op, {shape_const.getResult()});
return success();
}
LogicalResult ConvertTFLExpandDimsOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_expanddims_op = cast<TFL::ExpandDimsOp>(op);
llvm::Optional<Value> result =
convertExpandDimsOp(rewriter, op, tfl_expanddims_op.getResult(),
tfl_expanddims_op.input(), tfl_expanddims_op.dim());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSqueezeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_squeeze_op = cast<TFL::SqueezeOp>(op);
// Copy squeeze_dims into int32_t array
auto squeeze_dims_attr = tfl_squeeze_op.squeeze_dimsAttr();
SmallVector<int32_t> squeeze_dims;
for (auto& squeeze_dim : squeeze_dims_attr) {
squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
}
llvm::Optional<Value> result =
convertSqueezeOp(rewriter, op, tfl_squeeze_op.getResult(),
tfl_squeeze_op.input(), squeeze_dims);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLFillOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_fill_op = cast<TFL::FillOp>(op);
RankedTensorType output_type =
tfl_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
ElementsAttr dims_elems;
if (!matchPattern(tfl_fill_op.dims(), m_Constant(&dims_elems)))
return failure();
SmallVector<int64_t> dims_vals;
uint32_t total_size = 1;
for (int i = 0; i < dims_elems.getNumElements(); i++) {
dims_vals.push_back(dims_elems.getValue<IntegerAttr>(i).getInt());
total_size *= dims_vals[i];
}
ElementsAttr value_elem;
if (!matchPattern(tfl_fill_op.input(), m_Constant(&value_elem)))
return failure();
RankedTensorType fill_type = RankedTensorType::get(
ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
DenseElementsAttr fill_attr;
// Convert to a compatible zero type.
if (value_elem.getType().getElementType().isa<FloatType>()) {
SmallVector<float> fill_arr(
total_size,
value_elem.getValue<FloatAttr>(0).getValue().convertToFloat());
fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
} else {
SmallVector<int32_t> fill_arr(
total_size,
value_elem.getValue<IntegerAttr>(0).getValue().getLimitedValue());
fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
}
auto fill_const_op = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
fill_type, fill_attr);
rewriter.replaceOp(op, {fill_const_op.getResult()});
return success();
}
LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_any_op = cast<TFL::ReduceAnyOp>(op);
RankedTensorType output_type =
tfl_any_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_any_op.reduction_indices(), m_Constant(&axes_elems)))
return failure();
bool keep_dims = false;
auto keep_dims_attr = tfl_any_op.keep_dimsAttr();
if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
llvm::Optional<Value> result = convertReduceAnyOp(
rewriter, op, output_type, tfl_any_op.input(), axes_elems, keep_dims);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_max_op = cast<TFL::ReduceMaxOp>(op);
RankedTensorType output_type =
tfl_max_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_max_op.axes(), m_Constant(&axes_elems)))
return failure();
bool keep_dims = false;
auto keep_dims_attr = tfl_max_op.keep_dimsAttr();
if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
llvm::Optional<Value> result = convertReduceMaxOp(
rewriter, op, output_type, tfl_max_op.input(), axes_elems, keep_dims);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLReduceMinOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_min_op = cast<TFL::ReduceMinOp>(op);
RankedTensorType output_type =
tfl_min_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_min_op.axes(), m_Constant(&axes_elems)))
return failure();
bool keep_dims = false;
auto keep_dims_attr = tfl_min_op.keep_dimsAttr();
if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
llvm::Optional<Value> result = convertReduceMinOp(
rewriter, op, output_type, tfl_min_op.input(), axes_elems, keep_dims);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLReduceProdOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_prod_op = cast<TFL::ReduceProdOp>(op);
RankedTensorType output_type =
tfl_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_prod_op.axes(), m_Constant(&axes_elems)))
return failure();
bool keep_dims = false;
auto keep_dims_attr = tfl_prod_op.keep_dimsAttr();
if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
llvm::Optional<Value> result = convertReduceProdOp(
rewriter, op, output_type, tfl_prod_op.input(), axes_elems, keep_dims);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLMeanOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_mean_op = cast<TFL::MeanOp>(op);
RankedTensorType output_type =
tfl_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_mean_op.axis(), m_Constant(&axes_elems)))
return failure();
bool keep_dims = false;
auto keep_dims_attr = tfl_mean_op.keep_dimsAttr();
if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
llvm::Optional<Value> result = convertReduceMeanOp(
rewriter, op, output_type, tfl_mean_op.input(), axes_elems, keep_dims);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSumOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_sum_op = cast<TFL::SumOp>(op);
RankedTensorType output_type =
tfl_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
ElementsAttr axes_elems;
if (!matchPattern(tfl_sum_op.axes(), m_Constant(&axes_elems)))
return failure();
bool keep_dims = false;
auto keep_dims_attr = tfl_sum_op.keep_dimsAttr();
if (keep_dims_attr) keep_dims = keep_dims_attr.getValue();
llvm::Optional<Value> result = convertReduceSumOp(
rewriter, op, output_type, tfl_sum_op.input(), axes_elems, keep_dims);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLEluOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_elu_op = cast<TFL::EluOp>(op);
llvm::Optional<Value> result =
convertEluOp(rewriter, op, tfl_elu_op.getResult(), tfl_elu_op.x());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSoftmaxOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_softmax_op = cast<TFL::SoftmaxOp>(op);
llvm::Optional<Value> result = convertSoftmaxOp(
rewriter, op, tfl_softmax_op.getResult(), tfl_softmax_op.input(),
tfl_softmax_op.betaAttr().getValueAsDouble());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSqrtOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_rsqrt_op = cast<TFL::SqrtOp>(op);
auto rsqrt = CreateOpAndInfer<tosa::RsqrtOp>(
rewriter, op->getLoc(), tfl_rsqrt_op.getType(), tfl_rsqrt_op.x());
CreateReplaceOpAndInfer<tosa::ReciprocalOp>(rewriter, op, rsqrt.getType(),
rsqrt);
return success();
}
LogicalResult ConvertTFLLogSoftmaxOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_logsoftmax_op = cast<TFL::LogSoftmaxOp>(op);
llvm::Optional<Value> result = convertLogSoftmaxOp(
rewriter, op, tfl_logsoftmax_op.getResult(), tfl_logsoftmax_op.input());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSliceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_slice_op = cast<TFL::SliceOp>(op);
RankedTensorType output_type =
tfl_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
ElementsAttr begin_elems, size_elems;
SmallVector<int64_t> begin_vals, size_vals;
if (!matchPattern(tfl_slice_op.begin(), m_Constant(&begin_elems)) ||
!matchPattern(tfl_slice_op.size(), m_Constant(&size_elems))) {
return failure();
}
for (int i = 0; i < begin_elems.getNumElements(); i++)
begin_vals.push_back(begin_elems.getValue<IntegerAttr>(i).getInt());
for (int i = 0; i < size_elems.getNumElements(); i++)
size_vals.push_back(size_elems.getValue<IntegerAttr>(i).getInt());
ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
CreateReplaceOpAndInfer<tosa::SliceOp>(rewriter, op, output_type,
tfl_slice_op.input(), begin, size);
return success();
}
LogicalResult ConvertTFLTileOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_tile_op = cast<TFL::TileOp>(op);
RankedTensorType output_type =
tfl_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
ElementsAttr multiples_elems;
if (!matchPattern(tfl_tile_op.multiples(), m_Constant(&multiples_elems)))
return failure();
SmallVector<int64_t> multiples_vals;
for (int i = 0; i < multiples_elems.getNumElements(); i++)
multiples_vals.push_back(multiples_elems.getValue<IntegerAttr>(i).getInt());
ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
CreateReplaceOpAndInfer<tosa::TileOp>(rewriter, op, output_type,
tfl_tile_op.input(), multiples_attr);
return success();
}
LogicalResult ConvertTFLTransposeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_transpose_op = cast<TFL::TransposeOp>(op);
Type output_type = tfl_transpose_op.getResult().getType();
CreateReplaceOpAndInfer<tosa::TransposeOp>(rewriter, op, output_type,
tfl_transpose_op.input(),
tfl_transpose_op.perm());
return success();
}
LogicalResult ConvertTFLPackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_pack_op = cast<TFL::PackOp>(op);
SmallVector<Value> inputs(tfl_pack_op.values());
assert(inputs.size() >= 2);
IntegerAttr axis_attr;
{
auto tmpAttr = tfl_pack_op.axisAttr();
if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
axis_attr = tmpAttr;
}
int32_t axis_i32 = axis_attr.getInt();
llvm::Optional<Value> result =
convertPackOp(rewriter, op, tfl_pack_op.getResult(), inputs, axis_i32);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLUnpackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_unpack_op = cast<TFL::UnpackOp>(op);
IntegerAttr axis_attr;
{
auto tmpAttr = tfl_unpack_op.axisAttr();
if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
axis_attr = tmpAttr;
}
int32_t axis_i32 = axis_attr.getInt();
llvm::Optional<SmallVector<Value>> results =
convertUnpackOp(rewriter, op, tfl_unpack_op.input(), axis_i32);
if (!results) return failure();
rewriter.replaceOp(op, results.getValue());
return success();
}
// Splits in num_split parts along split_dim
LogicalResult ConvertTFLSplitOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_split_op = cast<TFL::SplitOp>(op);
// Get the number of splits
int32_t num_split = -1;
auto numSplitAttr = tfl_split_op.num_splitsAttr();
if (numSplitAttr) {
num_split = numSplitAttr.getInt();
} else {
return failure();
}
// Get the axis
ElementsAttr axisAttrElems;
if (!matchPattern(tfl_split_op.split_dim(), m_Constant(&axisAttrElems))) {
return op->emitOpError("Cannot read split_dim elems");
}
// The axis/split_dim parameter is stored as a 0D tensor instead of
// an integer attribute in TFLite MLIR.
int32_t axis = axisAttrElems.getValue<IntegerAttr>({}).getInt();
llvm::Optional<SmallVector<Value>> results =
convertSplitOp(rewriter, op, tfl_split_op.getResult(0),
tfl_split_op.value(), num_split, axis);
if (!results) return failure();
rewriter.replaceOp(op, results.getValue());
return success();
}
// Splits in num_split parts along split_dim
LogicalResult ConvertTFLSplitVOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_splitv_op = cast<TFL::SplitVOp>(op);
// Get the size_splits array
SmallVector<int32_t> size_split;
ElementsAttr size_split_elems;
if (!matchPattern(tfl_splitv_op.size_splits(),
m_Constant(&size_split_elems))) {
return failure();
}
for (int i = 0; i < size_split_elems.getNumElements(); i++) {
size_split.push_back(size_split_elems.getValue<IntegerAttr>(i).getInt());
}
// Get the axis
ElementsAttr axisAttrElems;
if (!matchPattern(tfl_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
return op->emitOpError("Cannot read split_dim elems");
}
// The axis/split_dim parameter is stored as a 0D tensor instead of
// an integer attribute in TFLite MLIR.
int32_t axis = axisAttrElems.getValue<IntegerAttr>(0).getInt();
llvm::Optional<SmallVector<Value>> results =
convertSplitVOp(rewriter, op, tfl_splitv_op.getResult(0),
tfl_splitv_op.value(), size_split, axis);
if (!results) return failure();
rewriter.replaceOp(op, results.getValue());
return success();
}
LogicalResult ConvertTFLPadOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_pad_op = cast<TFL::PadOp>(op);
RankedTensorType output_type =
tfl_pad_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
auto pad_op =
CreateOpAndInfer<tosa::PadOp>(rewriter, op->getLoc(), output_type,
tfl_pad_op.input(), tfl_pad_op.padding());
rewriter.replaceOp(op, {pad_op.getResult()});
return success();
}
LogicalResult ConvertTFLResizeBilinearOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_resize_op = cast<TFL::ResizeBilinearOp>(op);
RankedTensorType output_type =
tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
llvm::Optional<Value> result = convertResizeOp(
rewriter, op, output_type, tfl_resize_op.input(), StringRef("BILINEAR"),
tfl_resize_op.align_cornersAttr().getValue(),
tfl_resize_op.half_pixel_centersAttr().getValue());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLResizeNearestNeighborOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_resize_op = cast<TFL::ResizeNearestNeighborOp>(op);
RankedTensorType output_type =
tfl_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
llvm::Optional<Value> result =
convertResizeOp(rewriter, op, output_type, tfl_resize_op.input(),
StringRef("NEAREST_NEIGHBOR"),
tfl_resize_op.align_cornersAttr().getValue(),
tfl_resize_op.half_pixel_centersAttr().getValue());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSelectOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_sel_op = cast<TFL::SelectOp>(op);
llvm::Optional<Value> result =
convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSelectV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_sel_op = cast<TFL::SelectV2Op>(op);
llvm::Optional<Value> result =
convertSelectOp(rewriter, op, tfl_sel_op.getResult(),
tfl_sel_op.condition(), tfl_sel_op.x(), tfl_sel_op.y());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSpaceToBatchNdOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_s2b_op = cast<TFL::SpaceToBatchNdOp>(op);
llvm::Optional<Value> result = convertSpaceToBatchNDOp(
rewriter, op, tfl_s2b_op.getResult(), tfl_s2b_op.input(),
tfl_s2b_op.block_shape(), tfl_s2b_op.paddings());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLBatchToSpaceNdOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_b2s_op = cast<TFL::BatchToSpaceNdOp>(op);
llvm::Optional<Value> result = convertBatchToSpaceNDOp(
rewriter, op, tfl_b2s_op.getResult(), tfl_b2s_op.input(),
tfl_b2s_op.block_shape(), tfl_b2s_op.indices());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLSpaceToDepthOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_s2d_op = cast<TFL::SpaceToDepthOp>(op);
auto block_size_attr = tfl_s2d_op.block_sizeAttr();
llvm::Optional<Value> result = convertSpaceToDepthOp(
rewriter, op, tfl_s2d_op.getResult(), tfl_s2d_op.input(), block_size_attr,
rewriter.getStringAttr("NHWC"));
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLDepthToSpaceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_d2s_op = cast<TFL::DepthToSpaceOp>(op);
auto block_size_attr = tfl_d2s_op.block_sizeAttr();
llvm::Optional<Value> result = convertDepthToSpaceOp(
rewriter, op, tfl_d2s_op.getResult(), tfl_d2s_op.input(), block_size_attr,
rewriter.getStringAttr("NHWC"));
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLStridedSliceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_ss_op = cast<TFL::StridedSliceOp>(op);
llvm::Optional<Value> result = convertStridedSliceOp(
rewriter, op, tfl_ss_op.getResult(), tfl_ss_op.input(), tfl_ss_op.begin(),
tfl_ss_op.end(), tfl_ss_op.strides(), tfl_ss_op.begin_maskAttr().getInt(),
tfl_ss_op.end_maskAttr().getInt(), tfl_ss_op.ellipsis_maskAttr().getInt(),
tfl_ss_op.new_axis_maskAttr().getInt(),
tfl_ss_op.shrink_axis_maskAttr().getInt());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLZerosLikeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_zeroslike_op = cast<TFL::ZerosLikeOp>(op);
llvm::Optional<Value> result = convertZerosLikeOp(
rewriter, op, tfl_zeroslike_op.getResult(), tfl_zeroslike_op.input());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLHardSwishOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_hardswish_op = cast<TFL::HardSwishOp>(op);
RankedTensorType output_type =
tfl_hardswish_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
RankedTensorType input_type =
tfl_hardswish_op.input().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!input_type) return failure();
// TFL hardswish: f(x) -> (x * relu6(x+3))/6
if (input_type.getElementType().isa<mlir::quant::QuantizedType>() &&
output_type.getElementType().isa<mlir::quant::QuantizedType>()) {
// Should match TFLite reference numerical behavior
mlir::quant::UniformQuantizedType input_qtype =
input_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
mlir::quant::UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
auto hardswish_func = [](double v) -> double {
double w = v + 3.0;
w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w;
return v * w / 6.0;
};
if (input_qtype.getStorageTypeIntegralWidth() == 8) {
// Implement with 8-bit table lookup.
Value table_const = getTosaConst8bitTable(
rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
output_qtype.getScale(), output_qtype.getZeroPoint(), hardswish_func);
CreateReplaceOpAndInfer<tosa::TableOp>(
rewriter, op, output_type, tfl_hardswish_op.input(), table_const);
}
} else {
// op1 = constop(3)
// op2 = add(x, op1)
// op3 = clamp(op2, 0, 6)
// op4 = mul(x, op3)
// op5 = reciprocal(6)
// op6 = mul (op4, op5)
Value op1_value = getTosaConstTensorSingleF32(rewriter, op, 3.0);
auto op2_add_x_op1 =
CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
tfl_hardswish_op.input(), op1_value);
auto op3_relu_op2_6 = CreateOpAndInfer<tosa::ClampOp>(
rewriter, op->getLoc(), output_type, op2_add_x_op1.getResult(),
rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(0),
rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(6.0f));
auto op4_mul_x_op3 = CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(), output_type, tfl_hardswish_op.input(),
op3_relu_op2_6.getResult(), 0);
auto const_6 = getTosaConstTensorSingleF32(rewriter, op, 6.0);
auto op5_reciprocal_6 = CreateOpAndInfer<tosa::ReciprocalOp>(
rewriter, op->getLoc(), const_6.getType(), const_6);
auto op6_mul_op4_op5 = CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(), output_type, op4_mul_x_op3.getResult(),
op5_reciprocal_6.getResult(), 0);
rewriter.replaceOp(op, {op6_mul_op4_op5.getResult()});
}
return success();
}
LogicalResult ConvertTFLLogisticOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_logistic_op = cast<TFL::LogisticOp>(op);
RankedTensorType output_type =
tfl_logistic_op.getResult().getType().dyn_cast<RankedTensorType>();
RankedTensorType input_type =
tfl_logistic_op.x().getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLLogisticOp: input/output tensor should "
"be all quantized or all floating-point.");
}
if (input_is_qtype) {
RankedTensorType int32_type = RankedTensorType::get(
output_type.getShape(), rewriter.getIntegerType(32));
mlir::quant::UniformQuantizedType input_qtype =
input_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
mlir::quant::UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
auto sigmoid_func = [](double x) -> double {
return 1.0 / (1.0 + std::exp(-x));
};
if (input_qtype.getStorageTypeIntegralWidth() == 8) {
Value table_const = getTosaConst8bitTable(
rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
output_qtype.getScale(), output_qtype.getZeroPoint(), sigmoid_func);
CreateReplaceOpAndInfer<tosa::TableOp>(rewriter, op, output_type,
tfl_logistic_op.x(), table_const);
} else { // int16
if (input_qtype.getZeroPoint() != 0 || output_qtype.getZeroPoint() != 0) {
op->emitOpError(
"ConvertTFLLogistic: input/output zeropoint should be 0 in 16-bit "
"mode");
return failure();
}
double input_min = -32768 * input_qtype.getScale();
double input_max = 32767 * input_qtype.getScale();
// Generate table with gen_lut() in
// tensorflow/lite/kernels/internal/common.h
Value table_const = getTosaConst16bitTable(rewriter, op, sigmoid_func,
input_min, input_max);
auto op1_table_in = CreateOpAndInfer<tosa::TableOp>(
rewriter, op->getLoc(), int32_type, tfl_logistic_op.x(), table_const);
Value op2_rescale_op1 =
buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
1.0 / 128.0, 0, 0, false, true);
rewriter.replaceOp(op, {op2_rescale_op1});
}
} else {
CreateReplaceOpAndInfer<tosa::SigmoidOp>(rewriter, op, output_type,
tfl_logistic_op.x());
}
return success();
}
LogicalResult ConvertTFLTanhOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_tanh_op = cast<TFL::TanhOp>(op);
RankedTensorType output_type =
tfl_tanh_op.getResult().getType().dyn_cast<RankedTensorType>();
RankedTensorType input_type =
tfl_tanh_op.input().getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) return failure();
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
return op->emitOpError(
"ConvertTFLTanhOp: input/output tensor should "
"be all quantized or all floating-point.");
}
if (input_is_qtype) {
RankedTensorType int32_type = RankedTensorType::get(
output_type.getShape(), rewriter.getIntegerType(32));
mlir::quant::UniformQuantizedType input_qtype =
input_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
mlir::quant::UniformQuantizedType output_qtype =
output_type.getElementType()
.dyn_cast_or_null<mlir::quant::UniformQuantizedType>();
auto tanh_func = [](double x) -> double {
x = std::exp(-2.0 * x);
return (1.0 - x) / (1.0 + x);
};
if (input_qtype.getStorageTypeIntegralWidth() == 8) {
Value table_const = getTosaConst8bitTable(
rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(),
output_qtype.getScale(), output_qtype.getZeroPoint(), tanh_func);
CreateReplaceOpAndInfer<tosa::TableOp>(rewriter, op, output_type,
tfl_tanh_op.input(), table_const);
} else { // int16
if (input_qtype.getZeroPoint() != 0 || output_qtype.getZeroPoint() != 0) {
op->emitOpError(
"ConvertTFLLogistic: input/output zeropoint should be 0 in 16-bit "
"mode");
return failure();
}
double input_min = -32768 * input_qtype.getScale();
double input_max = 32767 * input_qtype.getScale();
// Generate table with gen_lut() in
// tensorflow/lite/kernels/internal/common.h
Value table_const =
getTosaConst16bitTable(rewriter, op, tanh_func, input_min, input_max);
auto op1_table_in = CreateOpAndInfer<tosa::TableOp>(
rewriter, op->getLoc(), int32_type, tfl_tanh_op.input(), table_const);
Value op2_rescale_op1 =
buildRescale(rewriter, op, output_type, op1_table_in.getResult(),
1.0 / 128.0, 0, 0, false, true);
rewriter.replaceOp(op, {op2_rescale_op1});
}
} else {
CreateReplaceOpAndInfer<tosa::TanhOp>(rewriter, op, output_type,
tfl_tanh_op.input());
}
return success();
}
LogicalResult ConvertTFLPReluOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_prelu_op = cast<TFL::PReluOp>(op);
RankedTensorType output_type =
tfl_prelu_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
return failure();
}
LogicalResult ConvertTFLLeakyReluOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_leakyrelu_op = cast<TFL::LeakyReluOp>(op);
RankedTensorType input_type =
tfl_leakyrelu_op.input().getType().dyn_cast<RankedTensorType>();
RankedTensorType output_type =
tfl_leakyrelu_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) return failure();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
// Implement LeakyRelu as element-wise:
// out = x > 0 ? x : alpha * x
//
// In TOSA ops:
//
// const_zero = constant(0)
// op1 = mul(x, alpha)
// op2 = greater_equal(x, const_zero)
// out = select(a2, x, a1)
//
// If alpha can be constrained to 0.0 <= alpha <= 1.0, then
// an alternative simpler lowering could be implemented with:
//
// max(mul(x, alapha), x)
//
// But this alternative is not robust unless alpha meets those constraints.
FloatAttr tmpAttr = tfl_leakyrelu_op.alphaAttr();
// There is disagreement between the MLIR .td defaults and TF
// documentation on 0.2 vs 0.3, but 0.2 will be used here.
double alpha = 0.2;
if (tmpAttr) {
alpha = tmpAttr.getValueAsDouble();
}
if (output_is_qtype) {
// op1 = rescale(input)
// rescaled_alpha = (alpha << alpha_shift) // Remains within int32 range
// op2 = mul(rescaled_input, rescaled_alpha, alpha_shift)
// op3 = greater_equal(op1, 0)
// op4 = select(op3, op1, op2)
// out = rescale(op4)
RankedTensorType rescale_type =
RankedTensorType::get(output_type.getShape(), rewriter.getI32Type());
UniformQuantizedType input_qtype =
input_type.getElementType().cast<UniformQuantizedType>();
UniformQuantizedType output_qtype =
output_type.getElementType().cast<UniformQuantizedType>();
double scale_alpha =
input_qtype.getScale() * alpha / output_qtype.getScale();
double scale_identity = input_qtype.getScale() / output_qtype.getScale();
Value op1_rescale_in =
buildRescaleToInt32(rewriter, op, tfl_leakyrelu_op.input(), 1.0,
input_qtype.getZeroPoint());
Value const_zero = getTosaConstTensorSingleI32(rewriter, op, 0);
auto op2_ge = CreateOpAndInfer<tosa::GreaterEqualOp>(
rewriter, op->getLoc(),
RankedTensorType::get(rescale_type.getShape(), rewriter.getI1Type()),
op1_rescale_in, const_zero);
Value op3_rescale_alpha_in = buildRescale(
rewriter, op, output_type, tfl_leakyrelu_op.input(), scale_alpha,
input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true);
Value op4_rescale_identity_in = buildRescale(
rewriter, op, output_type, tfl_leakyrelu_op.input(), scale_identity,
input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true);
CreateReplaceOpAndInfer<tosa::SelectOp>(rewriter, op, output_type, op2_ge,
op4_rescale_identity_in,
op3_rescale_alpha_in);
return success();
} else {
Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0);
auto op1_mul = CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(), output_type, tfl_leakyrelu_op.input(),
getTosaConstTensorSingleF32(rewriter, op, alpha), 0);
auto op2_ge = CreateOpAndInfer<tosa::GreaterEqualOp>(
rewriter, op->getLoc(),
RankedTensorType::get(output_type.getShape(),
rewriter.getIntegerType(1)),
tfl_leakyrelu_op.input(), const_zero);
CreateReplaceOpAndInfer<tosa::SelectOp>(rewriter, op, output_type, op2_ge,
tfl_leakyrelu_op.input(),
op1_mul.getResult());
return success();
}
}
LogicalResult ConvertTFLNegOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_neg_op = cast<TFL::NegOp>(op);
RankedTensorType output_type =
tfl_neg_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!output_type) return failure();
CreateReplaceOpAndInfer<tosa::NegateOp>(rewriter, op, output_type,
tfl_neg_op.x());
return success();
}
LogicalResult ConvertTFLYieldOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
rewriter.replaceOpWithNewOp<tosa::YieldOp>(op, op->getResultTypes(),
op->getOperands());
return success();
}
LogicalResult ConvertTFLCustomOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_custom_op = cast<TFL::CustomOp>(op);
rewriter.replaceOpWithNewOp<tosa::CustomOp>(
op, op->getResultTypes(), tfl_custom_op.custom_code(), op->getOperands());
return success();
}
LogicalResult ConvertTFLReverseV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_reverse_op = cast<TFL::ReverseV2Op>(op);
RankedTensorType input_type =
tfl_reverse_op.input().getType().dyn_cast<RankedTensorType>();
RankedTensorType output_type =
tfl_reverse_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) return failure();
ElementsAttr axis_elems;
if (!matchPattern(tfl_reverse_op.axis(), m_Constant(&axis_elems)))
return failure();
auto input_rank = input_type.getShape().size();
Value val = tfl_reverse_op.input();
if (axis_elems.getNumElements() == 0) {
auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
rewriter, op->getLoc(), output_type, val);
val = identity_op.getResult();
} else {
for (int i = 0; i < axis_elems.getNumElements(); i++) {
int64_t axis_val = axis_elems.getValue<IntegerAttr>(i).getInt();
if (axis_val < 0) axis_val += input_rank;
auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
auto reverse_op = CreateOpAndInfer<tosa::ReverseOp>(
rewriter, op->getLoc(), output_type, val, axis_attr);
val = reverse_op.getResult();
}
}
rewriter.replaceOp(op, {val});
return success();
}
LogicalResult ConvertTFLQuantizeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_quantize_op = cast<TFL::QuantizeOp>(op);
RankedTensorType input_type =
tfl_quantize_op.input().getType().dyn_cast<RankedTensorType>();
RankedTensorType output_type =
tfl_quantize_op.getResult().getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) return failure();
ShapedType qtype =
tfl_quantize_op.getResult().getType().dyn_cast<ShapedType>();
if (!qtype) return failure();
UniformQuantizedType element_type =
qtype.getElementType().dyn_cast<UniformQuantizedType>();
if (!element_type) return failure();
UniformQuantizedType input_element_type =
input_type.getElementType().dyn_cast<UniformQuantizedType>();
// If input is already a quantized type, this is basically a RESCALE (or
// tensorflow::ops::Requantize)
if (input_element_type) {
double rescale_scale =
input_element_type.getScale() / element_type.getScale();
Value rescale_op =
buildRescale(rewriter, op, output_type, tfl_quantize_op.input(),
rescale_scale, input_element_type.getZeroPoint(),
element_type.getZeroPoint(), true, true);
rewriter.replaceOp(op, {rescale_op});
return success();
} else {
double scale = 1 / element_type.getScale();
int64_t zp = element_type.getZeroPoint();
int64_t num_bits = element_type.getStorageTypeIntegralWidth();
zp = element_type.isSigned() ? zp : zp - (1 << (num_bits - 1));
llvm::Optional<Value> result = convertQuantizeOp(
rewriter, op, output_type, tfl_quantize_op.input(), scale, zp);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
}
LogicalResult ConvertTFLDequantizeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_dequantize_op = cast<TFL::DequantizeOp>(op);
RankedTensorType output_type =
tfl_dequantize_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
RankedTensorType qtype =
tfl_dequantize_op.input().getType().dyn_cast<RankedTensorType>();
if (!qtype) return failure();
Type element_type = qtype.getElementType();
if (element_type.isa<FloatType>()) {
CreateReplaceOpAndInfer<tosa::CastOp>(rewriter, op, output_type,
tfl_dequantize_op.input());
return success();
}
if (auto eq_ty = element_type.dyn_cast<quant::UniformQuantizedType>()) {
double scale = eq_ty.getScale();
int64_t zp = eq_ty.getZeroPoint();
int64_t num_bits = eq_ty.getStorageTypeIntegralWidth();
zp = eq_ty.isSigned() ? zp : zp - (1 << (num_bits - 1));
llvm::Optional<Value> result = convertDequantizeOp(
rewriter, op, output_type, tfl_dequantize_op.input(), scale, zp, 0);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
if (quant::UniformQuantizedPerAxisType eq_ty =
element_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
SmallVector<float> zps;
for (auto zp : eq_ty.getZeroPoints()) {
int64_t num_bits = eq_ty.getStorageTypeIntegralWidth();
zps.push_back(eq_ty.isSigned() ? zp : zp - (1 << (num_bits - 1)));
}
SmallVector<float> scales;
for (auto scale : eq_ty.getScales()) {
scales.push_back(scale);
}
llvm::Optional<Value> result = convertDequantizeOp(
rewriter, op, output_type, tfl_dequantize_op.input(), scales, zps,
eq_ty.getQuantizedDimension());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
return failure();
}
LogicalResult ConvertTFLQConstOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_qconst_op = cast<TFL::QConstOp>(op);
RankedTensorType output_type =
tfl_qconst_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
ElementsAttr elements = tfl_qconst_op.value();
Type element_type = elements.getType().getElementType();
if (output_type.getElementType().isa<quant::QuantizedType>()) {
output_type = RankedTensorType::get(output_type.getShape(), element_type);
}
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, elements);
return success();
}
LogicalResult ConvertConstantOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_const_op = cast<ConstantOp>(op);
ShapedType output_type =
tfl_const_op.getResult().getType().dyn_cast<ShapedType>();
// Not a ranked tensor output
if (!output_type) return failure();
ElementsAttr attr = tfl_const_op.valueAttr().dyn_cast<ElementsAttr>();
auto e_type = output_type.getElementType();
// TOSA only support up to 48-bits
// If source is higher than that, it's not representabble.
// For data type like 64 bits, we need to truncate them into 48 bits.
if (e_type.isInteger(64)) {
e_type = rewriter.getIntegerType(48);
attr = attr.cast<DenseIntOrFPElementsAttr>().mapValues(
e_type, [](const APInt& x) -> APInt { return x.trunc(48); });
}
output_type = output_type.clone(e_type);
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, output_type, attr);
return success();
}
LogicalResult ConvertTFLGatherOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_gather_op = cast<TFL::GatherOp>(op);
int32_t axis = tfl_gather_op.axisAttr().getInt();
int32_t batch_dims = 0; // Not a parameter in tfl.Gather; default to 0.
llvm::Optional<Value> result = convertGatherOp(
rewriter, op, tfl_gather_op.getResult(), tfl_gather_op.params(),
tfl_gather_op.indices(), batch_dims, axis);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLGatherNdOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_gathernd_op = cast<TFL::GatherNdOp>(op);
llvm::Optional<Value> result =
convertGatherNdOp(rewriter, op, tfl_gathernd_op.getResult(),
tfl_gathernd_op.params(), tfl_gathernd_op.indices());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLOneHotOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_one_hot_op = cast<TFL::OneHotOp>(op);
ElementsAttr depth_elems;
if (!matchPattern(tfl_one_hot_op.depth(), m_Constant(&depth_elems)))
return failure();
int32_t depth = depth_elems.getValue<IntegerAttr>({}).getInt();
IntegerAttr axisAttr = tfl_one_hot_op.axisAttr();
int32_t axis = axisAttr.getInt();
llvm::Optional<Value> result = convertOneHotOp(
rewriter, op, tfl_one_hot_op.getResult(), tfl_one_hot_op.indices(),
tfl_one_hot_op.on_value(), tfl_one_hot_op.off_value(), depth, axis);
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
LogicalResult ConvertTFLArgMaxOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto arg_max_op = cast<TFL::ArgMaxOp>(op);
ElementsAttr dim_elems;
if (!matchPattern(arg_max_op.dim(), m_Constant(&dim_elems))) return failure();
int32_t dim = dim_elems.getValue<IntegerAttr>({}).getInt();
CreateReplaceOpAndInfer<tosa::ArgMaxOp>(
rewriter, op, arg_max_op.getType(), arg_max_op.input(),
rewriter.getIntegerAttr(rewriter.getI64Type(), dim));
return success();
}
LogicalResult ConvertTFLFakeQuantOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto fakequant_op = cast<TFL::FakeQuantOp>(op);
RankedTensorType output_type =
fakequant_op.getResult().getType().dyn_cast<RankedTensorType>();
// Not a ranked tensor output
if (!output_type) return failure();
llvm::Optional<Value> result =
convertFakeQuantOp(rewriter, op, output_type, fakequant_op.input(),
fakequant_op.minAttr().getValueAsDouble(),
fakequant_op.maxAttr().getValueAsDouble(),
fakequant_op.num_bitsAttr().getInt(),
fakequant_op.narrow_rangeAttr().getValue());
if (!result) return failure();
rewriter.replaceOp(op, {result.getValue()});
return success();
}
void LegalizeTFL::runOnFunction() {
ConversionTarget target(getContext());
target.addIllegalDialect<TFL::TensorFlowLiteDialect>();
target.addIllegalOp<quant::StatisticsOp>();
// Operations are legal if they don't contain any illegal type.
target.markUnknownOpDynamicallyLegal([](Operation* op) {
if (auto constantOp = dyn_cast<ConstantOp>(op)) {
return constantOp.getType().isa<NoneType>();
}
return true;
});
auto* ctx = &getContext();
auto func = getFunction();
RewritePatternSet patterns(&getContext());
// Add the generated patterns to the list.
populateWithGenerated(patterns);
#define DEF_PATTERN_INSERT(PAT) patterns.insert<Convert##PAT##Op>(ctx);
DEF_PATTERN_INSERT(TFLRelu);
DEF_PATTERN_INSERT(TFLRelu6);
DEF_PATTERN_INSERT(TFLEqual);
DEF_PATTERN_INSERT(TFLNotEqual);
DEF_PATTERN_INSERT(TFLGreater);
DEF_PATTERN_INSERT(TFLGreaterEqual);
DEF_PATTERN_INSERT(TFLAdd);
DEF_PATTERN_INSERT(TFLSub);
DEF_PATTERN_INSERT(TFLMul);
DEF_PATTERN_INSERT(TFLSquare);
DEF_PATTERN_INSERT(TFLSquaredDifference);
DEF_PATTERN_INSERT(TFLDiv);
DEF_PATTERN_INSERT(TFLMaximum);
DEF_PATTERN_INSERT(TFLMinimum);
DEF_PATTERN_INSERT(TFLFloorMod);
DEF_PATTERN_INSERT(TFLFloorDiv);
DEF_PATTERN_INSERT(TFLAddN);
DEF_PATTERN_INSERT(TFLAveragePool2D);
DEF_PATTERN_INSERT(TFLMaxPool2D);
DEF_PATTERN_INSERT(TFLConcatenation);
DEF_PATTERN_INSERT(TFLReshape);
DEF_PATTERN_INSERT(TFLRank);
DEF_PATTERN_INSERT(TFLShape);
DEF_PATTERN_INSERT(TFLExpandDims);
DEF_PATTERN_INSERT(TFLSqueeze);
DEF_PATTERN_INSERT(TFLFill);
DEF_PATTERN_INSERT(TFLElu);
DEF_PATTERN_INSERT(TFLSoftmax);
DEF_PATTERN_INSERT(TFLLogSoftmax);
DEF_PATTERN_INSERT(TFLSqrt);
DEF_PATTERN_INSERT(TFLReduceAny);
DEF_PATTERN_INSERT(TFLReduceMax);
DEF_PATTERN_INSERT(TFLReduceMin);
DEF_PATTERN_INSERT(TFLMean);
DEF_PATTERN_INSERT(TFLReduceProd);
DEF_PATTERN_INSERT(TFLSum);
DEF_PATTERN_INSERT(TFLConv2D);
DEF_PATTERN_INSERT(TFLTransposeConv);
DEF_PATTERN_INSERT(TFLDepthwiseConv2D);
DEF_PATTERN_INSERT(TFLFullyConnected);
DEF_PATTERN_INSERT(TFLSplit);
DEF_PATTERN_INSERT(TFLSplitV);
DEF_PATTERN_INSERT(TFLPack);
DEF_PATTERN_INSERT(TFLUnpack);
DEF_PATTERN_INSERT(TFLTranspose);
DEF_PATTERN_INSERT(TFLTile);
DEF_PATTERN_INSERT(TFLSlice);
DEF_PATTERN_INSERT(TFLStridedSlice);
DEF_PATTERN_INSERT(TFLZerosLike);
DEF_PATTERN_INSERT(TFLHardSwish);
DEF_PATTERN_INSERT(TFLLess);
DEF_PATTERN_INSERT(TFLLessEqual);
DEF_PATTERN_INSERT(TFLPad);
DEF_PATTERN_INSERT(TFLResizeBilinear);
DEF_PATTERN_INSERT(TFLResizeNearestNeighbor);
DEF_PATTERN_INSERT(TFLSelect);
DEF_PATTERN_INSERT(TFLSelectV2);
DEF_PATTERN_INSERT(TFLSpaceToBatchNd);
DEF_PATTERN_INSERT(TFLBatchToSpaceNd);
DEF_PATTERN_INSERT(TFLSpaceToDepth);
DEF_PATTERN_INSERT(TFLDepthToSpace);
DEF_PATTERN_INSERT(TFLLogistic);
DEF_PATTERN_INSERT(TFLTanh);
DEF_PATTERN_INSERT(TFLPRelu);
DEF_PATTERN_INSERT(TFLLeakyRelu);
DEF_PATTERN_INSERT(TFLNeg);
DEF_PATTERN_INSERT(TFLYield);
DEF_PATTERN_INSERT(TFLCustom);
DEF_PATTERN_INSERT(TFLReverseV2);
DEF_PATTERN_INSERT(TFLQuantize);
DEF_PATTERN_INSERT(TFLDequantize);
DEF_PATTERN_INSERT(TFLQConst);
DEF_PATTERN_INSERT(Constant);
DEF_PATTERN_INSERT(TFLGather);
DEF_PATTERN_INSERT(TFLGatherNd);
DEF_PATTERN_INSERT(TFLArgMax);
DEF_PATTERN_INSERT(TFLFakeQuant);
DEF_PATTERN_INSERT(TFLOneHot);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
signalPassFailure();
}
// Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
// the FuncOp type.
IRRewriter rewriter(func.getContext());
func.walk([&](ReturnOp op) {
FuncOp parent = dyn_cast<FuncOp>(op->getParentOp());
if (!parent) return;
rewriter.setInsertionPoint(op);
FunctionType funcTy = func.getType();
auto resultTys = funcTy.getResults();
bool castAdded = false;
SmallVector<Value> castedValues;
for (auto it : llvm::zip(op->getOperands(), resultTys)) {
auto operand = std::get<0>(it);
auto currentTy = operand.getType();
auto castTy = std::get<1>(it);
if (currentTy == castTy) {
castedValues.push_back(operand);
continue;
}
castedValues.push_back(
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand)
.getResult());
castAdded = true;
}
if (castAdded) {
rewriter.replaceOpWithNewOp<ReturnOp>(op, castedValues);
}
});
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass() {
return std::make_unique<LegalizeTFL>();
}
} // namespace tosa
} // namespace mlir