Add ability to raise custom exceptions from xla::Status.

This is intended for PjRT clients that may want to raise a different exception other than `XlaRuntimeError`. Example usage:

```cpp
void RaiseCustomException(xla::Status) {
     throw CustomException("");
}
xla::Status status = ...;
xla::status_casters_util::SetFunctionPointerAsPayload(status, &RaiseCustomException);
```
PiperOrigin-RevId: 451153222
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 2b7b287..3c2e398 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -115,6 +115,7 @@
     features = ["-use_header_modules"],
     deps = [
         ":exceptions",
+        ":status_casters_util",
         "//tensorflow/compiler/xla:status",
         "//tensorflow/compiler/xla:statusor",
         "@pybind11",
@@ -122,6 +123,21 @@
 )
 
 cc_library(
+    name = "status_casters_util",
+    srcs = ["status_casters_util.cc"],
+    hdrs = ["status_casters_util.h"],
+    compatible_with = [],
+    copts = [
+        "-fexceptions",
+        "-fno-strict-aliasing",
+    ],
+    features = ["-use_header_modules"],
+    deps = [
+        "//tensorflow/compiler/xla:status",
+    ],
+)
+
+cc_library(
     name = "exceptions",
     hdrs = ["exceptions.h"],
     compatible_with = [],
@@ -648,6 +664,19 @@
     ],
 )
 
+pybind_extension(
+    name = "status_casters_example",
+    testonly = True,
+    srcs = ["status_casters_example.cc"],
+    deps = [
+        ":status_casters",
+        ":status_casters_util",
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:statusor",
+        "@pybind11",
+    ],
+)
+
 # TODO(phawkins): the configuration settings here are overly confusing. The right fix is to split
 # xla_extension.so so that each backend is a separate plugin, however that must wait for a clean
 # ABI separation between devices.
@@ -754,3 +783,16 @@
                "//conditions:default": [],
            }),
 )
+
+py_test(
+    name = "status_casters_test",
+    srcs = ["status_casters_test.py"],
+    python_version = "PY3",
+    tags = ["no_oss"],
+    deps = [
+        ":status_casters_example",
+        ":xla_client",
+        ":xla_extension",
+        "@absl_py//absl/testing:absltest",
+    ] + xla_py_test_deps(),
+)
diff --git a/tensorflow/compiler/xla/python/status_casters.h b/tensorflow/compiler/xla/python/status_casters.h
index 5c6d48f..220b982 100644
--- a/tensorflow/compiler/xla/python/status_casters.h
+++ b/tensorflow/compiler/xla/python/status_casters.h
@@ -18,6 +18,7 @@
 
 #include "pybind11/pybind11.h"
 #include "tensorflow/compiler/xla/python/exceptions.h"
+#include "tensorflow/compiler/xla/python/status_casters_util.h"
 #include "tensorflow/compiler/xla/status.h"
 #include "tensorflow/compiler/xla/statusor.h"
 
