Export the MLIR classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 292076160
Change-Id: I62bf3aac988c3ce4e18aa01ee49d8aa9ffde383d
diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD
index 5291cf3..07405c0 100644
--- a/tensorflow/compiler/mlir/python/BUILD
+++ b/tensorflow/compiler/mlir/python/BUILD
@@ -3,9 +3,29 @@
licenses = ["notice"], # Apache 2.0
)
-exports_files(
- ["mlir.i"],
+cc_library(
+ name = "mlir",
+ srcs = ["mlir.cc"],
+ hdrs = ["mlir.h"],
+ deps = [
+ "//tensorflow/c:tf_status",
+ "//tensorflow/c:tf_status_helper",
+ "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+ "//tensorflow/compiler/mlir/tensorflow:error_util",
+ "//tensorflow/compiler/mlir/tensorflow:import_utils",
+ "@llvm-project//llvm:support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
+ ],
+)
+
+filegroup(
+ name = "pywrap_mlir_hdrs",
+ srcs = [
+ "mlir.h",
+ ],
visibility = [
- "//tensorflow/python:__subpackages__",
+ "//tensorflow/python:__pkg__",
],
)
diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc
new file mode 100644
index 0000000..e6ac78b
--- /dev/null
+++ b/tensorflow/compiler/mlir/python/mlir.cc
@@ -0,0 +1,157 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Parser.h" // TF:llvm-project
+#include "mlir/Pass/PassManager.h" // TF:llvm-project
+#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
+
+namespace tensorflow {
+
+std::string ImportGraphDef(const std::string &proto,
+ const std::string &pass_pipeline,
+ TF_Status *status) {
+ GraphDef graphdef;
+ auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(status, s);
+ return "// error";
+ }
+ GraphDebugInfo debug_info;
+ GraphImportConfig specs;
+ mlir::MLIRContext context;
+ auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
+ if (!module.ok()) {
+ Set_TF_Status_from_Status(status, module.status());
+ return "// error";
+ }
+
+ // Run the pass_pipeline on the module if not empty.
+ if (!pass_pipeline.empty()) {
+ mlir::PassManager pm(&context);
+ std::string error;
+ llvm::raw_string_ostream error_stream(error);
+ if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
+ TF_SetStatus(status, TF_INVALID_ARGUMENT,
+ ("Invalid pass_pipeline: " + error_stream.str()).c_str());
+ return "// error";
+ }
+
+ mlir::StatusScopedDiagnosticHandler statusHandler(&context);
+ if (failed(pm.run(*module.ValueOrDie()))) {
+ Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
+ return "// error";
+ }
+ }
+ return MlirModuleToString(*module.ConsumeValueOrDie());
+}
+
+std::string ExperimentalConvertSavedModelToMlir(
+ const std::string &saved_model_path, const std::string &exported_names_str,
+ bool show_debug_info, TF_Status *status) {
+ // Load the saved model into a SavedModelV2Bundle.
+
+ tensorflow::SavedModelV2Bundle bundle;
+ auto load_status =
+ tensorflow::SavedModelV2Bundle::Load(saved_model_path, &bundle);
+ if (!load_status.ok()) {
+ Set_TF_Status_from_Status(status, load_status);
+ return "// error";
+ }
+
+ // Convert the SavedModelV2Bundle to an MLIR module.
+
+ std::vector<string> exported_names =
+ absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
+ mlir::MLIRContext context;
+ auto module_or = ConvertSavedModelToMlir(
+ &bundle, &context, absl::Span<std::string>(exported_names));
+ if (!module_or.status().ok()) {
+ Set_TF_Status_from_Status(status, module_or.status());
+ return "// error";
+ }
+
+ return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
+}
+
+std::string ExperimentalConvertSavedModelV1ToMlir(
+ const std::string &saved_model_path, const std::string &tags,
+ bool show_debug_info, TF_Status *status) {
+ // Load the saved model into a SavedModelBundle.
+
+ std::unordered_set<string> tag_set =
+ absl::StrSplit(tags, ',', absl::SkipEmpty());
+
+ tensorflow::SavedModelBundle bundle;
+ auto load_status =
+ tensorflow::LoadSavedModel({}, {}, saved_model_path, tag_set, &bundle);
+ if (!load_status.ok()) {
+ Set_TF_Status_from_Status(status, load_status);
+ return "// error";
+ }
+
+ // Convert the SavedModelBundle to an MLIR module.
+
+ mlir::MLIRContext context;
+ auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
+ if (!module_or.status().ok()) {
+ Set_TF_Status_from_Status(status, module_or.status());
+ return "// error";
+ }
+
+ return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
+}
+
+std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
+ const std::string &pass_pipeline,
+ bool show_debug_info,
+ TF_Status *status) {
+ mlir::MLIRContext context;
+ mlir::OwningModuleRef module;
+ {
+ mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
+ module = mlir::parseSourceString(mlir_txt, &context);
+ if (!module) {
+ Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
+ return "// error";
+ }
+ }
+
+ // Run the pass_pipeline on the module.
+ mlir::PassManager pm(&context);
+ std::string error;
+ llvm::raw_string_ostream error_stream(error);
+ if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
+ TF_SetStatus(status, TF_INVALID_ARGUMENT,
+ ("Invalid pass_pipeline: " + error_stream.str()).c_str());
+ return "// error";
+ }
+
+ mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
+ if (failed(pm.run(*module))) {
+ Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
+ return "// error";
+ }
+ return MlirModuleToString(*module, show_debug_info);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h
new file mode 100644
index 0000000..b85b409
--- /dev/null
+++ b/tensorflow/compiler/mlir/python/mlir.h
@@ -0,0 +1,67 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Functions for getting information about kernels registered in the binary.
+// Migrated from previous SWIG file (mlir.i) authored by aminim@.
+#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
+#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
+
+#include <string>
+
+#include "tensorflow/c/tf_status.h"
+
+namespace tensorflow {
+
+// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
+// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
+// returning it as a string.
+// This is an early experimental API, ideally we should return a wrapper object
+// around a Python binding to the MLIR module.
+std::string ImportGraphDef(const std::string &proto,
+ const std::string &pass_pipeline, TF_Status *status);
+
+// Load a SavedModel and return a textual MLIR string corresponding to it.
+//
+// Args:
+// saved_model_path: File path from which to load the SavedModel.
+// exported_names_str: Comma-separated list of names to export.
+// Empty means "export all".
+//
+// Returns:
+// A string of textual MLIR representing the raw imported SavedModel.
+std::string ExperimentalConvertSavedModelToMlir(
+ const std::string &saved_model_path, const std::string &exported_names_str,
+ bool show_debug_info, TF_Status *status);
+
+// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
+//
+// Args:
+// saved_model_path: File path from which to load the SavedModel.
+// tags: Tags to identify MetaGraphDef that need to be loaded.
+//
+// Returns:
+// A string of textual MLIR representing the raw imported SavedModel.
+std::string ExperimentalConvertSavedModelV1ToMlir(
+ const std::string &saved_model_path, const std::string &tags,
+ bool show_debug_info, TF_Status *status);
+
+std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
+ const std::string &pass_pipeline,
+ bool show_debug_info,
+ TF_Status *status);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
diff --git a/tensorflow/compiler/mlir/python/mlir.i b/tensorflow/compiler/mlir/python/mlir.i
deleted file mode 100644
index b1d5328..0000000
--- a/tensorflow/compiler/mlir/python/mlir.i
+++ /dev/null
@@ -1,252 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-%include "tensorflow/python/platform/base.i"
-
-%{
-
-#include "mlir/Parser.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Pass/PassManager.h"
-#include "llvm/Support/raw_ostream.h"
-#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
-
-namespace tensorflow {
-namespace swig {
-
-// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
-// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
-// returning it as a string.
-// This is an early experimental API, ideally we should return a wrapper object
-// around a Python binding to the MLIR module.
-string ImportGraphDef(const string &proto, const string &pass_pipeline, TF_Status* status) {
- GraphDef graphdef;
- auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- return "// error";
- }
- GraphDebugInfo debug_info;
- GraphImportConfig specs;
- mlir::MLIRContext context;
- auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
- if (!module.ok()) {
- Set_TF_Status_from_Status(status, module.status());
- return "// error";
- }
-
- // Run the pass_pipeline on the module if not empty.
- if (!pass_pipeline.empty()) {
- mlir::PassManager pm(&context);
- std::string error;
- llvm::raw_string_ostream error_stream(error);
- if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
- TF_SetStatus(status, TF_INVALID_ARGUMENT,
- ("Invalid pass_pipeline: " + error_stream.str()).c_str());
- return "// error";
- }
-
- mlir::StatusScopedDiagnosticHandler statusHandler(&context);
- if (failed(pm.run(*module.ValueOrDie()))) {
- Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
- return "// error";
- }
- }
- return MlirModuleToString(*module.ConsumeValueOrDie());
-}
-
-// Load a SavedModel and return a textual MLIR string corresponding to it.
-//
-// Args:
-// saved_model_path: File path from which to load the SavedModel.
-// exported_names_str: Comma-separated list of names to export.
-// Empty means "export all".
-//
-// Returns:
-// A string of textual MLIR representing the raw imported SavedModel.
-string ExperimentalConvertSavedModelToMlir(
- const string &saved_model_path,
- const string &exported_names_str,
- bool show_debug_info,
- TF_Status* status) {
- // Load the saved model into a SavedModelV2Bundle.
-
- tensorflow::SavedModelV2Bundle bundle;
- auto load_status = tensorflow::SavedModelV2Bundle::Load(
- saved_model_path, &bundle);
- if (!load_status.ok()) {
- Set_TF_Status_from_Status(status, load_status);
- return "// error";
- }
-
- // Convert the SavedModelV2Bundle to an MLIR module.
-
- std::vector<string> exported_names =
- absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
- mlir::MLIRContext context;
- auto module_or = ConvertSavedModelToMlir(&bundle, &context,
- absl::Span<std::string>(exported_names));
- if (!module_or.status().ok()) {
- Set_TF_Status_from_Status(status, module_or.status());
- return "// error";
- }
-
- return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
-}
-
-// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
-//
-// Args:
-// saved_model_path: File path from which to load the SavedModel.
-// tags: Tags to identify MetaGraphDef that need to be loaded.
-//
-// Returns:
-// A string of textual MLIR representing the raw imported SavedModel.
-string ExperimentalConvertSavedModelV1ToMlir(
- const string &saved_model_path,
- const string &tags,
- bool show_debug_info,
- TF_Status* status) {
- // Load the saved model into a SavedModelBundle.
-
- std::unordered_set<string> tag_set
- = absl::StrSplit(tags, ',', absl::SkipEmpty());
-
- tensorflow::SavedModelBundle bundle;
- auto load_status = tensorflow::LoadSavedModel(
- {}, {},
- saved_model_path, tag_set, &bundle);
- if (!load_status.ok()) {
- Set_TF_Status_from_Status(status, load_status);
- return "// error";
- }
-
- // Convert the SavedModelBundle to an MLIR module.
-
- mlir::MLIRContext context;
- auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
- if (!module_or.status().ok()) {
- Set_TF_Status_from_Status(status, module_or.status());
- return "// error";
- }
-
- return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
-}
-
-
-string ExperimentalRunPassPipeline(
- const string &mlir_txt,
- const string &pass_pipeline,
- bool show_debug_info,
- TF_Status* status) {
- mlir::MLIRContext context;
- mlir::OwningModuleRef module;
- {
- mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
- module = mlir::parseSourceString(mlir_txt, &context);
- if (!module) {
- Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
- return "// error";
- }
- }
-
- // Run the pass_pipeline on the module.
- mlir::PassManager pm(&context);
- std::string error;
- llvm::raw_string_ostream error_stream(error);
- if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
- TF_SetStatus(status, TF_INVALID_ARGUMENT,
- ("Invalid pass_pipeline: " + error_stream.str()).c_str());
- return "// error";
- }
-
- mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
- if (failed(pm.run(*module))) {
- Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
- return "// error";
- }
- return MlirModuleToString(*module, show_debug_info);
-}
-
-} // namespace swig
-} // namespace tensorflow
-
-%}
-
-%ignoreall
-
-%unignore tensorflow;
-%unignore tensorflow::swig;
-%unignore tensorflow::swig::ImportGraphDef;
-%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
-%unignore tensorflow::swig::ExperimentalConvertSavedModelV1ToMlir;
-%unignore tensorflow::swig::ExperimentalRunPassPipeline;
-
-// Wrap this function
-namespace tensorflow {
-namespace swig {
-static string ImportGraphDef(const string &graphdef,
- const string &pass_pipeline,
- TF_Status* status);
-static string ExperimentalConvertSavedModelToMlir(
- const string &saved_model_path,
- const string &exported_names,
- bool show_debug_info,
- TF_Status* status);
-static string ExperimentalConvertSavedModelV1ToMlir(
- const string &saved_model_path,
- const string &tags,
- bool show_debug_info,
- TF_Status* status);
-static string ExperimentalRunPassPipeline(
- const string &mlir_txt,
- const string &pass_pipeline,
- bool show_debug_info,
- TF_Status* status);
-} // namespace swig
-} // namespace tensorflow
-
-%insert("python") %{
-def import_graphdef(graphdef, pass_pipeline):
- return ImportGraphDef(str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8')).decode('utf-8');
-
-def experimental_convert_saved_model_to_mlir(saved_model_path,
- exported_names,
- show_debug_info):
- return ExperimentalConvertSavedModelToMlir(
- str(saved_model_path).encode('utf-8'),
- str(exported_names).encode('utf-8'),
- show_debug_info
- ).decode('utf-8');
-
-def experimental_convert_saved_model_v1_to_mlir(saved_model_path,
- tags, show_debug_info):
- return ExperimentalConvertSavedModelV1ToMlir(
- str(saved_model_path).encode('utf-8'),
- str(tags).encode('utf-8'),
- show_debug_info
- ).decode('utf-8');
-
-def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
- return ExperimentalRunPassPipeline(
- mlir_txt.encode('utf-8'),
- pass_pipeline.encode('utf-8'),
- show_debug_info
- ).decode('utf-8');
-%}
-
-%unignoreall
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
index fd8221c..de61800 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
@@ -29,7 +29,7 @@
from absl import logging
import tensorflow.compat.v2 as tf
-from tensorflow.python import pywrap_tensorflow
+from tensorflow.python import pywrap_mlir # pylint: disable=g-direct-tensorflow-import
# Use /tmp to make debugging the tests easier (see README.md)
flags.DEFINE_string('save_model_path', '',
@@ -84,13 +84,13 @@
tf.saved_model.save(
create_module_fn(), save_model_path, options=save_options)
logging.info('Saved model to: %s', save_model_path)
- mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir(
+ mlir = pywrap_mlir.experimental_convert_saved_model_to_mlir(
save_model_path, ','.join(exported_names), show_debug_info)
# We don't strictly need this, but it serves as a handy sanity check
# for that API, which is otherwise a bit annoying to test.
# The canonicalization shouldn't affect these tests in any way.
- mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
- mlir, 'canonicalize', show_debug_info)
+ mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize',
+ show_debug_info)
print(mlir)
app.run(app_main)
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py
index 35858d2..fb29470 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py
@@ -28,7 +28,7 @@
from absl import logging
import tensorflow.compat.v1 as tf
-from tensorflow.python import pywrap_tensorflow
+from tensorflow.python import pywrap_mlir # pylint: disable=g-direct-tensorflow-import
# Use /tmp to make debugging the tests easier (see README.md)
flags.DEFINE_string('save_model_path', '', 'Path to save the model to.')
@@ -80,14 +80,15 @@
builder.save()
logging.info('Saved model to: %s', save_model_path)
- mlir = pywrap_tensorflow.experimental_convert_saved_model_v1_to_mlir(
+ mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
show_debug_info)
# We don't strictly need this, but it serves as a handy sanity check
# for that API, which is otherwise a bit annoying to test.
# The canonicalization shouldn't affect these tests in any way.
- mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
- mlir, 'tf-standard-pipeline', show_debug_info)
+ mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir,
+ 'tf-standard-pipeline',
+ show_debug_info)
print(mlir)
app.run(app_main)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 548fc8c..4db47fe 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1130,6 +1130,7 @@
":platform",
":pywrap_tensorflow",
":pywrap_tfe",
+ ":pywrap_mlir",
":random_seed",
":sparse_tensor",
":tensor_spec",
@@ -5544,7 +5545,6 @@
"grappler/tf_optimizer.i",
"lib/core/strings.i",
"platform/base.i",
- "//tensorflow/compiler/mlir/python:mlir.i",
],
# add win_def_file for pywrap_tensorflow
win_def_file = select({
@@ -5572,7 +5572,6 @@
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
- "//tensorflow/compiler/mlir:passes",
"//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
@@ -5593,6 +5592,7 @@
"//tensorflow/lite/toco/python:toco_python_api",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//tensorflow/core/util/tensor_bundle",
+ "//tensorflow/compiler/mlir/python:mlir",
] + (tf_additional_lib_deps() +
tf_additional_plugin_deps()) + if_ngraph([
"@ngraph_tf//:ngraph_tf",
@@ -5631,6 +5631,7 @@
"//tensorflow/core/common_runtime/eager:context", # tfe
"//tensorflow/core/profiler/lib:profiler_session", # tfe
"//tensorflow/c:tf_status_helper", # tfe
+ "//tensorflow/compiler/mlir/python:mlir", # mlir
]
# Filter the DEF file to reduce the number of symbols to 64K or less.
@@ -7681,6 +7682,34 @@
)
py_library(
+ name = "pywrap_mlir",
+ srcs = ["pywrap_mlir.py"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":_pywrap_mlir",
+ ":pywrap_tensorflow",
+ ],
+)
+
+tf_python_pybind_extension(
+ name = "_pywrap_mlir",
+ srcs = ["mlir_wrapper.cc"],
+ hdrs = [
+ "lib/core/safe_ptr.h",
+ "//tensorflow/c:headers",
+ "//tensorflow/c/eager:headers",
+ "//tensorflow/compiler/mlir/python:pywrap_mlir_hdrs",
+ ],
+ module_name = "_pywrap_mlir",
+ deps = [
+ ":pybind11_lib",
+ ":pybind11_status",
+ "//third_party/python_runtime:headers",
+ "@pybind11",
+ ],
+)
+
+py_library(
name = "pywrap_tfe",
srcs = ["pywrap_tfe.py"],
visibility = ["//visibility:public"],
diff --git a/tensorflow/python/compiler/mlir/BUILD b/tensorflow/python/compiler/mlir/BUILD
index ee191e4..fe59213 100644
--- a/tensorflow/python/compiler/mlir/BUILD
+++ b/tensorflow/python/compiler/mlir/BUILD
@@ -10,7 +10,7 @@
srcs = ["mlir.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python:pywrap_mlir",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/python/compiler/mlir/mlir.py b/tensorflow/python/compiler/mlir/mlir.py
index 3766b84..84d23c3 100644
--- a/tensorflow/python/compiler/mlir/mlir.py
+++ b/tensorflow/python/compiler/mlir/mlir.py
@@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function
-from tensorflow.python import pywrap_tensorflow as import_graphdef
+from tensorflow.python import pywrap_mlir
from tensorflow.python.util.tf_export import tf_export
@@ -38,4 +38,4 @@
Raises a RuntimeError on error.
"""
- return import_graphdef.import_graphdef(graph_def, pass_pipeline)
+ return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc
new file mode 100644
index 0000000..3e83566
--- /dev/null
+++ b/tensorflow/python/mlir_wrapper.cc
@@ -0,0 +1,67 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "include/pybind11/pybind11.h"
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/compiler/mlir/python/mlir.h"
+#include "tensorflow/python/lib/core/pybind11_lib.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
+
+PYBIND11_MODULE(_pywrap_mlir, m) {
+ m.def("ImportGraphDef",
+ [](const std::string &graphdef, const std::string &pass_pipeline) {
+ tensorflow::Safe_TF_StatusPtr status =
+ tensorflow::make_safe(TF_NewStatus());
+ std::string output =
+ tensorflow::ImportGraphDef(graphdef, pass_pipeline, status.get());
+ tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+ return output;
+ });
+
+ m.def("ExperimentalConvertSavedModelToMlir",
+ [](const std::string &saved_model_path,
+ const std::string &exported_names, bool show_debug_info) {
+ tensorflow::Safe_TF_StatusPtr status =
+ tensorflow::make_safe(TF_NewStatus());
+ std::string output = tensorflow::ExperimentalConvertSavedModelToMlir(
+ saved_model_path, exported_names, show_debug_info, status.get());
+ tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+ return output;
+ });
+
+ m.def("ExperimentalConvertSavedModelV1ToMlir",
+ [](const std::string &saved_model_path, const std::string &tags,
+ bool show_debug_info) {
+ tensorflow::Safe_TF_StatusPtr status =
+ tensorflow::make_safe(TF_NewStatus());
+ std::string output =
+ tensorflow::ExperimentalConvertSavedModelV1ToMlir(
+ saved_model_path, tags, show_debug_info, status.get());
+ tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+ return output;
+ });
+
+ m.def("ExperimentalRunPassPipeline",
+ [](const std::string &mlir_txt, const std::string &pass_pipeline,
+ bool show_debug_info) {
+ tensorflow::Safe_TF_StatusPtr status =
+ tensorflow::make_safe(TF_NewStatus());
+ std::string output = tensorflow::ExperimentalRunPassPipeline(
+ mlir_txt, pass_pipeline, show_debug_info, status.get());
+ tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+ return output;
+ });
+};
diff --git a/tensorflow/python/pywrap_mlir.py b/tensorflow/python/pywrap_mlir.py
new file mode 100644
index 0000000..73c69a8
--- /dev/null
+++ b/tensorflow/python/pywrap_mlir.py
@@ -0,0 +1,49 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python module for MLIR functions exported by pybind11."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=invalid-import-order, g-bad-import-order, wildcard-import, unused-import, undefined-variable
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python._pywrap_mlir import *
+
+
+def import_graphdef(graphdef, pass_pipeline):
+ return ImportGraphDef(
+ str(graphdef).encode('utf-8'),
+ pass_pipeline.encode('utf-8'))
+
+
+def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
+ show_debug_info):
+ return ExperimentalConvertSavedModelToMlir(
+ str(saved_model_path).encode('utf-8'),
+ str(exported_names).encode('utf-8'), show_debug_info)
+
+
+def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
+ show_debug_info):
+ return ExperimentalConvertSavedModelV1ToMlir(
+ str(saved_model_path).encode('utf-8'),
+ str(tags).encode('utf-8'), show_debug_info)
+
+
+def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
+ return ExperimentalRunPassPipeline(
+ mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'),
+ show_debug_info)
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 6ac9dcf..f3deea2 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -24,8 +24,6 @@
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
-%include "tensorflow/compiler/mlir/python/mlir.i"
-
// TODO(slebedev): This is a temporary workaround for projects implicitly
// relying on TensorFlow exposing tensorflow::Status.
%unignoreall
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index e657edc..694a1a47 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -189,3 +189,9 @@
[context] # tfe
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
+
+[mlir] # mlir
+tensorflow::ExperimentalRunPassPipeline
+tensorflow::ExperimentalConvertSavedModelV1ToMlir
+tensorflow::ExperimentalConvertSavedModelToMlir
+tensorflow::ImportGraphDef