Export the toco functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

We are adding `toco::` to the exported namespaces for pywrap_tensorflow's shared object. A few downstream modules also require a previous import of pywrap tensorflow, because the wrapper depends on the shared library. See https://github.com/tensorflow/tensorflow/pull/31955 for additional information.

PiperOrigin-RevId: 276096778
Change-Id: I042f488c36b00818b2344fb39c36cad97cee6eb8
diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD
index 7540a4d..b4e0dd0 100644
--- a/tensorflow/lite/python/BUILD
+++ b/tensorflow/lite/python/BUILD
@@ -209,6 +209,7 @@
     ],
     srcs_version = "PY2AND3",
     deps = [
+        "//tensorflow/python:_pywrap_toco_api",
         "//tensorflow/python:pywrap_tensorflow",
         "//tensorflow/python:util",
     ],
diff --git a/tensorflow/lite/python/wrap_toco.py b/tensorflow/lite/python/wrap_toco.py
index e9b0176..213f31c 100644
--- a/tensorflow/lite/python/wrap_toco.py
+++ b/tensorflow/lite/python/wrap_toco.py
@@ -17,7 +17,11 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python import pywrap_tensorflow
+# We need to import pywrap_tensorflow prior to the toco wrapper.
+# pylint: disable=invalud-import-order,g-bad-import-order
+from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
+from tensorflow.python import _pywrap_toco_api
+
 
 # TODO(b/137402359): Remove lazy loading wrapper
 
@@ -25,7 +29,7 @@
 def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str,
                          debug_info_str, enable_mlir_converter):
   """Wraps TocoConvert with lazy loader."""
-  return pywrap_tensorflow.TocoConvert(
+  return _pywrap_toco_api.TocoConvert(
       model_flags_str,
       toco_flags_str,
       input_data_str,
@@ -36,4 +40,4 @@
 
 def wrapped_get_potentially_supported_ops():
   """Wraps TocoGetPotentiallySupportedOps with lazy loader."""
-  return pywrap_tensorflow.TocoGetPotentiallySupportedOps()
+  return _pywrap_toco_api.TocoGetPotentiallySupportedOps()
diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD
index 641ea7a..488f333 100644
--- a/tensorflow/lite/toco/python/BUILD
+++ b/tensorflow/lite/toco/python/BUILD
@@ -16,6 +16,16 @@
     ],
 )
 
+filegroup(
+    name = "toco_python_api_hdrs",
+    srcs = [
+        "toco_python_api.h",
+    ],
+    visibility = [
+        "//tensorflow/python:__subpackages__",
+    ],
+)
+
 cc_library(
     name = "toco_python_api",
     srcs = ["toco_python_api.cc"],
@@ -49,6 +59,7 @@
         if_false = [],
         if_true = ["//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer"],
     ),
+    alwayslink = True,
 )
 
 # Compatibility stub. Remove when internal customers moved.
@@ -61,7 +72,7 @@
         "//tensorflow/lite:__subpackages__",
     ],
     deps = [
-        "//tensorflow/python:pywrap_tensorflow",
+        "//tensorflow/python:_pywrap_toco_api",
     ],
 )
 
@@ -71,6 +82,7 @@
     python_version = "PY2",
     srcs_version = "PY2AND3",
     deps = [
+        "//tensorflow/python:_pywrap_toco_api",
         "//tensorflow/python:platform",
         "//tensorflow/python:pywrap_tensorflow",
     ],
@@ -89,10 +101,3 @@
         "no_pip",
     ],
 )
-
-exports_files(
-    ["toco.i"],
-    visibility = [
-        "//tensorflow/python:__subpackages__",
-    ],
-)
diff --git a/tensorflow/lite/toco/python/tensorflow_wrap_toco.py b/tensorflow/lite/toco/python/tensorflow_wrap_toco.py
index d70b043..ceef9b8 100644
--- a/tensorflow/lite/toco/python/tensorflow_wrap_toco.py
+++ b/tensorflow/lite/toco/python/tensorflow_wrap_toco.py
@@ -20,5 +20,5 @@
 # TODO(aselle): Remove once no clients internally depend on this.
 
 # pylint: disable=unused-import
