[TFLite] Add saved model imported TF->TFLite conversion passes

Add a boolean flag to the pass config. The higher level API ConvertMLIRToTFLiteFlatBuffer is expected to set this accordingly.

For now, we plan to incrementally enable saved model imported conversion. To start with if tflite_convert.from_saved_model() is used in the python API and the underlying model is a RNN conversion, then this flag needs to be set to true. Eventually, we imagine this flag to be always true when the tflite_convert.from_saved_model() is invoked from the higher level API.

PiperOrigin-RevId: 298984717
Change-Id: Ibeef86667e03b51ac12a71ce79cb4493534749ab
diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
index b14041e..0b05a15 100644
--- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
+++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
@@ -36,7 +36,8 @@
         form_clusters(false),
         inline_functions(true),
         unfold_batch_matmul(true),
-        legalize_tf_while(true) {}
+        legalize_tf_while(true),
+        saved_model_import(false) {}
 
   // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
   // added, which produces TF Lite ops.
@@ -66,6 +67,10 @@
   // Note: This is staging step and will be removed.
   // TODO(b/137395003): Remove post switching legalization.
   bool legalize_tf_while;
+
+  // This flag indicates whether the TF program to be converted is being
+  // imported into MLIR via saved model import.
+  bool saved_model_import;
 };
 
 }  // namespace TFL
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
index b000de1..358db8f 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
@@ -84,6 +84,18 @@
     pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
   }
 
+  if (pass_config.saved_model_import) {
+    // This pass does resource analysis of saved model global tensors and marks
+    // those deemed read-only as immutable.
+    pass_manager->addPass(
+        mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
+    // This pass marks non-exported functions as symbol visibility 'private'
+    // those deemed read-only as immutable.
+    pass_manager->addPass(
+        mlir::tf_saved_model::
+            CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
+  }
+
   // Enable fusing composite ops that can be lowered to built-in TFLite ops.
   if (pass_config.emit_builtin_tflite_ops) {
     pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
@@ -114,6 +126,10 @@
         mlir::TFL::CreateLegalizeTFWhilePass());
   }
 
+  if (pass_config.inline_functions) {
+    pass_manager->addPass(mlir::createInlinerPass());
+  }
+
   // TODO(jpienaar): Revise post dialect constants.
   pass_manager->addPass(mlir::TF::CreateDecodeConstantPass());
   // Canonicalization includes const folding, which is utilized here to optimize
@@ -121,9 +137,13 @@
   // tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
   pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
   pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
-
-  if (pass_config.inline_functions) {
-    pass_manager->addPass(mlir::createInlinerPass());
+  // This pass does dead code elimination based on symbol visibility.
+  pass_manager->addPass(mlir::createSymbolDCEPass());
+  if (pass_config.saved_model_import) {
+    // This pass 'freezes' immutable global tensors and inlines them as tf
+    // constant ops.
+    pass_manager->addPass(
+        mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
   }
 
   // The below passes only make sense if Builtin TFLite ops are enabled