Export the GetPythonWrappers functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 281122988
Change-Id: I47eb899954f8e4728fb9f69e8f0b3eb95fc33257
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 9de7eda..9ca3816 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -103,6 +103,7 @@
":_pywrap_events_writer",
":_pywrap_kernel_registry",
":_pywrap_py_exception_registry",
+ ":_pywrap_python_op_gen",
":_pywrap_quantize_training",
":_pywrap_stacktrace_handler",
":_pywrap_stat_summarizer",
@@ -903,6 +904,29 @@
alwayslink = 1,
)
+cc_header_only_library(
+ name = "python_op_gen_headers_lib",
+ extra_deps = [
+ "//tensorflow/core:protos_all_cc",
+ ],
+ deps = [
+ ":python_op_gen",
+ ],
+)
+
+tf_python_pybind_extension(
+ name = "_pywrap_python_op_gen",
+ srcs = ["framework/python_op_gen_wrapper.cc"],
+ module_name = "_pywrap_python_op_gen",
+ deps = [
+ ":pybind11_absl",
+ ":pybind11_lib",
+ ":python_op_gen_headers_lib",
+ "//third_party/python_runtime:headers",
+ "@pybind11",
+ ],
+)
+
cc_library(
name = "python_op_gen_main",
srcs = ["framework/python_op_gen_main.cc"],
@@ -991,6 +1015,7 @@
":_pywrap_kernel_registry",
":_pywrap_py_exception_registry",
":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
+ ":_pywrap_python_op_gen",
":_pywrap_quantize_training",
":_pywrap_stacktrace_handler",
":_pywrap_stat_summarizer",
@@ -5272,7 +5297,6 @@
swig_includes = [
"client/tf_session.i",
"client/tf_sessionrun_wrapper.i",
- "framework/python_op_gen.i",
"grappler/cluster.i",
"grappler/cost_analyzer.i",
"grappler/item.i",
@@ -5394,6 +5418,7 @@
":ndarray_tensor", # checkpoint_reader
":numpy_lib", # checkpoint_reader
":safe_ptr", # checkpoint_reader
+ "//tensorflow/python:python_op_gen", # python_op_gen
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
],
outs = ["pybind_symbol_target_libs_file.txt"],
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 41f85c1..10034a9 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -53,6 +53,7 @@
from tensorflow.python import _pywrap_util_port
from tensorflow.python import _pywrap_stat_summarizer
from tensorflow.python import _pywrap_py_exception_registry
+from tensorflow.python import _pywrap_python_op_gen
from tensorflow.python import _pywrap_kernel_registry
from tensorflow.python import _pywrap_quantize_training
from tensorflow.python import _pywrap_transform_graph
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index f173304..1306e94 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -25,6 +25,7 @@
import platform
import sys
+from tensorflow.python import _pywrap_python_op_gen
from tensorflow.python import pywrap_tensorflow as py_tf
from tensorflow.python.lib.io import file_io
from tensorflow.python.util import deprecation
@@ -56,7 +57,8 @@
"""
lib_handle = py_tf.TF_LoadLibrary(library_filename)
try:
- wrappers = py_tf.GetPythonWrappers(py_tf.TF_GetOpList(lib_handle))
+ wrappers = _pywrap_python_op_gen.GetPythonWrappers(
+ py_tf.TF_GetOpList(lib_handle))
finally:
# Delete the library handle to release any memory held in C
# that are no longer needed.
@@ -156,4 +158,3 @@
errno.ENOENT,
'The file or folder to load kernel libraries from does not exist.',
library_location)
-
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index 3941ec1..093afc1 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -1073,9 +1073,8 @@
}
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
- string op_list_str(op_list_buf, op_list_len);
OpList ops;
- ops.ParseFromString(op_list_str);
+ ops.ParseFromArray(op_list_buf, op_list_len);
ApiDefMap api_def_map(ops);
return GetPythonOps(ops, api_def_map, {});
diff --git a/tensorflow/python/framework/python_op_gen.i b/tensorflow/python/framework/python_op_gen.i
deleted file mode 100644
index 26ec4e8..0000000
--- a/tensorflow/python/framework/python_op_gen.i
+++ /dev/null
@@ -1,41 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-%include "tensorflow/python/platform/base.i"
-
-%{
-#include "tensorflow/python/framework/python_op_gen.h"
-%}
-
-// Input typemap for GetPythonWrappers.
-// Accepts a python object of 'bytes' type, and converts it to
-// a const char* pointer and size_t length. The default typemap
-// going from python bytes to const char* tries to decode the
-// contents from utf-8 to unicode for Python version >= 3, but
-// we want the bytes to be uninterpreted.
-%typemap(in) (const char* op_list_buf, size_t op_list_len) {
- char* c_string;
- Py_ssize_t py_size;
- if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
- SWIG_fail;
- }
- $1 = c_string;
- $2 = static_cast<size_t>(py_size);
-}
-
-
-%ignoreall;
-%unignore tensorflow::GetPythonWrappers;
-%include "tensorflow/python/framework/python_op_gen.h"
diff --git a/tensorflow/python/framework/python_op_gen_wrapper.cc b/tensorflow/python/framework/python_op_gen_wrapper.cc
new file mode 100644
index 0000000..9418432
--- /dev/null
+++ b/tensorflow/python/framework/python_op_gen_wrapper.cc
@@ -0,0 +1,34 @@
+
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <Python.h>
+
+#include "include/pybind11/pybind11.h"
+#include "tensorflow/python/framework/python_op_gen.h"
+
+namespace py = pybind11;
+
+PYBIND11_MODULE(_pywrap_python_op_gen, m) {
+ m.def("GetPythonWrappers", [](py::bytes input) {
+ char* c_string;
+ Py_ssize_t py_size;
+ if (PyBytes_AsStringAndSize(input.ptr(), &c_string, &py_size) == -1) {
+ throw py::error_already_set();
+ }
+ return py::bytes(
+ tensorflow::GetPythonWrappers(c_string, static_cast<size_t>(py_size)));
+ });
+};
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index ac52688..7985b9c 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -34,8 +34,4 @@
%include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i"
-%include "tensorflow/python/framework/python_op_gen.i"
-
-
-
%include "tensorflow/compiler/mlir/python/mlir.i"
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index 7ef7a54..30a844d 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -112,3 +112,5 @@
tensorflow::detail::PyDecrefDeleter
tensorflow::make_safe
+[python_op_gen] # python_op_gen
+tensorflow::GetPythonWrappers