-from tensorflow.python.pywrap_tensorflow import TocoConvert
+from tensorflow.python._pywrap_toco_api import TocoConvert
 # pylint: enable=unused-import
diff --git a/tensorflow/lite/toco/python/toco.i b/tensorflow/lite/toco/python/toco.i
deleted file mode 100644
index ee2b36c..0000000
--- a/tensorflow/lite/toco/python/toco.i
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-%include "std_string.i"
-
-%{
-#include "tensorflow/lite/toco/python/toco_python_api.h"
-%}
-
-// The TensorFlow exception handler releases the GIL with
-// Py_BEGIN_ALLOW_THREADS. Remove that because these function use the Python
-// API to decode inputs.
-%noexception toco::TocoConvert;
-%noexception toco::TocoGetPotentiallySupportedOps;
-
-namespace toco {
-
-// Convert a model represented in `input_contents`. `model_flags_proto`
-// describes model parameters. `toco_flags_proto` describes conversion
-// parameters (see relevant .protos for more information). Returns a string
-// representing the contents of the converted model. When extended_return
-// flag is set to true returns a dictionary that contains string representation
-// of the converted model and some statistics like arithmetic ops count.
-// `debug_info_str` contains the `GraphDebugInfo` proto. When
-// `enable_mlir_converter` is True, use MLIR-based conversion instead of
-// TOCO conversion.
-PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
-                        PyObject* toco_flags_proto_txt_raw,
-                        PyObject* input_contents_txt_raw,
-                        bool extended_return = false,
-                        PyObject* debug_info_txt_raw = nullptr,
-                        bool enable_mlir_converter = false);
-
-// Returns a list of names of all ops potentially supported by tflite.
-PyObject* TocoGetPotentiallySupportedOps();
-
-} // namespace toco
diff --git a/tensorflow/lite/toco/python/toco_from_protos.py b/tensorflow/lite/toco/python/toco_from_protos.py
index a669214..4da0bb9 100644
--- a/tensorflow/lite/toco/python/toco_from_protos.py
+++ b/tensorflow/lite/toco/python/toco_from_protos.py
@@ -19,7 +19,11 @@
 
 import argparse
 import sys
-from tensorflow.python import pywrap_tensorflow
+
+# We need to import pywrap_tensorflow prior to the toco wrapper.
+# pylint: disable=invalud-import-order,g-bad-import-order
+from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
+from tensorflow.python import _pywrap_toco_api
 from tensorflow.python.platform import app
 
 FLAGS = None
@@ -43,7 +47,7 @@
 
   enable_mlir_converter = FLAGS.enable_mlir_converter
 
-  output_str = pywrap_tensorflow.TocoConvert(
+  output_str = _pywrap_toco_api.TocoConvert(
       model_str,
       toco_str,
       input_str,
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 90ceda1..97c3210 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -583,6 +583,20 @@
     ],
 )
 
