Move WhileOutline and RuntimeVerify inside AddTFToTFLConversionPasses (NFC)
PiperOrigin-RevId: 406609247
Change-Id: I7c5dc04f10071816f6bd7dad4710835098809ee8
diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
index 47cb6ec..8553ee4 100644
--- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
+++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
@@ -63,6 +63,8 @@
// Note: This is staging step and will be removed.
// TODO(b/137395003): Remove post switching legalization.
bool legalize_tf_while;
+ // Whether to outline WhileOp at the end of the pipeline.
+ bool outline_tf_while = false;
// Whether to do shape inference.
bool shape_inference;
// Whether to do TFLite runtime verification.
diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
index 4c99ffd..a90651f 100644
--- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
+++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
@@ -336,17 +336,13 @@
std::make_unique<mlir::TFL::ErrorCollectorInstrumentation>(
module->getContext()));
- tensorflow::AddTFToTFLConversionPasses(model_flags, toco_flags, pass_config,
- &pm, session);
- // Convert back to outlined while format for export back to flatbuffer.
- if (pass_config.legalize_tf_while) {
- pm.addPass(mlir::TFL::CreateWhileOutlinePass());
- }
- pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
-
+ mlir::TFL::PassConfig pass_config_copy = pass_config;
+ pass_config_copy.outline_tf_while = true;
+ tensorflow::AddTFToTFLConversionPasses(model_flags, toco_flags,
+ pass_config_copy, &pm, session);
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, toco_flags,
- pass_config.quant_specs, saved_model_tags, result, &pm);
+ pass_config_copy.quant_specs, saved_model_tags, result, &pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
index b62a4e2..e6bab46 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
@@ -318,6 +318,13 @@
if (pass_config.unfold_large_splat_constant) {
pass_manager->addPass(mlir::TFL::CreateUnfoldLargeSplatConstantPass());
}
+ if (pass_config.outline_tf_while) {
+ pass_manager->addPass(mlir::TFL::CreateWhileOutlinePass());
+ }
+ if (pass_config.runtime_verification) {
+ pass_manager->addNestedPass<mlir::FuncOp>(
+ mlir::TFL::CreateRuntimeVerifyPass());
+ }
}
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
index e24c21d..25da586 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
@@ -275,6 +275,8 @@
pass_config.unfold_batch_matmul = unfold_batchmatmul;
pass_config.unfold_large_splat_constant = unfold_large_splat_constant;
pass_config.guarantee_all_funcs_one_use = guarantee_all_funcs_one_use;
+ pass_config.runtime_verification = true;
+ pass_config.outline_tf_while = true;
if (enable_hlo_to_tf_conversion) {
pass_config.enable_hlo_to_tf_conversion = true;
@@ -283,13 +285,6 @@
// TODO(b/153507667): Pass the session object when importing logic is removed.
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm,
/*session=*/llvm::None);
- // TODO(b/150901738): Move those into tf_tfl_translate.cc.
- // Convert back to outlined while format for export back to flatbuffer.
- if (pass_config.legalize_tf_while) {
- pm.addPass(mlir::TFL::CreateWhileOutlinePass());
- }
- pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
-
toco::TocoFlags toco_flags;
toco_flags.set_force_select_tf_ops(!emit_builtin_tflite_ops);
toco_flags.set_enable_select_tf_ops(emit_select_tf_ops);