@@ -52,6 +53,13 @@
   static handle cast(xla::Status src, return_value_policy /* policy */,
                      handle /* parent */) {
     if (!src.ok()) {
+      std::optional<xla::status_casters_util::FunctionPtr> function =
+          xla::status_casters_util::GetFunctionPointerFromPayload(src);
+
+      if (function.has_value()) {
+        function.value()(src);  // This is supposed to throw a custom exception
+      }
+
       throw xla::XlaRuntimeError(src);
     }
     return none().inc_ref();
@@ -62,6 +70,7 @@
 struct type_caster<xla::StatusOr<T>> {
  public:
   using value_conv = make_caster<T>;
+  using status_conv = make_caster<xla::Status>;
 
   PYBIND11_TYPE_CASTER(xla::StatusOr<T>,
                        _("StatusOr[") + value_conv::name + _("]"));
@@ -69,7 +78,7 @@
   static handle cast(xla::StatusOr<T> src, return_value_policy policy,
                      handle parent) {
     if (!src.ok()) {
-      throw xla::XlaRuntimeError(src.status());
+      return status_conv::cast(src.status(), policy, parent);
     }
     return value_conv::cast(std::forward<xla::StatusOr<T>>(src).ValueOrDie(),
                             policy, parent);
diff --git a/tensorflow/compiler/xla/python/status_casters_example.cc b/tensorflow/compiler/xla/python/status_casters_example.cc
new file mode 100644
index 0000000..f8cf8ec
--- /dev/null
+++ b/tensorflow/compiler/xla/python/status_casters_example.cc
@@ -0,0 +1,70 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdexcept>
+
+#include "pybind11/pybind11.h"
+#include "tensorflow/compiler/xla/python/status_casters.h"
+#include "tensorflow/compiler/xla/python/status_casters_util.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+namespace py = ::pybind11;
+
+namespace {
+
+class XlaTestError : public std::runtime_error {
+ public:
+  using std::runtime_error::runtime_error;
+};
+
+void ThrowXlaTestError(xla::Status status) {
+  DCHECK(!status.ok());
+  throw XlaTestError("XlaTestError");
+}
+
+xla::Status GetXlaStatus() {
+  return xla::Status(tensorflow::error::Code::UNKNOWN, "XlaStatus");
+}
+
+xla::StatusOr<int> GetXlaStatusOr() { return GetXlaStatus(); }
+
+xla::Status GetXlaStatusWithXlaTestError() {
+  xla::Status status(tensorflow::error::Code::UNKNOWN, "XlaTestError");
+  status_casters_util::SetFunctionPointerAsPayload(status, &ThrowXlaTestError);
+
+  return status;
+}
+
+xla::StatusOr<int> GetXlaStatusOrWithXlaTestError() {
+  return GetXlaStatusWithXlaTestError();
+}
+
+}  // namespace
+
+PYBIND11_MODULE(status_casters_example, m) {
+  py::register_exception<XlaTestError>(m, "XlaTestError", PyExc_RuntimeError);
+
+  m.def("raise_xla_status", &GetXlaStatus);
+  m.def("raise_xla_status_or", &GetXlaStatusOr);
+
+  m.def("raise_xla_status_with_xla_test_error", &GetXlaStatusWithXlaTestError);
+  m.def("raise_xla_status_or_with_xla_test_error",
+        &GetXlaStatusOrWithXlaTestError);
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/python/status_casters_test.py b/tensorflow/compiler/xla/python/status_casters_test.py
new file mode 100644
index 0000000..a33f881
--- /dev/null
+++ b/tensorflow/compiler/xla/python/status_casters_test.py
@@ -0,0 +1,41 @@
+# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Python xla::Status casters."""
+
+from absl.testing import absltest
+
+from tensorflow.compiler.xla.python import status_casters_example
+from tensorflow.compiler.xla.python import xla_client
+
+
+class StatusCastersTest(absltest.TestCase):
+
+  def test_xla_runtime_error(self):
+    with self.assertRaises(xla_client.XlaRuntimeError):
+      status_casters_example.raise_xla_status()
+
+    with self.assertRaises(xla_client.XlaRuntimeError):
+      status_casters_example.raise_xla_status_or()
+
+  def test_xla_test_error(self):
+    with self.assertRaises(status_casters_example.XlaTestError):
+      status_casters_example.raise_xla_status_with_xla_test_error()
+
+    with self.assertRaises(status_casters_example.XlaTestError):
+      status_casters_example.raise_xla_status_or_with_xla_test_error()
+
+
+if __name__ == '__main__':
+  absltest.main()
diff --git a/tensorflow/compiler/xla/python/status_casters_util.cc b/tensorflow/compiler/xla/python/status_casters_util.cc
new file mode 100644
index 0000000..86c4c7a
--- /dev/null
+++ b/tensorflow/compiler/xla/python/status_casters_util.cc
@@ -0,0 +1,55 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/python/status_casters_util.h"
+
+#include <string>
+
+namespace xla {
+namespace status_casters_util {
+
+namespace {
+
+const char kStatusPayloadUrl[] = "xla.status_casters_util.function";
+
+absl::Cord SerializeFunctionPointer(FunctionPtr fn) {
+  return absl::Cord(
+      absl::string_view(reinterpret_cast<const char*>(&fn), sizeof(fn)));
+}
+
+FunctionPtr DeserializeFunctionPointer(const absl::Cord& payload) {
+  return *reinterpret_cast<FunctionPtr*>(
+      const_cast<char*>(std::string(payload).data()));
+}
+
+}  // namespace
+
+void SetFunctionPointerAsPayload(xla::Status& status, FunctionPtr fn) {
+  status.SetPayload(kStatusPayloadUrl, SerializeFunctionPointer(fn));
+}
+
+std::optional<FunctionPtr> GetFunctionPointerFromPayload(
+    const xla::Status& status) {
+  std::optional<absl::Cord> payload = status.GetPayload(kStatusPayloadUrl);
+
+  if (!payload.has_value()) {
+    return std::nullopt;
+  }
+
+  return DeserializeFunctionPointer(payload.value());
+}
+
+}  // namespace status_casters_util
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/python/status_casters_util.h b/tensorflow/compiler/xla/python/status_casters_util.h
new file mode 100644
index 0000000..0a9da01
--- /dev/null
+++ b/tensorflow/compiler/xla/python/status_casters_util.h
@@ -0,0 +1,48 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_STATUS_CASTERS_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_PYTHON_STATUS_CASTERS_UTIL_H_
+
+#include <optional>
+
+#include "tensorflow/compiler/xla/status.h"
+
+namespace xla {
+namespace status_casters_util {
+
+using FunctionPtr = void (*)(xla::Status);
+
+// Sets the function pointer `fn` as payload in `status`. The function must
+// accept `xla::Status` as a parameter, and its intended use is to cast it to a
+// custom exception raised in Python code.
+//
+// Example:
+//   void RaiseCustomException(xla::Status) {
+//     throw MyCustomException("");
+//   }
+//   xla::Status status = ...;
+//   SetFunctionPointerAsPayload(status, &RaiseCustomException);
+void SetFunctionPointerAsPayload(xla::Status& status, FunctionPtr fn);
+
+// Gets the function pointer from the `status` payload, returns std::nullopt if
+// the function pointer was not set.
+std::optional<FunctionPtr> GetFunctionPointerFromPayload(
+    const xla::Status& status);
+
+}  // namespace status_casters_util
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_PYTHON_STATUS_CASTERS_UTIL_H_