Add a Legalization Enable Option to the TFLite to TOSA pass in order to add rewrites dynamically

Change-Id: I24c29e1e3488a6ebf2a0433a749bd487918e6e3b
Signed-off-by: Aaron DeBattista <aaron.debattista@arm.com>
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
index b27fedb..a9a1279 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
@@ -31,6 +31,11 @@
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Transforms/DialectConversion.h"
+
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
@@ -40,16 +45,6 @@
 #define DEBUG_TYPE PASS_NAME
 #define HARDSWISH_EXPLICIT_RESCALING false
 
-// Conditionally avoid converting some TFLite ops to TOSA.
-// By default, all conversions will be invoked.
-//
-// The denylist file lists patterns which are not legalized from TFLite to TOSA.
-llvm::cl::opt<std::string> tfl_tosa_denylist(
-    "tfl-tosa-denylist",
-    llvm::cl::desc("<a list of patterns not legalized from TFLite to TOSA>"),
-    llvm::cl::init("transforms/tfl_tosa_denylist.txt"),
-    llvm::cl::value_desc("pattern name"));
-
 namespace mlir {
 namespace tosa {
 namespace {
@@ -59,12 +54,24 @@
 // Performs lowering to TOSA dialect.
 class LegalizeTFL : public TosaLegalizeTFLPassBase<LegalizeTFL> {
  public:
-  explicit LegalizeTFL() {}
+  explicit LegalizeTFL(std::unordered_set<std::string> &legalization_disable) {this->legalization_disable = legalization_disable;}
   void runOnFunction() override;
+ private:
+  std::unordered_set<std::string> legalization_disable;
 };
 
 #include "tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.inc"
 
+// Input from tfl.conv2d takes 64 bits a bias, while tosa.conv2d expects 48
+// bits. Need to do a customized truncate here instead of tablegen to handle
+// attribute with negative value.
+struct ConvertConstantOp : public RewritePattern {
+  explicit ConvertConstantOp(MLIRContext* context)
+      : RewritePattern(arith::ConstantOp::getOperationName(), 1, context) {}
+  LogicalResult matchAndRewrite(Operation* op,
+                                PatternRewriter& rewriter) const override;
+};
+
 #define DECL_CONVERT_OP(tfl_op)                                              \
   struct ConvertTFL##tfl_op##Op : public RewritePattern {                    \
     explicit ConvertTFL##tfl_op##Op(MLIRContext* context)                    \
@@ -155,19 +162,9 @@
 DECL_CONVERT_OP(OneHot);
 DECL_CONVERT_OP(ArgMax);
 DECL_CONVERT_OP(FakeQuant);
+
 #undef DECL_CONVERT_OP
 
-// Input from tfl.conv2d takes 64 bits a bias, while tosa.conv2d expects 48
-// bits. Need to do a customized truncate here instead of tablegen to handle
-// attribute with negative value.
-
-struct ConvertConstantOp : public RewritePattern {
-  explicit ConvertConstantOp(MLIRContext* context)
-      : RewritePattern(arith::ConstantOp::getOperationName(), 1, context) {}
-  LogicalResult matchAndRewrite(Operation* op,
-                                PatternRewriter& rewriter) const override;
-};
-
 LogicalResult ConvertTFLReluOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tfl_relu_op = cast<TFL::ReluOp>(op);
@@ -3095,21 +3092,42 @@
 
 void LegalizeTFL::runOnFunction() {
   OwningRewritePatternList patterns(&getContext());
-  populateLegalizeTFLPatterns(&getContext(), patterns);
+  populateLegalizeTFLPatterns(&getContext(), patterns, legalization_disable);
 
   auto func = getFunction();
   if (ApplyPatternsWithShapeResolution(func, std::move(patterns)).failed()) {
     signalPassFailure();
   }
 }
+
 }  // namespace
 
 void populateLegalizeTFLPatterns(MLIRContext* ctx,
-                                 RewritePatternSet& patterns) {
-  // Add the generated patterns to the list.
-  populateWithGenerated(patterns);
+                                 RewritePatternSet& patterns,
+                                 std::unordered_set<std::string> legalization_disable) {
+  //  0. If there is no entry for the Op in the legalization_disable set
+  //     than we add the rewrite.
+  //  1. If there is an entry for the Op in the legalization_disable set
+  //     than we skip adding the rewrite
+  #define DEF_PATTERN_INSERT(PAT)                                              \
+    if (legalization_disable.find(#PAT) == legalization_disable.end())         \
+      patterns.insert<Convert##PAT##Op>(ctx);
 
-#define DEF_PATTERN_INSERT(PAT) patterns.insert<Convert##PAT##Op>(ctx);
+  DEF_PATTERN_INSERT(TFLAbs);
+  DEF_PATTERN_INSERT(TFLCeil);
+  DEF_PATTERN_INSERT(TFLFloor);
+  DEF_PATTERN_INSERT(TFLExp);
+  DEF_PATTERN_INSERT(TFLLog);
+  DEF_PATTERN_INSERT(TFLRsqrt);
+  DEF_PATTERN_INSERT(TFLLogicalNot);
+  DEF_PATTERN_INSERT(TFLCast);
+
+  DEF_PATTERN_INSERT(QuantStat);
+
+  DEF_PATTERN_INSERT(TFLLogicalAnd);
+  DEF_PATTERN_INSERT(TFLLogicalOr);
+  DEF_PATTERN_INSERT(TFLPow);
+
   DEF_PATTERN_INSERT(TFLRelu);
   DEF_PATTERN_INSERT(TFLRelu6);
   DEF_PATTERN_INSERT(TFLEqual);
@@ -3121,6 +3139,7 @@
   DEF_PATTERN_INSERT(TFLMul);
   DEF_PATTERN_INSERT(TFLSquare);
   DEF_PATTERN_INSERT(TFLSquaredDifference);
+  DEF_PATTERN_INSERT(TFLRound);
   DEF_PATTERN_INSERT(TFLDiv);
   DEF_PATTERN_INSERT(TFLMaximum);
   DEF_PATTERN_INSERT(TFLMinimum);
@@ -3160,8 +3179,8 @@
   DEF_PATTERN_INSERT(TFLTile);
   DEF_PATTERN_INSERT(TFLSlice);
   DEF_PATTERN_INSERT(TFLStridedSlice);
-  DEF_PATTERN_INSERT(TFLZerosLike);
   DEF_PATTERN_INSERT(TFLHardSwish);
+  DEF_PATTERN_INSERT(TFLZerosLike);
   DEF_PATTERN_INSERT(TFLLess);
   DEF_PATTERN_INSERT(TFLLessEqual);
   DEF_PATTERN_INSERT(TFLPad);
@@ -3186,19 +3205,19 @@
   DEF_PATTERN_INSERT(TFLDequantize);
   DEF_PATTERN_INSERT(TFLConst);
   DEF_PATTERN_INSERT(TFLQConst);
-  DEF_PATTERN_INSERT(Constant);
   DEF_PATTERN_INSERT(TFLGather);
   DEF_PATTERN_INSERT(TFLGatherNd);
   DEF_PATTERN_INSERT(TFLSparseToDense);
+  DEF_PATTERN_INSERT(Constant);
+  DEF_PATTERN_INSERT(TFLOneHot);
   DEF_PATTERN_INSERT(TFLArgMax);
   DEF_PATTERN_INSERT(TFLFakeQuant);
-  DEF_PATTERN_INSERT(TFLOneHot);
-#undef DEF_PATTERN_INSERT
 }
 
 // Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass.
-std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass() {
-  return std::make_unique<LegalizeTFL>();
+std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass(
+  std::unordered_set<std::string> legalization_disable) {
+  return std::make_unique<LegalizeTFL>(legalization_disable);
 }
 
 }  // namespace tosa
diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.h b/tensorflow/compiler/mlir/tosa/transforms/passes.h
index 84d297c..fbf8235 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tosa/transforms/passes.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H
 
 #include <memory>
+#include <unordered_set>
 
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 
@@ -24,11 +25,11 @@
 namespace tosa {
 
 void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns);
-void populateLegalizeTFLPatterns(MLIRContext* ctx, RewritePatternSet& patterns);
+void populateLegalizeTFLPatterns(MLIRContext* ctx, RewritePatternSet& patterns, std::unordered_set<std::string> legalization_disable = {});
 
 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass();
 std::unique_ptr<OperationPass<FuncOp>> createFuseBiasTFPass();
-std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass();
+std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFLPass(std::unordered_set<std::string> legalization_disable = {});
 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFTFLPass();
 std::unique_ptr<OperationPass<FuncOp>> createConvertTFLUint8Pass();
 std::unique_ptr<OperationPass<FuncOp>> createStripQuantTypesPass();
diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.td b/tensorflow/compiler/mlir/tosa/transforms/passes.td
index b657455..a06d47b 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/passes.td
+++ b/tensorflow/compiler/mlir/tosa/transforms/passes.td
@@ -23,7 +23,7 @@
 
 def TosaLegalizeTFLPass : FunctionPass<"tosa-legalize-tfl"> {
   let summary = "Legalize from TensorFlow Lite to TOSA";
-  let constructor = "createLegalizeTFLPass()";
+  let constructor = "createLegalizeTFLPass({})";
   let dependentDialects = ["TosaDialect"];
 }
 
diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
index cbf84c0..df9b912 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
+++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
@@ -21,29 +21,31 @@
 include "mlir/Dialect/Quant/QuantOps.td"
 include "mlir/Dialect/Tosa/IR/TosaOps.td"
 
+//===----------------------------------------------------------------------===//
 // Unary ops patterns.
-def : Pat<(TFL_AbsOp $arg), (Tosa_AbsOp $arg)>;
-def : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>;
-def : Pat<(TFL_FloorOp $arg), (Tosa_FloorOp $arg)>;
-def : Pat<(TFL_ExpOp $arg), (Tosa_ExpOp $arg)>;
-def : Pat<(TFL_LogOp $arg), (Tosa_LogOp $arg)>;
-def : Pat<(TFL_RsqrtOp $arg), (Tosa_RsqrtOp $arg)>;
-def : Pat<(TFL_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>;
-def : Pat<(TFL_CastOp $in), (Tosa_CastOp $in)>;
+//===----------------------------------------------------------------------===//
+
+def ConvertTFLAbsOp : Pat<(TFL_AbsOp $arg), (Tosa_AbsOp $arg)>;
+def ConvertTFLCeilOp : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>;
+def ConvertTFLFloorOp : Pat<(TFL_FloorOp $arg), (Tosa_FloorOp $arg)>;
+def ConvertTFLExpOp : Pat<(TFL_ExpOp $arg), (Tosa_ExpOp $arg)>;
+def ConvertTFLLogOp : Pat<(TFL_LogOp $arg), (Tosa_LogOp $arg)>;
+def ConvertTFLRsqrtOp : Pat<(TFL_RsqrtOp $arg), (Tosa_RsqrtOp $arg)>;
+def ConvertTFLLogicalNotOp : Pat<(TFL_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>;
+def ConvertTFLCastOp: Pat<(TFL_CastOp $in), (Tosa_CastOp $in)>;
 
 // Removing the quant.stats op for unquantized models.
-def : Pat<(quant_StatisticsOp $value, $layer_stats, $axis_stats, $axis),
-          (replaceWithValue $value)>;
+def ConvertQuantStatOp : Pat<(quant_StatisticsOp $value, $layer_stats, $axis_stats, $axis),
+                         (replaceWithValue $value)>;
 
 //===----------------------------------------------------------------------===//
 // Binary ops patterns.
 //===----------------------------------------------------------------------===//
 
-def : Pat<(TFL_LogicalAndOp $l, $r), (Tosa_LogicalAndOp $l, $r)>;
-def : Pat<(TFL_LogicalOrOp $l, $r), (Tosa_LogicalOrOp $l, $r)>;
-def : Pat<(TFL_PowOp $l, $r), (Tosa_PowOp $l, $r)>;
+def ConvertTFLLogicalAndOp : Pat<(TFL_LogicalAndOp $l, $r), (Tosa_LogicalAndOp $l, $r)>;
+def ConvertTFLLogicalOrOp : Pat<(TFL_LogicalOrOp $l, $r), (Tosa_LogicalOrOp $l, $r)>;
+def ConvertTFLPowOp : Pat<(TFL_PowOp $l, $r), (Tosa_PowOp $l, $r)>;
 
 //===----------------------------------------------------------------------===//
 // Ternary ops patterns.
 //===----------------------------------------------------------------------===//
-