+tf_python_pybind_extension(
+    name = "_pywrap_toco_api",
+    srcs = [
+        "lite/toco_python_api_wrapper.cc",
+    ],
+    hdrs = ["//tensorflow/lite/toco/python:toco_python_api_hdrs"],
+    module_name = "_pywrap_toco_api",
+    deps = [
+        "//tensorflow/python:pybind11_lib",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
 cc_library(
     name = "cpp_python_util",
     srcs = ["util/util.cc"],
@@ -5188,7 +5202,6 @@
         "util/traceme.i",
         "util/transform_graph.i",
         "//tensorflow/compiler/mlir/python:mlir.i",
-        "//tensorflow/lite/toco/python:toco.i",
     ],
     # add win_def_file for pywrap_tensorflow
     win_def_file = select({
@@ -5292,6 +5305,7 @@
         "//tensorflow/core:core_cpu_impl",  # device_lib
         ":py_exception_registry",  # py_exception_registry
         ":kernel_registry",
+        "//tensorflow/lite/toco/python:toco_python_api",  # toco
     ],
     outs = ["pybind_symbol_target_libs_file.txt"],
     cmd = select({
diff --git a/tensorflow/python/lite/toco_python_api_wrapper.cc b/tensorflow/python/lite/toco_python_api_wrapper.cc
new file mode 100644
index 0000000..9199c79
--- /dev/null
+++ b/tensorflow/python/lite/toco_python_api_wrapper.cc
@@ -0,0 +1,57 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "include/pybind11/pybind11.h"
+#include "tensorflow/lite/toco/python/toco_python_api.h"
+#include "tensorflow/python/lib/core/pybind11_lib.h"
+
+namespace py = pybind11;
+
+PYBIND11_MODULE(_pywrap_toco_api, m) {
+  m.def(
+      "TocoConvert",
+      [](py::object model_flags_proto_txt_raw,
+         py::object toco_flags_proto_txt_raw, py::object input_contents_txt_raw,
+         bool extended_return, py::object debug_info_txt_raw,
+         bool enable_mlir_converter) {
+        return tensorflow::pyo_or_throw(toco::TocoConvert(
+            model_flags_proto_txt_raw.ptr(), toco_flags_proto_txt_raw.ptr(),
+            input_contents_txt_raw.ptr(), extended_return,
+            debug_info_txt_raw.ptr(), enable_mlir_converter));
+      },
+      py::arg("model_flags_proto_txt_raw"), py::arg("toco_flags_proto_txt_raw"),
+      py::arg("input_contents_txt_raw"), py::arg("extended_return") = false,
+      py::arg("debug_info_txt_raw") = py::none(),
+      py::arg("enable_mlir_converter") = false,
+      R"pbdoc(
+      Convert a model represented in `input_contents`. `model_flags_proto`
+      describes model parameters. `toco_flags_proto` describes conversion
+      parameters (see relevant .protos for more information). Returns a string
+      representing the contents of the converted model. When extended_return
+      flag is set to true returns a dictionary that contains string representation
+      of the converted model and some statistics like arithmetic ops count.
+      `debug_info_str` contains the `GraphDebugInfo` proto. When
+      `enable_mlir_converter` is True, tuse MLIR-based conversion instead of
+      TOCO conversion.
+    )pbdoc");
+  m.def(
+      "TocoGetPotentiallySupportedOps",
+      []() {
+        return tensorflow::pyo_or_throw(toco::TocoGetPotentiallySupportedOps());
+      },
+      R"pbdoc(
+      Returns a list of names of all ops potentially supported by tflite.
+    )pbdoc");
+}
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 75629bc..950add0 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -29,8 +29,6 @@
 
 %include "tensorflow/python/lib/core/bfloat16.i"
 
-%include "tensorflow/lite/toco/python/toco.i"
-
 %include "tensorflow/python/lib/io/file_io.i"
 
 %include "tensorflow/python/framework/python_op_gen.i"
diff --git a/tensorflow/tf_exported_symbols.lds b/tensorflow/tf_exported_symbols.lds
index 0463233..bed2ab4 100644
--- a/tensorflow/tf_exported_symbols.lds
+++ b/tensorflow/tf_exported_symbols.lds
@@ -1,4 +1,5 @@
 *tensorflow*
+*toco*
 *perftools*gputools*
 *tf_*
 *TF_*
diff --git a/tensorflow/tf_version_script.lds b/tensorflow/tf_version_script.lds
index 563d178..f74644b 100644
--- a/tensorflow/tf_version_script.lds
+++ b/tensorflow/tf_version_script.lds
@@ -1,6 +1,7 @@
 tensorflow {
   global:
     *tensorflow*;
+    *toco*;
     *perftools*gputools*;
     *TF_*;
     *TFE_*;
diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
index 30c1fd2..f803d94d 100644
--- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
+++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
@@ -70,6 +70,7 @@
                         r"^(TFE_\w*)$|"
                         r"nsync::|"
                         r"tensorflow::|"
+                        r"toco::|"
                         r"functor::|"
                         r"perftools::gputools")
 
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index 072c449..e1f4d19 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -83,3 +83,7 @@
 [kernel_registry] # kernel_registry
 tensorflow::swig::TryFindKernelClass
 
+[toco_python_api] # toco_python_api
+toco::TocoConvert
+toco::TocoGetPotentiallySupportedOps
+