[TFLite] Add a standard pass pipeline for TFLite
This represents a std configuration pass pipeline.
PiperOrigin-RevId: 299289688
Change-Id: If097f6146eee3589b664ead87924b9bad6d738ab
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
index a948895..e44242c 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
@@ -180,3 +180,85 @@
}
} // namespace tensorflow
+
+namespace mlir {
+namespace TFL {
+
+struct StandardPipelineOptions
+ : public PassPipelineOptions<StandardPipelineOptions> {
+ // TODO(b/150915052): All the tf_tfl_translate_cl flags should
+ // move inside this.
+};
+
+// NOLINTNEXTLINE
+// This creates the standard pass pipeline for TF->TFLite. This
+// represents a std configuration for TFLite, for use with APIs like
+// tensorflow/python/pywrap_mlir.py::experimental_run_pass_pipeline
+// This does not yet include quantization passes.
+void CreateTFLStandardPipeline(OpPassManager& pm,
+ const StandardPipelineOptions& options) {
+ OpPassManager& func_pm = pm.nest<FuncOp>();
+
+ // tf_executor dialect passes - Cleaning up the IR.
+ func_pm.addPass(tf_executor::CreateSwitchFoldPass());
+ func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
+ func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
+
+ // more cleanup of executor dialect and raise to control flow.
+ pm.addPass(mlir::CreateTFExecutorToControlDialectConversion());
+ pm.addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
+
+ // This is needed for control flow support with TF TensorList.
+ pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass());
+
+ // Saved model pass to mark global tensors immutable.
+ pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
+ // Used to mark non-exported functions in saved model private.
+ pm.addPass(mlir::tf_saved_model::
+ CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
+ // Op fusion pass.
+ pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
+
+ pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLegalizeTFWhilePass());
+
+ pm.addPass(mlir::createInlinerPass());
+
+ // Canonicalize, CSE etc.
+ pm.addPass(mlir::TF::CreateDecodeConstantPass());
+ pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+ pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
+ // DCE for private symbols.
+ pm.addPass(mlir::createSymbolDCEPass());
+
+ // freeze global tensors.
+ pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
+
+ // TFLite dialect passes.
+ pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
+ pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+ pm.addPass(mlir::TFL::CreateLegalizeTFPass());
+ pm.addPass(mlir::TFL::CreateOptimizePass());
+ pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
+
+ // Canonicalize, CSE etc.
+ pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+ pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
+
+ // Pass for stateful operands like LSTM.
+ pm.addPass(mlir::TFL::CreateSplitMergedOperandsPass());
+
+ pm.addPass(mlir::TFL::CreateWhileOutlinePass());
+
+ pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
+}
+
+// Registers a pass pipeline for the standard TFL passes.
+static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
+ "tfl-standard-pipeline",
+ "Run the standard passes involved in transforming/optimizing the TF "
+ "program to TFLite after "
+ "importing into MLIR.",
+ CreateTFLStandardPipeline);
+
+} // namespace TFL
+} // namespace mlir