Move WhileOutline and RuntimeVerify inside AddTFToTFLConversionPasses (NFC)

PiperOrigin-RevId: 406609247
Change-Id: I7c5dc04f10071816f6bd7dad4710835098809ee8
diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
index 47cb6ec..8553ee4 100644
--- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
+++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
@@ -63,6 +63,8 @@
   // Note: This is staging step and will be removed.
   // TODO(b/137395003): Remove post switching legalization.
   bool legalize_tf_while;
+  // Whether to outline WhileOp at the end of the pipeline.
+  bool outline_tf_while = false;
   // Whether to do shape inference.
   bool shape_inference;
   // Whether to do TFLite runtime verification.
diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
index 4c99ffd..a90651f 100644
--- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
+++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
@@ -336,17 +336,13 @@
       std::make_unique<mlir::TFL::ErrorCollectorInstrumentation>(
           module->getContext()));
 
-  tensorflow::AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config,
-                                         &pm, session);
-  // Convert back to outlined while format for export back to flatbuffer.
-  if (pass_config.legalize_tf_while) {
-    pm.addPass(mlir::TFL::CreateWhileOutlinePass());
-  }
-  pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
-
+  mlir::TFL::PassConfig pass_config_copy = pass_config;
+  pass_config_copy.outline_tf_while = true;
+  tensorflow::AddTFToTFLConversionPasses(model_flags, toco_flags,
+                                         pass_config_copy, &pm, session);
   auto status = ConvertTFExecutorToTFLOrFlatbuffer(
       module.get(), /*export_to_mlir=*/false, toco_flags,
-      pass_config.quant_specs, saved_model_tags, result, &pm);
+      pass_config_copy.quant_specs, saved_model_tags, result, &pm);
   if (toco_flags.has_dump_graphviz_dir()) {
     TF_RETURN_IF_ERROR(DumpOpGraphToFile(
         // rename once we enable the new converter feature flag.
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
index b62a4e2..e6bab46 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
@@ -318,6 +318,13 @@
   if (pass_config.unfold_large_splat_constant) {
     pass_manager->addPass(mlir::TFL::CreateUnfoldLargeSplatConstantPass());
   }
+  if (pass_config.outline_tf_while) {
+    pass_manager->addPass(mlir::TFL::CreateWhileOutlinePass());
+  }
+  if (pass_config.runtime_verification) {
+    pass_manager->addNestedPass<mlir::FuncOp>(
+        mlir::TFL::CreateRuntimeVerifyPass());
+  }
 }
 
 void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
index e24c21d..25da586 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
@@ -275,6 +275,8 @@
   pass_config.unfold_batch_matmul = unfold_batchmatmul;
   pass_config.unfold_large_splat_constant = unfold_large_splat_constant;
   pass_config.guarantee_all_funcs_one_use = guarantee_all_funcs_one_use;
+  pass_config.runtime_verification = true;
+  pass_config.outline_tf_while = true;
 
   if (enable_hlo_to_tf_conversion) {
     pass_config.enable_hlo_to_tf_conversion = true;
@@ -283,13 +285,6 @@
   // TODO(b/153507667): Pass the session object when importing logic is removed.
   tensorflow::AddTFToTFLConversionPasses(pass_config, &pm,
                                          /*session=*/llvm::None);
-  // TODO(b/150901738): Move those into tf_tfl_translate.cc.
-  // Convert back to outlined while format for export back to flatbuffer.
-  if (pass_config.legalize_tf_while) {
-    pm.addPass(mlir::TFL::CreateWhileOutlinePass());
-  }
-  pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
-
   toco::TocoFlags toco_flags;
   toco_flags.set_force_select_tf_ops(!emit_builtin_tflite_ops);
   toco_flags.set_enable_select_tf_ops(emit_select_tf_ops);