Add a Legalization Enable Option to the TFLite to TOSA pass in order to add rewrites dynamically
Change-Id: I24c29e1e3488a6ebf2a0433a749bd487918e6e3b
Signed-off-by: Aaron DeBattista <aaron.debattista@arm.com>
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
index b27fedb..a9a1279 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
@@ -31,6 +31,11 @@
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Transforms/DialectConversion.h"
+
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
#include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
@@ -40,16 +45,6 @@
#define DEBUG_TYPE PASS_NAME
#define HARDSWISH_EXPLICIT_RESCALING false
-// Conditionally avoid converting some TFLite ops to TOSA.
-// By default, all conversions will be invoked.
-//
-// The denylist file lists patterns which are not legalized from TFLite to TOSA.
-llvm::cl::opt<std::string> tfl_tosa_denylist(
- "tfl-tosa-denylist",
- llvm::cl::desc("<a list of patterns not legalized from TFLite to TOSA>"),
- llvm::cl::init("transforms/tfl_tosa_denylist.txt"),
- llvm::cl::value_desc("pattern name"));
-
namespace mlir {
namespace tosa {
namespace {
@@ -59,12 +54,24 @@
// Performs lowering to TOSA dialect.
class LegalizeTFL : public TosaLegalizeTFLPassBase<LegalizeTFL> {
public:
- explicit LegalizeTFL() {}
+ explicit LegalizeTFL(std::unordered_set<std::string> &legalization_disable) {this->legalization_disable = legalization_disable;}
void runOnFunction() override;
+ private:
+ std::unordered_set<std::string> legalization_disable;
};
#include "tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.inc"
+// Input from tfl.conv2d takes 64 bits a bias, while tosa.conv2d expects 48
+// bits. Need to do a customized truncate here instead of tablegen to handle
+// attribute with negative value.
+struct ConvertConstantOp : public RewritePattern {
+ explicit ConvertConstantOp(MLIRContext* context)
+ : RewritePattern(arith::ConstantOp::getOperationName(), 1, context) {}
+ LogicalResult matchAndRewrite(Operation* op,
+ PatternRewriter& rewriter) const override;
+};
+
#define DECL_CONVERT_OP(tfl_op) \
struct ConvertTFL##tfl_op##Op : public RewritePattern { \
explicit ConvertTFL##tfl_op##Op(MLIRContext* context) \
@@ -155,19 +162,9 @@
DECL_CONVERT_OP(OneHot);
DECL_CONVERT_OP(ArgMax);
DECL_CONVERT_OP(FakeQuant);
+
#undef DECL_CONVERT_OP
-// Input from tfl.conv2d takes 64 bits a bias, while tosa.conv2d expects 48
-// bits. Need to do a customized truncate here instead of tablegen to handle
-// attribute with negative value.
-
-struct ConvertConstantOp : public RewritePattern {
- explicit ConvertConstantOp(MLIRContext* context)
- : RewritePattern(arith::ConstantOp::getOperationName(), 1, context) {}
- LogicalResult matchAndRewrite(Operation* op,
- PatternRewriter& rewriter) const override;
-};
-
LogicalResult ConvertTFLReluOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tfl_relu_op = cast<TFL::ReluOp>(op);
@@ -3095,21 +3092,42 @@
void LegalizeTFL::runOnFunction() {
OwningRewritePatternList patterns(&getContext());
- populateLegalizeTFLPatterns(&getContext(), patterns);
+ populateLegalizeTFLPatterns(&getContext(), patterns, legalization_disable);
auto func = getFunction();
if (ApplyPatternsWithShapeResolution(func, std::move(patterns)).failed()) {
signalPassFailure();
}
}
+
} // namespace
void populateLegalizeTFLPatterns(MLIRContext* ctx,
- RewritePatternSet& patterns) {
- // Add the generated patterns to the list.
- populateWithGenerated(patterns);
+ RewritePatternSet& patterns,
+ std::unordered_set<std::string> legalization_disable) {
+ // 0. If there is no entry for the Op in the legalization_disable set
+ // than we add the rewrite.
+ // 1. If there is an entry for the Op in the legalization_disable set
+ // than we skip adding the rewrite
+ #define DEF_PATTERN_INSERT(PAT) \
+ if (legalization_disable.find(#PAT) == legalization_disable.end()) \
+ patterns.insert<Convert##PAT##Op>(ctx);
-#define DEF_PATTERN_INSERT(PAT) patterns.insert<Convert##PAT##Op>(ctx);
+ DEF_PATTERN_INSERT(TFLAbs);
+ DEF_PATTERN_INSERT(TFLCeil);
+ DEF_PATTERN_INSERT(TFLFloor);
+ DEF_PATTERN_INSERT(TFLExp);
+ DEF_PATTERN_INSERT(TFLLog);
+ DEF_PATTERN_INSERT(TFLRsqrt);
+ DEF_PATTERN_INSERT(TFLLogicalNot);
+ DEF_PATTERN_INSERT(TFLCast);
+
+ DEF_PATTERN_INSERT(QuantStat);
+
+ DEF_PATTERN_INSERT(TFLLogicalAnd);
+ DEF_PATTERN_INSERT(TFLLogicalOr);
+ DEF_PATTERN_INSERT(TFLPow);
+
DEF_PATTERN_INSERT(TFLRelu);
DEF_PATTERN_INSERT(TFLRelu6);
DEF_PATTERN_INSERT(TFLEqual);
@@ -3121,6 +3139,7 @@
DEF_PATTERN_INSERT(TFLMul);
DEF_PATTERN_INSERT(TFLSquare);
DEF_PATTERN_INSERT(TFLSquaredDifference);
+ DEF_PATTERN_INSERT(TFLRound);
DEF_PATTERN_INSERT(TFLDiv);
DEF_PATTERN_INSERT(TFLMaximum);
DEF_PATTERN_INSERT(TFLMinimum);
@@ -3160,8 +3179,8 @@
DEF_PATTERN_INSERT(TFLTile);
DEF_PATTERN_INSERT(TFLSlice);
DEF_PATTERN_INSERT(TFLStridedSlice);
- DEF_PATTERN_INSERT(TFLZerosLike);
DEF_PATTERN_INSERT(TFLHardSwish);
+ DEF_PATTERN_INSERT(TFLZerosLike);
DEF_PATTERN_INSERT(TFLLess);
DEF_PATTERN_INSERT(TFLLessEqual);
DEF_PATTERN_INSERT(TFLPad);
@@ -3186,19 +3205,19 @@
DEF_PATTERN_INSERT(TFLDequantize);
DEF_PATTERN_INSERT(TFLConst);
DEF_PATTERN_INSERT(TFLQConst);
- DEF_PATTERN_INSERT(Constant);
DEF_PATTERN_INSERT(TFLGather);
DEF_PATTERN_INSERT(TFLGatherNd);
DEF_PATTERN_INSERT(TFLSparseToDense);
+ DEF_PATTERN_INSERT(Constant);
+ DEF_PATTERN_INSERT(TFLOneHot);
DEF_PATTERN_INSERT(TFLArgMax);
DEF_PATTERN_INSERT(TFLFakeQuant);
- DEF_PATTERN_INSERT(TFLOneHot);
-#undef DEF_PATTERN_INSERT
}
// Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass.
-std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass() {
- return std::make_unique<LegalizeTFL>();
+std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass(
+ std::unordered_set<std::string> legalization_disable) {
+ return std::make_unique<LegalizeTFL>(legalization_disable);
}
} // namespace tosa
diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.h b/tensorflow/compiler/mlir/tosa/transforms/passes.h
index 84d297c..fbf8235 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tosa/transforms/passes.h
@@ -17,6 +17,7 @@
#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H
#include <memory>
+#include <unordered_set>
#include "mlir/Pass/Pass.h" // from @llvm-project
@@ -24,11 +25,11 @@
namespace tosa {
void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns);
-void populateLegalizeTFLPatterns(MLIRContext* ctx, RewritePatternSet& patterns);
+void populateLegalizeTFLPatterns(MLIRContext* ctx, RewritePatternSet& patterns, std::unordered_set<std::string> legalization_disable = {});
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass();
std::unique_ptr<OperationPass<FuncOp>> createFuseBiasTFPass();
-std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass();
+std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass(std::unordered_set<std::string> legalization_disable = {});
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFTFLPass();
std::unique_ptr<OperationPass<FuncOp>> createConvertTFLUint8Pass();
std::unique_ptr<OperationPass<FuncOp>> createStripQuantTypesPass();
diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.td b/tensorflow/compiler/mlir/tosa/transforms/passes.td
index b657455..a06d47b 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/passes.td
+++ b/tensorflow/compiler/mlir/tosa/transforms/passes.td
@@ -23,7 +23,7 @@
def TosaLegalizeTFLPass : FunctionPass<"tosa-legalize-tfl"> {
let summary = "Legalize from TensorFlow Lite to TOSA";
- let constructor = "createLegalizeTFLPass()";
+ let constructor = "createLegalizeTFLPass({})";
let dependentDialects = ["TosaDialect"];
}
diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
index cbf84c0..df9b912 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
+++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
@@ -21,29 +21,31 @@
include "mlir/Dialect/Quant/QuantOps.td"
include "mlir/Dialect/Tosa/IR/TosaOps.td"
+//===----------------------------------------------------------------------===//
// Unary ops patterns.
-def : Pat<(TFL_AbsOp $arg), (Tosa_AbsOp $arg)>;
-def : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>;
-def : Pat<(TFL_FloorOp $arg), (Tosa_FloorOp $arg)>;
-def : Pat<(TFL_ExpOp $arg), (Tosa_ExpOp $arg)>;
-def : Pat<(TFL_LogOp $arg), (Tosa_LogOp $arg)>;
-def : Pat<(TFL_RsqrtOp $arg), (Tosa_RsqrtOp $arg)>;
-def : Pat<(TFL_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>;
-def : Pat<(TFL_CastOp $in), (Tosa_CastOp $in)>;
+//===----------------------------------------------------------------------===//
+
+def ConvertTFLAbsOp : Pat<(TFL_AbsOp $arg), (Tosa_AbsOp $arg)>;
+def ConvertTFLCeilOp : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>;
+def ConvertTFLFloorOp : Pat<(TFL_FloorOp $arg), (Tosa_FloorOp $arg)>;
+def ConvertTFLExpOp : Pat<(TFL_ExpOp $arg), (Tosa_ExpOp $arg)>;
+def ConvertTFLLogOp : Pat<(TFL_LogOp $arg), (Tosa_LogOp $arg)>;
+def ConvertTFLRsqrtOp : Pat<(TFL_RsqrtOp $arg), (Tosa_RsqrtOp $arg)>;
+def ConvertTFLLogicalNotOp : Pat<(TFL_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>;
+def ConvertTFLCastOp: Pat<(TFL_CastOp $in), (Tosa_CastOp $in)>;
// Removing the quant.stats op for unquantized models.
-def : Pat<(quant_StatisticsOp $value, $layer_stats, $axis_stats, $axis),
- (replaceWithValue $value)>;
+def ConvertQuantStatOp : Pat<(quant_StatisticsOp $value, $layer_stats, $axis_stats, $axis),
+ (replaceWithValue $value)>;
//===----------------------------------------------------------------------===//
// Binary ops patterns.
//===----------------------------------------------------------------------===//
-def : Pat<(TFL_LogicalAndOp $l, $r), (Tosa_LogicalAndOp $l, $r)>;
-def : Pat<(TFL_LogicalOrOp $l, $r), (Tosa_LogicalOrOp $l, $r)>;
-def : Pat<(TFL_PowOp $l, $r), (Tosa_PowOp $l, $r)>;
+def ConvertTFLLogicalAndOp : Pat<(TFL_LogicalAndOp $l, $r), (Tosa_LogicalAndOp $l, $r)>;
+def ConvertTFLLogicalOrOp : Pat<(TFL_LogicalOrOp $l, $r), (Tosa_LogicalOrOp $l, $r)>;
+def ConvertTFLPowOp : Pat<(TFL_PowOp $l, $r), (Tosa_PowOp $l, $r)>;
//===----------------------------------------------------------------------===//
// Ternary ops patterns.
//===----------------------------------------------------------------------===//
-