Added pybind11 type casters for tensorflow::Status and tensorflow::StringView
These are needed to simplify migrating some SWIG sources to pybind, e.g.
python/client/events_writer.i and python/lib/io/file_io.i
PiperOrigin-RevId: 268728509
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7ae73b3..5d811df 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -370,6 +370,28 @@
)
cc_library(
+ name = "pybind11_absl",
+ hdrs = ["lib/core/pybind11_absl.h"],
+ features = ["-parse_headers"],
+ deps = [
+ "//tensorflow/core/platform:stringpiece",
+ "@pybind11",
+ ],
+)
+
+cc_library(
+ name = "pybind11_status",
+ hdrs = ["lib/core/pybind11_status.h"],
+ features = ["-parse_headers"],
+ deps = [
+ ":py_exception_registry",
+ "//tensorflow/core:lib",
+ "//third_party/python_runtime:headers",
+ "@pybind11",
+ ],
+)
+
+cc_library(
name = "kernel_registry",
srcs = ["util/kernel_registry.cc"],
hdrs = ["util/kernel_registry.h"],
diff --git a/tensorflow/python/lib/core/py_exception_registry.h b/tensorflow/python/lib/core/py_exception_registry.h
index 2b0f23b..d761ab4 100644
--- a/tensorflow/python/lib/core/py_exception_registry.h
+++ b/tensorflow/python/lib/core/py_exception_registry.h
@@ -18,6 +18,7 @@
#include <map>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/platform/logging.h"
#ifndef PyObject_HEAD
@@ -60,6 +61,10 @@
// called before using this function. `code` should not be TF_OK.
static PyObject* Lookup(TF_Code code);
+ static inline PyObject* Lookup(error::Code code) {
+ return Lookup(static_cast<TF_Code>(code));
+ }
+
private:
static PyExceptionRegistry* singleton_;
PyExceptionRegistry() = default;
diff --git a/tensorflow/python/lib/core/pybind11_absl.h b/tensorflow/python/lib/core/pybind11_absl.h
new file mode 100644
index 0000000..09f9681
--- /dev/null
+++ b/tensorflow/python/lib/core/pybind11_absl.h
@@ -0,0 +1,40 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_ABSL_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_ABSL_H_
+
+#include "pybind11/pybind11.h"
+#include "tensorflow/core/platform/stringpiece.h"
+
+#if !defined(PYBIND11_CPP17)
+
+namespace pybind11 {
+namespace detail {
+
+// Convert between tensorflow::StringPiece (aka absl::string_view) and Python.
+//
+// pybind11 supports std::string_view, and absl::string_view is meant to be a
+// drop-in replacement for std::string_view, so we can just use the built in
+// implementation.
+template <>
+struct type_caster<tensorflow::StringPiece>
+ : string_caster<tensorflow::StringPiece, true> {};
+
+} // namespace detail
+} // namespace pybind11
+
+#endif // !defined(PYBIND11_CPP17)
+#endif // TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_ABSL_H_
diff --git a/tensorflow/python/lib/core/pybind11_status.h b/tensorflow/python/lib/core/pybind11_status.h
new file mode 100644
index 0000000..ca3baeb
--- /dev/null
+++ b/tensorflow/python/lib/core/pybind11_status.h
@@ -0,0 +1,66 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
+
+#include <Python.h>
+
+#include "pybind11/pybind11.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/python/lib/core/py_exception_registry.h"
+
+namespace tensorflow {
+
+namespace py = ::pybind11;
+
+namespace pybind11 {
+
+inline void MaybeRaiseFromStatus(const Status& status) {
+ if (!status.ok()) {
+ // TODO(slebedev): translate to builtin exception classes instead?
+ auto* exc_type = PyExceptionRegistry::Lookup(status.code());
+ PyErr_SetObject(
+ exc_type,
+ py::make_tuple(nullptr, nullptr, status.error_message()).ptr());
+ throw py::error_already_set();
+ }
+}
+
+} // namespace pybind11
+} // namespace tensorflow
+
+namespace pybind11 {
+namespace detail {
+
+// Raise an exception if a given status is not OK, otherwise return None.
+//
+// The correspondence between status codes and exception classes is given
+// by PyExceptionRegistry. Note that the registry should be initialized
+// in order to be used, see PyExceptionRegistry::Init.
+template <>
+struct type_caster<::tensorflow::Status> {
+ public:
+ PYBIND11_TYPE_CASTER(::tensorflow::Status, _("Status"));
+ static handle cast(::tensorflow::Status status, return_value_policy, handle) {
+ tensorflow::pybind11::MaybeRaiseFromStatus(status);
+ return none();
+ }
+};
+
+} // namespace detail
+} // namespace pybind11
+
+#endif // TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_