Export the MLIR classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

PiperOrigin-RevId: 292076160
Change-Id: I62bf3aac988c3ce4e18aa01ee49d8aa9ffde383d
diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD
index 5291cf3..07405c0 100644
--- a/tensorflow/compiler/mlir/python/BUILD
+++ b/tensorflow/compiler/mlir/python/BUILD
@@ -3,9 +3,29 @@
     licenses = ["notice"],  # Apache 2.0
 )
 
-exports_files(
-    ["mlir.i"],
+cc_library(
+    name = "mlir",
+    srcs = ["mlir.cc"],
+    hdrs = ["mlir.h"],
+    deps = [
+        "//tensorflow/c:tf_status",
+        "//tensorflow/c:tf_status_helper",
+        "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+        "//tensorflow/compiler/mlir/tensorflow:error_util",
+        "//tensorflow/compiler/mlir/tensorflow:import_utils",
+        "@llvm-project//llvm:support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:Pass",
+    ],
+)
+
+filegroup(
+    name = "pywrap_mlir_hdrs",
+    srcs = [
+        "mlir.h",
+    ],
     visibility = [
-        "//tensorflow/python:__subpackages__",
+        "//tensorflow/python:__pkg__",
     ],
 )
diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc
new file mode 100644
index 0000000..e6ac78b
--- /dev/null
+++ b/tensorflow/compiler/mlir/python/mlir.cc
@@ -0,0 +1,157 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Parser.h"  // TF:llvm-project
+#include "mlir/Pass/PassManager.h"  // TF:llvm-project
+#include "mlir/Pass/PassRegistry.h"  // TF:llvm-project
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
+
+namespace tensorflow {
+
+std::string ImportGraphDef(const std::string &proto,
+                           const std::string &pass_pipeline,
+                           TF_Status *status) {
+  GraphDef graphdef;
+  auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
+  if (!s.ok()) {
+    Set_TF_Status_from_Status(status, s);
+    return "// error";
+  }
+  GraphDebugInfo debug_info;
+  GraphImportConfig specs;
+  mlir::MLIRContext context;
+  auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
+  if (!module.ok()) {
+    Set_TF_Status_from_Status(status, module.status());
+    return "// error";
+  }
+
+  // Run the pass_pipeline on the module if not empty.
+  if (!pass_pipeline.empty()) {
+    mlir::PassManager pm(&context);
+    std::string error;
+    llvm::raw_string_ostream error_stream(error);
+    if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
+      TF_SetStatus(status, TF_INVALID_ARGUMENT,
+                   ("Invalid pass_pipeline: " + error_stream.str()).c_str());
+      return "// error";
+    }
+
+    mlir::StatusScopedDiagnosticHandler statusHandler(&context);
+    if (failed(pm.run(*module.ValueOrDie()))) {
+      Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
+      return "// error";
+    }
+  }
+  return MlirModuleToString(*module.ConsumeValueOrDie());
+}
+
+std::string ExperimentalConvertSavedModelToMlir(
+    const std::string &saved_model_path, const std::string &exported_names_str,
+    bool show_debug_info, TF_Status *status) {
+  // Load the saved model into a SavedModelV2Bundle.
+
+  tensorflow::SavedModelV2Bundle bundle;
+  auto load_status =
+      tensorflow::SavedModelV2Bundle::Load(saved_model_path, &bundle);
+  if (!load_status.ok()) {
+    Set_TF_Status_from_Status(status, load_status);
+    return "// error";
+  }
+
+  // Convert the SavedModelV2Bundle to an MLIR module.
+
+  std::vector<string> exported_names =
+      absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
+  mlir::MLIRContext context;
+  auto module_or = ConvertSavedModelToMlir(
+      &bundle, &context, absl::Span<std::string>(exported_names));
+  if (!module_or.status().ok()) {
+    Set_TF_Status_from_Status(status, module_or.status());
+    return "// error";
+  }
+
+  return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
+}
+
+std::string ExperimentalConvertSavedModelV1ToMlir(
+    const std::string &saved_model_path, const std::string &tags,
+    bool show_debug_info, TF_Status *status) {
+  // Load the saved model into a SavedModelBundle.
+
+  std::unordered_set<string> tag_set =
+      absl::StrSplit(tags, ',', absl::SkipEmpty());
+
+  tensorflow::SavedModelBundle bundle;
+  auto load_status =
+      tensorflow::LoadSavedModel({}, {}, saved_model_path, tag_set, &bundle);
+  if (!load_status.ok()) {
+    Set_TF_Status_from_Status(status, load_status);
+    return "// error";
+  }
+
+  // Convert the SavedModelBundle to an MLIR module.
+
+  mlir::MLIRContext context;
+  auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
+  if (!module_or.status().ok()) {
+    Set_TF_Status_from_Status(status, module_or.status());
+    return "// error";
+  }
+
+  return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
+}
+
+std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
+                                        const std::string &pass_pipeline,
+                                        bool show_debug_info,
+                                        TF_Status *status) {
+  mlir::MLIRContext context;
+  mlir::OwningModuleRef module;
+  {
+    mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
+    module = mlir::parseSourceString(mlir_txt, &context);
+    if (!module) {
+      Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
+      return "// error";
+    }
+  }
+
+  // Run the pass_pipeline on the module.
+  mlir::PassManager pm(&context);
+  std::string error;
+  llvm::raw_string_ostream error_stream(error);
+  if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
+    TF_SetStatus(status, TF_INVALID_ARGUMENT,
+                 ("Invalid pass_pipeline: " + error_stream.str()).c_str());
+    return "// error";
+  }
+
+  mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
+  if (failed(pm.run(*module))) {
+    Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
+    return "// error";
+  }
+  return MlirModuleToString(*module, show_debug_info);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h
new file mode 100644
index 0000000..b85b409
--- /dev/null
+++ b/tensorflow/compiler/mlir/python/mlir.h
@@ -0,0 +1,67 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Functions for getting information about kernels registered in the binary.
+// Migrated from previous SWIG file (mlir.i) authored by aminim@.
+#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
+#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
+
+#include <string>
+
+#include "tensorflow/c/tf_status.h"
+
+namespace tensorflow {
+
+// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
+// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
+// returning it as a string.
+// This is an early experimental API, ideally we should return a wrapper object
+// around a Python binding to the MLIR module.
+std::string ImportGraphDef(const std::string &proto,
+                           const std::string &pass_pipeline, TF_Status *status);
+
+// Load a SavedModel and return a textual MLIR string corresponding to it.
+//
+// Args:
+//   saved_model_path: File path from which to load the SavedModel.
+//   exported_names_str: Comma-separated list of names to export.
+//                       Empty means "export all".
+//
+// Returns:
+//   A string of textual MLIR representing the raw imported SavedModel.
+std::string ExperimentalConvertSavedModelToMlir(
+    const std::string &saved_model_path, const std::string &exported_names_str,
+    bool show_debug_info, TF_Status *status);
+
+// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
+//
+// Args:
+//   saved_model_path: File path from which to load the SavedModel.
+//   tags: Tags to identify MetaGraphDef that need to be loaded.
+//
+// Returns:
+//   A string of textual MLIR representing the raw imported SavedModel.
+std::string ExperimentalConvertSavedModelV1ToMlir(
+    const std::string &saved_model_path, const std::string &tags,
+    bool show_debug_info, TF_Status *status);
+
+std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
+                                        const std::string &pass_pipeline,
+                                        bool show_debug_info,
+                                        TF_Status *status);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
diff --git a/tensorflow/compiler/mlir/python/mlir.i b/tensorflow/compiler/mlir/python/mlir.i
deleted file mode 100644
index b1d5328..0000000
--- a/tensorflow/compiler/mlir/python/mlir.i
+++ /dev/null
@@ -1,252 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-%include "tensorflow/python/platform/base.i"
-
-%{
-
-#include "mlir/Parser.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Pass/PassManager.h"
-#include "llvm/Support/raw_ostream.h"
-#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
-
-namespace tensorflow {
-namespace swig {
-
-// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
-// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
-// returning it as a string.
-// This is an early experimental API, ideally we should return a wrapper object
-// around a Python binding to the MLIR module.
-string ImportGraphDef(const string &proto, const string &pass_pipeline, TF_Status* status) {
-  GraphDef graphdef;
-  auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
-  if (!s.ok()) {
-    Set_TF_Status_from_Status(status, s);
-    return "// error";
-  }
-  GraphDebugInfo debug_info;
-  GraphImportConfig specs;
-  mlir::MLIRContext context;
-  auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
-  if (!module.ok()) {
-    Set_TF_Status_from_Status(status, module.status());
-    return "// error";
-  }
-
-  // Run the pass_pipeline on the module if not empty.
-  if (!pass_pipeline.empty()) {
-    mlir::PassManager pm(&context);
-    std::string error;
-    llvm::raw_string_ostream error_stream(error);
-    if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
-      TF_SetStatus(status, TF_INVALID_ARGUMENT,
-                   ("Invalid pass_pipeline: " + error_stream.str()).c_str());
-      return "// error";
-    }
-
-    mlir::StatusScopedDiagnosticHandler statusHandler(&context);
-    if (failed(pm.run(*module.ValueOrDie()))) {
-      Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
-      return "// error";
-    }
-  }
-  return MlirModuleToString(*module.ConsumeValueOrDie());
-}
-
-// Load a SavedModel and return a textual MLIR string corresponding to it.
-//
-// Args:
-//   saved_model_path: File path from which to load the SavedModel.
-//   exported_names_str: Comma-separated list of names to export.
-//                       Empty means "export all".
-//
-// Returns:
-//   A string of textual MLIR representing the raw imported SavedModel.
-string ExperimentalConvertSavedModelToMlir(
-    const string &saved_model_path,
-    const string &exported_names_str,
-    bool show_debug_info,
-    TF_Status* status) {
-  // Load the saved model into a SavedModelV2Bundle.
-
-  tensorflow::SavedModelV2Bundle bundle;
-  auto load_status = tensorflow::SavedModelV2Bundle::Load(
-      saved_model_path, &bundle);
-  if (!load_status.ok()) {
-    Set_TF_Status_from_Status(status, load_status);
-    return "// error";
-  }
-
-  // Convert the SavedModelV2Bundle to an MLIR module.
-
-  std::vector<string> exported_names =
-      absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
-  mlir::MLIRContext context;
-  auto module_or = ConvertSavedModelToMlir(&bundle, &context,
-      absl::Span<std::string>(exported_names));
-  if (!module_or.status().ok()) {
-    Set_TF_Status_from_Status(status, module_or.status());
-    return "// error";
-  }
-
-  return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
-}
-
-// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
-//
-// Args:
-//   saved_model_path: File path from which to load the SavedModel.
-//   tags: Tags to identify MetaGraphDef that need to be loaded.
-//
-// Returns:
-//   A string of textual MLIR representing the raw imported SavedModel.
-string ExperimentalConvertSavedModelV1ToMlir(
-    const string &saved_model_path,
-    const string &tags,
-    bool show_debug_info,
-    TF_Status* status) {
-  // Load the saved model into a SavedModelBundle.
-
-  std::unordered_set<string> tag_set
-      = absl::StrSplit(tags, ',', absl::SkipEmpty());
-
-  tensorflow::SavedModelBundle bundle;
-  auto load_status = tensorflow::LoadSavedModel(
-      {}, {},
-      saved_model_path, tag_set, &bundle);
-  if (!load_status.ok()) {
-    Set_TF_Status_from_Status(status, load_status);
-    return "// error";
-  }
-
-  // Convert the SavedModelBundle to an MLIR module.
-
-  mlir::MLIRContext context;
-  auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
-  if (!module_or.status().ok()) {
-    Set_TF_Status_from_Status(status, module_or.status());
-    return "// error";
-  }
-
-  return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
-}
-
-
-string ExperimentalRunPassPipeline(
-    const string &mlir_txt,
-    const string &pass_pipeline,
-    bool show_debug_info,
-    TF_Status* status) {
-  mlir::MLIRContext context;
-  mlir::OwningModuleRef module;
-  {
-    mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
-    module = mlir::parseSourceString(mlir_txt, &context);
-    if (!module) {
-      Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
-      return "// error";
-    }
-  }
-
-  // Run the pass_pipeline on the module.
-  mlir::PassManager pm(&context);
-  std::string error;
-  llvm::raw_string_ostream error_stream(error);
-  if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
-    TF_SetStatus(status, TF_INVALID_ARGUMENT,
-                 ("Invalid pass_pipeline: " + error_stream.str()).c_str());
-    return "// error";
-  }
-
-  mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
-  if (failed(pm.run(*module))) {
-    Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
-    return "// error";
-  }
-  return MlirModuleToString(*module, show_debug_info);
-}
-
-}  // namespace swig
-}  // namespace tensorflow
-
-%}
-
-%ignoreall
-
-%unignore tensorflow;
-%unignore tensorflow::swig;
-%unignore tensorflow::swig::ImportGraphDef;
-%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
-%unignore tensorflow::swig::ExperimentalConvertSavedModelV1ToMlir;
-%unignore tensorflow::swig::ExperimentalRunPassPipeline;
-
-// Wrap this function
-namespace tensorflow {
-namespace swig {
-static string ImportGraphDef(const string &graphdef,
-                             const string &pass_pipeline,
-                             TF_Status* status);
-static string ExperimentalConvertSavedModelToMlir(
-    const string &saved_model_path,
-    const string &exported_names,
-    bool show_debug_info,
-    TF_Status* status);
-static string ExperimentalConvertSavedModelV1ToMlir(
-    const string &saved_model_path,
-    const string &tags,
-    bool show_debug_info,
-    TF_Status* status);
-static string ExperimentalRunPassPipeline(
-    const string &mlir_txt,
-    const string &pass_pipeline,
-    bool show_debug_info,
-    TF_Status* status);
-}  // namespace swig
-}  // namespace tensorflow
-
-%insert("python") %{
-def import_graphdef(graphdef, pass_pipeline):
-  return ImportGraphDef(str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8')).decode('utf-8');
-
-def experimental_convert_saved_model_to_mlir(saved_model_path,
-                                             exported_names,
-                                             show_debug_info):
-  return ExperimentalConvertSavedModelToMlir(
-    str(saved_model_path).encode('utf-8'),
-    str(exported_names).encode('utf-8'),
-    show_debug_info
-  ).decode('utf-8');
-
-def experimental_convert_saved_model_v1_to_mlir(saved_model_path,
-                                                tags, show_debug_info):
-  return ExperimentalConvertSavedModelV1ToMlir(
-    str(saved_model_path).encode('utf-8'),
-    str(tags).encode('utf-8'),
-    show_debug_info
-  ).decode('utf-8');
-
-def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
-  return ExperimentalRunPassPipeline(
-    mlir_txt.encode('utf-8'),
-    pass_pipeline.encode('utf-8'),
-    show_debug_info
-  ).decode('utf-8');
-%}
-
-%unignoreall
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
index fd8221c..de61800 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py
@@ -29,7 +29,7 @@
 from absl import logging
 import tensorflow.compat.v2 as tf
 
-from tensorflow.python import pywrap_tensorflow
+from tensorflow.python import pywrap_mlir  # pylint: disable=g-direct-tensorflow-import
 
 # Use /tmp to make debugging the tests easier (see README.md)
 flags.DEFINE_string('save_model_path', '',
@@ -84,13 +84,13 @@
     tf.saved_model.save(
         create_module_fn(), save_model_path, options=save_options)
     logging.info('Saved model to: %s', save_model_path)
-    mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir(
+    mlir = pywrap_mlir.experimental_convert_saved_model_to_mlir(
         save_model_path, ','.join(exported_names), show_debug_info)
     # We don't strictly need this, but it serves as a handy sanity check
     # for that API, which is otherwise a bit annoying to test.
     # The canonicalization shouldn't affect these tests in any way.
-    mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
-        mlir, 'canonicalize', show_debug_info)
+    mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize',
+                                                      show_debug_info)
     print(mlir)
 
   app.run(app_main)
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py
index 35858d2..fb29470 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py
@@ -28,7 +28,7 @@
 from absl import logging
 import tensorflow.compat.v1 as tf
 
-from tensorflow.python import pywrap_tensorflow
+from tensorflow.python import pywrap_mlir  # pylint: disable=g-direct-tensorflow-import
 
 # Use /tmp to make debugging the tests easier (see README.md)
 flags.DEFINE_string('save_model_path', '', 'Path to save the model to.')
@@ -80,14 +80,15 @@
     builder.save()
 
     logging.info('Saved model to: %s', save_model_path)
-    mlir = pywrap_tensorflow.experimental_convert_saved_model_v1_to_mlir(
+    mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
         save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
         show_debug_info)
     # We don't strictly need this, but it serves as a handy sanity check
     # for that API, which is otherwise a bit annoying to test.
     # The canonicalization shouldn't affect these tests in any way.
-    mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
-        mlir, 'tf-standard-pipeline', show_debug_info)
+    mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir,
+                                                      'tf-standard-pipeline',
+                                                      show_debug_info)
     print(mlir)
 
   app.run(app_main)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 548fc8c..4db47fe 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1130,6 +1130,7 @@
         ":platform",
         ":pywrap_tensorflow",
         ":pywrap_tfe",
