[PJRT:C] PjRt C API for platform_version() function.
PiperOrigin-RevId: 456812812
diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h
index a623435..eb07781 100644
--- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h
+++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h
@@ -116,6 +116,24 @@
typedef PJRT_Error* PJRT_Client_Process_Index(
PJRT_Client_Process_Index_Args* args);
+typedef struct {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ // `platform_version` has the same lifetime as `client`. It's owned by
+ // `client`.
+ const char* platform_version; // out
+ size_t platform_version_size; // out
+} PJRT_Client_PlatformVersion_Args;
+
+const size_t PJRT_Client_PlatformVersion_Args_STRUCT_SIZE =
+ PJRT_STRUCT_SIZE(PJRT_Client_PlatformVersion_Args, platform_version_size);
+
+// Returns a string containing human-readable, platform-specific version info
+// (e.g. the CUDA version on GPU or libtpu version on Cloud TPU).
+typedef PJRT_Error* PJRT_Client_PlatformVersion(
+ PJRT_Client_PlatformVersion_Args* args);
+
// -------------------------------- API access ---------------------------------
#define PJRT_API_STRUCT_FIELD(fn_type) fn_type* fn_type
@@ -131,6 +149,7 @@
PJRT_API_STRUCT_FIELD(PJRT_Client_Destroy);
PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformName);
PJRT_API_STRUCT_FIELD(PJRT_Client_Process_Index);
+ PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformVersion);
} PJRT_Api;
const size_t PJRT_Api_STRUCT_SIZE =
diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
index e81c9b4..1fbb3c3 100644
--- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
+++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
@@ -87,4 +87,15 @@
return nullptr;
}
+PJRT_Error* PJRT_Client_PlatformVersion(
+ PJRT_Client_PlatformVersion_Args* args) {
+ PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes(
+ "PJRT_CLient_PlatformVersion_Args",
+ PJRT_Client_PlatformVersion_Args_STRUCT_SIZE, args->struct_size));
+ absl::string_view platform_version = args->client->client->platform_version();
+ args->platform_version = platform_version.data();
+ args->platform_version_size = platform_version.size();
+ return nullptr;
+}
+
} // namespace pjrt
diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h
index a5b373d..38d9a47 100644
--- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h
+++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h
@@ -39,6 +39,7 @@
PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args);
PJRT_Error* PJRT_Client_PlatformName(PJRT_Client_PlatformName_Args* args);
PJRT_Error* PJRT_Client_Process_Index(PJRT_Client_Process_Index_Args* args);
+PJRT_Error* PJRT_Client_PlatformVersion(PJRT_Client_PlatformVersion_Args* args);
// Helper macros and functions
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc
index 831f1f9..80c926d 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc
+++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc
@@ -82,6 +82,20 @@
return process_index_args.process_index;
}
+absl::string_view PjRtCApiClient::platform_version() const {
+ PJRT_Client_PlatformVersion_Args args;
+ args.struct_size = PJRT_Client_PlatformVersion_Args_STRUCT_SIZE;
+ args.priv = nullptr;
+ args.client = c_client_;
+ PJRT_Error* error = c_api_->PJRT_Client_PlatformVersion(&args);
+ // TODO(b/236710439)
+ CHECK(error == nullptr);
+
+ absl::string_view platform_version(args.platform_version,
+ args.platform_version_size);
+ return platform_version;
+}
+
StatusOr<std::optional<std::string>> PjRtCApiClient::ExecutableFingerprint(
const PjRtExecutable& executable) const {
return wrapped_->ExecutableFingerprint(
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h
index cc3ecc0..8d5a768 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h
+++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h
@@ -127,9 +127,7 @@
absl::string_view platform_name() const override;
- absl::string_view platform_version() const override {
- return wrapped_->platform_version();
- }
+ absl::string_view platform_version() const override;
PjRtRuntimeType runtime_type() const override {
return wrapped_->runtime_type();