Expose a way to run an MLIR pass pipeline from python.
PiperOrigin-RevId: 276753194
Change-Id: Iebc195f58376ab37159371dfbfd5e89badd8c530
diff --git a/tensorflow/compiler/mlir/python/mlir.i b/tensorflow/compiler/mlir/python/mlir.i
index ba5bfb9..98db04d 100644
--- a/tensorflow/compiler/mlir/python/mlir.i
+++ b/tensorflow/compiler/mlir/python/mlir.i
@@ -17,7 +17,6 @@
%{
-#include "mlir/Parser.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/Support/raw_ostream.h"
@@ -114,41 +113,6 @@
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
}
-
-string ExperimentalRunPassPipeline(
- const string &mlir_txt,
- const string &pass_pipeline,
- bool show_debug_info,
- TF_Status* status) {
- mlir::MLIRContext context;
- mlir::OwningModuleRef module;
- {
- mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
- module = mlir::parseSourceString(mlir_txt, &context);
- if (!module) {
- Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
- return "// error";
- }
- }
-
- // Run the pass_pipeline on the module.
- mlir::PassManager pm(&context);
- std::string error;
- llvm::raw_string_ostream error_stream(error);
- if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
- TF_SetStatus(status, TF_INVALID_ARGUMENT,
- ("Invalid pass_pipeline: " + error_stream.str()).c_str());
- return "// error";
- }
-
- mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
- if (failed(pm.run(*module))) {
- Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
- return "// error";
- }
- return MlirModuleToString(*module, show_debug_info);
-}
-
} // namespace swig
} // namespace tensorflow
@@ -160,7 +124,6 @@
%unignore tensorflow::swig;
%unignore tensorflow::swig::ImportGraphDef;
%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
-%unignore tensorflow::swig::ExperimentalRunPassPipeline;
// Wrap this function
namespace tensorflow {
@@ -173,11 +136,6 @@
const string &exported_names,
bool show_debug_info,
TF_Status* status);
-static string ExperimentalRunPassPipeline(
- const string &mlir_txt,
- const string &pass_pipeline,
- bool show_debug_info,
- TF_Status* status);
} // namespace swig
} // namespace tensorflow
@@ -193,13 +151,6 @@
str(exported_names).encode('utf-8'),
show_debug_info
).decode('utf-8');
-
-def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
- return ExperimentalRunPassPipeline(
- mlir_txt.encode('utf-8'),
- pass_pipeline.encode('utf-8'),
- show_debug_info
- ).decode('utf-8');
%}
%unignoreall
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
index ebdbe3a..77b7b3a 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
@@ -81,11 +81,6 @@
logging.info('Saved model to: %s', save_model_path)
mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir(
save_model_path, ','.join(exported_names), show_debug_info)
- # We don't strictly need this, but it serves as a handy sanity check
- # for that API, which is otherwise a bit annoying to test.
- # The canonicalization shouldn't affect these tests in any way.
- mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
- mlir, 'canonicalize', show_debug_info)
print(mlir)
app.run(app_main)