Add ability to raise custom exceptions from xla::Status.
This is intended for PjRT clients that may want to raise a different exception other than `XlaRuntimeError`. Example usage:
```cpp
void RaiseCustomException(xla::Status) {
throw CustomException("");
}
xla::Status status = ...;
xla::status_casters_util::SetFunctionPointerAsPayload(status, &RaiseCustomException);
```
PiperOrigin-RevId: 451153222
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 2b7b287..3c2e398 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -115,6 +115,7 @@
features = ["-use_header_modules"],
deps = [
":exceptions",
+ ":status_casters_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"@pybind11",
@@ -122,6 +123,21 @@
)
cc_library(
+ name = "status_casters_util",
+ srcs = ["status_casters_util.cc"],
+ hdrs = ["status_casters_util.h"],
+ compatible_with = [],
+ copts = [
+ "-fexceptions",
+ "-fno-strict-aliasing",
+ ],
+ features = ["-use_header_modules"],
+ deps = [
+ "//tensorflow/compiler/xla:status",
+ ],
+)
+
+cc_library(
name = "exceptions",
hdrs = ["exceptions.h"],
compatible_with = [],
@@ -648,6 +664,19 @@
],
)
+pybind_extension(
+ name = "status_casters_example",
+ testonly = True,
+ srcs = ["status_casters_example.cc"],
+ deps = [
+ ":status_casters",
+ ":status_casters_util",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:statusor",
+ "@pybind11",
+ ],
+)
+
# TODO(phawkins): the configuration settings here are overly confusing. The right fix is to split
# xla_extension.so so that each backend is a separate plugin, however that must wait for a clean
# ABI separation between devices.
@@ -754,3 +783,16 @@
"//conditions:default": [],
}),
)
+
+py_test(
+ name = "status_casters_test",
+ srcs = ["status_casters_test.py"],
+ python_version = "PY3",
+ tags = ["no_oss"],
+ deps = [
+ ":status_casters_example",
+ ":xla_client",
+ ":xla_extension",
+ "@absl_py//absl/testing:absltest",
+ ] + xla_py_test_deps(),
+)
diff --git a/tensorflow/compiler/xla/python/status_casters.h b/tensorflow/compiler/xla/python/status_casters.h
index 5c6d48f..220b982 100644
--- a/tensorflow/compiler/xla/python/status_casters.h
+++ b/tensorflow/compiler/xla/python/status_casters.h
@@ -18,6 +18,7 @@
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/python/exceptions.h"
+#include "tensorflow/compiler/xla/python/status_casters_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -52,6 +53,13 @@
static handle cast(xla::Status src, return_value_policy /* policy */,
handle /* parent */) {
if (!src.ok()) {
+ std::optional<xla::status_casters_util::FunctionPtr> function =
+ xla::status_casters_util::GetFunctionPointerFromPayload(src);
+
+ if (function.has_value()) {
+ function.value()(src); // This is supposed to throw a custom exception
+ }
+
throw xla::XlaRuntimeError(src);
}
return none().inc_ref();
@@ -62,6 +70,7 @@
struct type_caster<xla::StatusOr<T>> {
public:
using value_conv = make_caster<T>;
+ using status_conv = make_caster<xla::Status>;
PYBIND11_TYPE_CASTER(xla::StatusOr<T>,
_("StatusOr[") + value_conv::name + _("]"));
@@ -69,7 +78,7 @@
static handle cast(xla::StatusOr<T> src, return_value_policy policy,
handle parent) {
if (!src.ok()) {
- throw xla::XlaRuntimeError(src.status());
+ return status_conv::cast(src.status(), policy, parent);
}
return value_conv::cast(std::forward<xla::StatusOr<T>>(src).ValueOrDie(),
policy, parent);
diff --git a/tensorflow/compiler/xla/python/status_casters_example.cc b/tensorflow/compiler/xla/python/status_casters_example.cc
new file mode 100644
index 0000000..f8cf8ec
--- /dev/null
+++ b/tensorflow/compiler/xla/python/status_casters_example.cc
@@ -0,0 +1,70 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdexcept>
+
+#include "pybind11/pybind11.h"
+#include "tensorflow/compiler/xla/python/status_casters.h"
+#include "tensorflow/compiler/xla/python/status_casters_util.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+namespace py = ::pybind11;
+
+namespace {
+
+class XlaTestError : public std::runtime_error {
+ public:
+ using std::runtime_error::runtime_error;
+};
+
+void ThrowXlaTestError(xla::Status status) {
+ DCHECK(!status.ok());
+ throw XlaTestError("XlaTestError");
+}
+
+xla::Status GetXlaStatus() {
+ return xla::Status(tensorflow::error::Code::UNKNOWN, "XlaStatus");
+}
+
+xla::StatusOr<int> GetXlaStatusOr() { return GetXlaStatus(); }
+
+xla::Status GetXlaStatusWithXlaTestError() {
+ xla::Status status(tensorflow::error::Code::UNKNOWN, "XlaTestError");
+ status_casters_util::SetFunctionPointerAsPayload(status, &ThrowXlaTestError);
+
+ return status;
+}
+
+xla::StatusOr<int> GetXlaStatusOrWithXlaTestError() {
+ return GetXlaStatusWithXlaTestError();
+}
+
+} // namespace
+
+PYBIND11_MODULE(status_casters_example, m) {
+ py::register_exception<XlaTestError>(m, "XlaTestError", PyExc_RuntimeError);
+
+ m.def("raise_xla_status", &GetXlaStatus);
+ m.def("raise_xla_status_or", &GetXlaStatusOr);
+
+ m.def("raise_xla_status_with_xla_test_error", &GetXlaStatusWithXlaTestError);
+ m.def("raise_xla_status_or_with_xla_test_error",
+ &GetXlaStatusOrWithXlaTestError);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/python/status_casters_test.py b/tensorflow/compiler/xla/python/status_casters_test.py
new file mode 100644
index 0000000..a33f881
--- /dev/null
+++ b/tensorflow/compiler/xla/python/status_casters_test.py
@@ -0,0 +1,41 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Python xla::Status casters."""
+
+from absl.testing import absltest
+
+from tensorflow.compiler.xla.python import status_casters_example
+from tensorflow.compiler.xla.python import xla_client
+
+
+class StatusCastersTest(absltest.TestCase):
+
+ def test_xla_runtime_error(self):
+ with self.assertRaises(xla_client.XlaRuntimeError):
+ status_casters_example.raise_xla_status()
+
+ with self.assertRaises(xla_client.XlaRuntimeError):
+ status_casters_example.raise_xla_status_or()
+
+ def test_xla_test_error(self):
+ with self.assertRaises(status_casters_example.XlaTestError):
+ status_casters_example.raise_xla_status_with_xla_test_error()
+
+ with self.assertRaises(status_casters_example.XlaTestError):
+ status_casters_example.raise_xla_status_or_with_xla_test_error()
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tensorflow/compiler/xla/python/status_casters_util.cc b/tensorflow/compiler/xla/python/status_casters_util.cc
new file mode 100644
index 0000000..86c4c7a
--- /dev/null
+++ b/tensorflow/compiler/xla/python/status_casters_util.cc
@@ -0,0 +1,55 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/python/status_casters_util.h"
+
+#include <string>
+
+namespace xla {
+namespace status_casters_util {
+
+namespace {
+
+const char kStatusPayloadUrl[] = "xla.status_casters_util.function";
+
+absl::Cord SerializeFunctionPointer(FunctionPtr fn) {
+ return absl::Cord(
+ absl::string_view(reinterpret_cast<const char*>(&fn), sizeof(fn)));
+}
+
+FunctionPtr DeserializeFunctionPointer(const absl::Cord& payload) {
+ return *reinterpret_cast<FunctionPtr*>(
+ const_cast<char*>(std::string(payload).data()));
+}
+
+} // namespace
+
+void SetFunctionPointerAsPayload(xla::Status& status, FunctionPtr fn) {
+ status.SetPayload(kStatusPayloadUrl, SerializeFunctionPointer(fn));
+}
+
+std::optional<FunctionPtr> GetFunctionPointerFromPayload(
+ const xla::Status& status) {
+ std::optional<absl::Cord> payload = status.GetPayload(kStatusPayloadUrl);
+
+ if (!payload.has_value()) {
+ return std::nullopt;
+ }
+
+ return DeserializeFunctionPointer(payload.value());
+}
+
+} // namespace status_casters_util
+} // namespace xla
diff --git a/tensorflow/compiler/xla/python/status_casters_util.h b/tensorflow/compiler/xla/python/status_casters_util.h
new file mode 100644
index 0000000..0a9da01
--- /dev/null
+++ b/tensorflow/compiler/xla/python/status_casters_util.h
@@ -0,0 +1,48 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_STATUS_CASTERS_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_PYTHON_STATUS_CASTERS_UTIL_H_
+
+#include <optional>
+
+#include "tensorflow/compiler/xla/status.h"
+
+namespace xla {
+namespace status_casters_util {
+
+using FunctionPtr = void (*)(xla::Status);
+
+// Sets the function pointer `fn` as payload in `status`. The function must
+// accept `xla::Status` as a parameter, and its intended use is to cast it to a
+// custom exception raised in Python code.
+//
+// Example:
+// void RaiseCustomException(xla::Status) {
+// throw MyCustomException("");
+// }
+// xla::Status status = ...;
+// SetFunctionPointerAsPayload(status, &RaiseCustomException);
+void SetFunctionPointerAsPayload(xla::Status& status, FunctionPtr fn);
+
+// Gets the function pointer from the `status` payload, returns std::nullopt if
+// the function pointer was not set.
+std::optional<FunctionPtr> GetFunctionPointerFromPayload(
+ const xla::Status& status);
+
+} // namespace status_casters_util
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_PYTHON_STATUS_CASTERS_UTIL_H_