+        ":pywrap_mlir",
         ":random_seed",
         ":sparse_tensor",
         ":tensor_spec",
@@ -5544,7 +5545,6 @@
         "grappler/tf_optimizer.i",
         "lib/core/strings.i",
         "platform/base.i",
-        "//tensorflow/compiler/mlir/python:mlir.i",
     ],
     # add win_def_file for pywrap_tensorflow
     win_def_file = select({
@@ -5572,7 +5572,6 @@
         "//tensorflow/c:tf_status_helper",
         "//tensorflow/c/eager:c_api",
         "//tensorflow/c/eager:c_api_experimental",
-        "//tensorflow/compiler/mlir:passes",
         "//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
         "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
         "//tensorflow/core/distributed_runtime/rpc:grpc_session",
@@ -5593,6 +5592,7 @@
         "//tensorflow/lite/toco/python:toco_python_api",
         "//tensorflow/python/eager:pywrap_tfe_lib",
         "//tensorflow/core/util/tensor_bundle",
+        "//tensorflow/compiler/mlir/python:mlir",
     ] + (tf_additional_lib_deps() +
          tf_additional_plugin_deps()) + if_ngraph([
         "@ngraph_tf//:ngraph_tf",
@@ -5631,6 +5631,7 @@
     "//tensorflow/core/common_runtime/eager:context",  # tfe
     "//tensorflow/core/profiler/lib:profiler_session",  # tfe
     "//tensorflow/c:tf_status_helper",  # tfe
+    "//tensorflow/compiler/mlir/python:mlir",  # mlir
 ]
 
 # Filter the DEF file to reduce the number of symbols to 64K or less.
@@ -7681,6 +7682,34 @@
 )
 
 py_library(
+    name = "pywrap_mlir",
+    srcs = ["pywrap_mlir.py"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":_pywrap_mlir",
+        ":pywrap_tensorflow",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_mlir",
+    srcs = ["mlir_wrapper.cc"],
+    hdrs = [
+        "lib/core/safe_ptr.h",
+        "//tensorflow/c:headers",
+        "//tensorflow/c/eager:headers",
+        "//tensorflow/compiler/mlir/python:pywrap_mlir_hdrs",
+    ],
+    module_name = "_pywrap_mlir",
+    deps = [
+        ":pybind11_lib",
+        ":pybind11_status",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
+py_library(
     name = "pywrap_tfe",
     srcs = ["pywrap_tfe.py"],
     visibility = ["//visibility:public"],
diff --git a/tensorflow/python/compiler/mlir/BUILD b/tensorflow/python/compiler/mlir/BUILD
index ee191e4..fe59213 100644
--- a/tensorflow/python/compiler/mlir/BUILD
+++ b/tensorflow/python/compiler/mlir/BUILD
@@ -10,7 +10,7 @@
     srcs = ["mlir.py"],
     srcs_version = "PY2AND3",
     deps = [
-        "//tensorflow/python:pywrap_tensorflow",
+        "//tensorflow/python:pywrap_mlir",
         "//tensorflow/python:util",
     ],
 )
diff --git a/tensorflow/python/compiler/mlir/mlir.py b/tensorflow/python/compiler/mlir/mlir.py
index 3766b84..84d23c3 100644
--- a/tensorflow/python/compiler/mlir/mlir.py
+++ b/tensorflow/python/compiler/mlir/mlir.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python import pywrap_tensorflow as import_graphdef
+from tensorflow.python import pywrap_mlir
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -38,4 +38,4 @@
     Raises a RuntimeError on error.
 
   """
-  return import_graphdef.import_graphdef(graph_def, pass_pipeline)
+  return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc
new file mode 100644
index 0000000..3e83566
--- /dev/null
+++ b/tensorflow/python/mlir_wrapper.cc
@@ -0,0 +1,67 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "include/pybind11/pybind11.h"
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/compiler/mlir/python/mlir.h"
+#include "tensorflow/python/lib/core/pybind11_lib.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
+
+PYBIND11_MODULE(_pywrap_mlir, m) {
+  m.def("ImportGraphDef",
+        [](const std::string &graphdef, const std::string &pass_pipeline) {
+          tensorflow::Safe_TF_StatusPtr status =
+              tensorflow::make_safe(TF_NewStatus());
+          std::string output =
+              tensorflow::ImportGraphDef(graphdef, pass_pipeline, status.get());
+          tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+          return output;
+        });
+
+  m.def("ExperimentalConvertSavedModelToMlir",
+        [](const std::string &saved_model_path,
+           const std::string &exported_names, bool show_debug_info) {
+          tensorflow::Safe_TF_StatusPtr status =
+              tensorflow::make_safe(TF_NewStatus());
+          std::string output = tensorflow::ExperimentalConvertSavedModelToMlir(
+              saved_model_path, exported_names, show_debug_info, status.get());
+          tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+          return output;
+        });
+
+  m.def("ExperimentalConvertSavedModelV1ToMlir",
+        [](const std::string &saved_model_path, const std::string &tags,
+           bool show_debug_info) {
+          tensorflow::Safe_TF_StatusPtr status =
+              tensorflow::make_safe(TF_NewStatus());
+          std::string output =
+              tensorflow::ExperimentalConvertSavedModelV1ToMlir(
+                  saved_model_path, tags, show_debug_info, status.get());
+          tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+          return output;
+        });
+
+  m.def("ExperimentalRunPassPipeline",
+        [](const std::string &mlir_txt, const std::string &pass_pipeline,
+           bool show_debug_info) {
+          tensorflow::Safe_TF_StatusPtr status =
+              tensorflow::make_safe(TF_NewStatus());
+          std::string output = tensorflow::ExperimentalRunPassPipeline(
+              mlir_txt, pass_pipeline, show_debug_info, status.get());
+          tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+          return output;
+        });
+};
diff --git a/tensorflow/python/pywrap_mlir.py b/tensorflow/python/pywrap_mlir.py
new file mode 100644
index 0000000..73c69a8
--- /dev/null
+++ b/tensorflow/python/pywrap_mlir.py
@@ -0,0 +1,49 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python module for MLIR functions exported by pybind11."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=invalid-import-order, g-bad-import-order, wildcard-import, unused-import, undefined-variable
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python._pywrap_mlir import *
+
+
+def import_graphdef(graphdef, pass_pipeline):
+  return ImportGraphDef(
+      str(graphdef).encode('utf-8'),
+      pass_pipeline.encode('utf-8'))
+
+
+def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
+                                             show_debug_info):
+  return ExperimentalConvertSavedModelToMlir(
+      str(saved_model_path).encode('utf-8'),
+      str(exported_names).encode('utf-8'), show_debug_info)
+
+
+def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
+                                                show_debug_info):
+  return ExperimentalConvertSavedModelV1ToMlir(
+      str(saved_model_path).encode('utf-8'),
+      str(tags).encode('utf-8'), show_debug_info)
+
+
+def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
+  return ExperimentalRunPassPipeline(
+      mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'),
+      show_debug_info)
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 6ac9dcf..f3deea2 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -24,8 +24,6 @@
 %include "tensorflow/python/grappler/tf_optimizer.i"
 %include "tensorflow/python/grappler/cost_analyzer.i"
 
-%include "tensorflow/compiler/mlir/python/mlir.i"
-
 // TODO(slebedev): This is a temporary workaround for projects implicitly
 // relying on TensorFlow exposing tensorflow::Status.
 %unignoreall
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index e657edc..694a1a47 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -189,3 +189,9 @@
 
 [context] # tfe
 tensorflow::EagerContext::WaitForAndCloseRemoteContexts
+
+[mlir] # mlir
+tensorflow::ExperimentalRunPassPipeline
+tensorflow::ExperimentalConvertSavedModelV1ToMlir
+tensorflow::ExperimentalConvertSavedModelToMlir
+tensorflow::ImportGraphDef