blob: 80c926d3b87fec693ada491effe10ade3f7f8ef8 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h"
#include <string>
#include <utility>
// TODO(skyewm): remove when everything goes through C API
#include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
#include "tensorflow/core/tpu/pjrt_api.h"
namespace xla {
PjRtCApiClient::PjRtCApiClient(
const PJRT_Api* c_api, PJRT_Client* c_client,
std::vector<std::unique_ptr<PjRtCApiDevice>> devices)
: c_api_(c_api), c_client_(c_client), owned_devices_(std::move(devices)) {
wrapped_ = c_client_->client.get();
for (int i = 0; i < owned_devices_.size(); ++i) {
const std::unique_ptr<PjRtCApiDevice>& device = owned_devices_[i];
wrapped_device_map_[wrapped_->devices()[i]] = device.get();
device->SetClient(this);
devices_.push_back(device.get());
if (device->IsAddressable()) {
int idx = device->local_hardware_id();
if (idx >= addressable_devices_.size()) {
addressable_devices_.resize(idx + 1);
}
CHECK(addressable_devices_[idx] == nullptr) << idx;
addressable_devices_[idx] = device.get();
}
}
}
absl::string_view PjRtCApiClient::platform_name() const {
PJRT_Client_PlatformName_Args args;
args.client = c_client_;
args.struct_size = PJRT_Client_PlatformName_Args_STRUCT_SIZE;
args.priv = nullptr;
PJRT_Error* error = c_api_->PJRT_Client_PlatformName(&args);
// TODO(b/236710439): handle error
CHECK(error == nullptr);
absl::string_view platform_name(args.platform_name, args.platform_name_size);
return platform_name;
}
PjRtCApiClient::~PjRtCApiClient() {
PJRT_Client_Destroy_Args free_args;
free_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE;
free_args.priv = nullptr;
free_args.client = c_client_;
PJRT_Error* error = c_api_->PJRT_Client_Destroy(&free_args);
// TODO(skyewm): handle error
CHECK(error == nullptr);
}
int PjRtCApiClient::process_index() const {
PJRT_Client_Process_Index_Args process_index_args;
process_index_args.struct_size = PJRT_Client_Process_Index_Args_STRUCT_SIZE;
process_index_args.priv = nullptr;
process_index_args.client = c_client_;
PJRT_Error* error = c_api_->PJRT_Client_Process_Index(&process_index_args);
// TODO(b/236710439)
CHECK(error == nullptr);
return process_index_args.process_index;
}
absl::string_view PjRtCApiClient::platform_version() const {
PJRT_Client_PlatformVersion_Args args;
args.struct_size = PJRT_Client_PlatformVersion_Args_STRUCT_SIZE;
args.priv = nullptr;
args.client = c_client_;
PJRT_Error* error = c_api_->PJRT_Client_PlatformVersion(&args);
// TODO(b/236710439)
CHECK(error == nullptr);
absl::string_view platform_version(args.platform_version,
args.platform_version_size);
return platform_version;
}
StatusOr<std::optional<std::string>> PjRtCApiClient::ExecutableFingerprint(
const PjRtExecutable& executable) const {
return wrapped_->ExecutableFingerprint(
*PjRtCApiExecutable::GetWrapped(&executable));
}
StatusOr<std::string> PjRtCApiClient::SerializeExecutable(
const PjRtExecutable& executable) const {
return wrapped_->SerializeExecutable(
*PjRtCApiExecutable::GetWrapped(&executable));
}
StatusOr<std::uintptr_t> PjRtCApiClient::UnsafeBufferPointer(
PjRtBuffer* buffer) {
return wrapped_->UnsafeBufferPointer(PjRtCApiBuffer::GetWrapped(buffer));
}
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtCApiClient::WrapExecutable(
StatusOr<std::unique_ptr<PjRtExecutable>> to_wrap) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
std::move(to_wrap));
return std::unique_ptr<PjRtExecutable>(
std::make_unique<PjRtCApiExecutable>(this, std::move(executable)));
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCApiClient::WrapBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>> to_wrap) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> buffer, std::move(to_wrap));
return std::unique_ptr<PjRtBuffer>(
std::make_unique<PjRtCApiBuffer>(this, std::move(buffer)));
}
PjRtCApiExecutable::PjRtCApiExecutable(PjRtCApiClient* client,
std::unique_ptr<PjRtExecutable> wrapped)
: client_(client), wrapped_(std::move(wrapped)) {
addressable_devices_.reserve(wrapped_->addressable_devices().size());
for (PjRtDevice* device : wrapped_->addressable_devices()) {
addressable_devices_.push_back(client_->GetCApiDevice(device));
}
}
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
PjRtCApiExecutable::Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options,
std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
std::vector<std::vector<PjRtBuffer*>> wrapped_args;
for (const std::vector<PjRtBuffer*>& args : argument_handles) {
wrapped_args.push_back(PjRtCApiBuffer::GetWrappedVector(args));
}
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> out,
wrapped_->Execute(wrapped_args, options, returned_futures));
for (auto& buffer_list : out) {
for (std::unique_ptr<PjRtBuffer>& buffer : buffer_list) {
buffer = std::make_unique<PjRtCApiBuffer>(client_, std::move(buffer));
}
}
return out;
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtCApiExecutable::ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
std::vector<PjRtBuffer*> wrapped_args =
PjRtCApiBuffer::GetWrappedVector(argument_handles);
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<PjRtBuffer>> out,
wrapped_->ExecuteSharded(wrapped_args, PjRtCApiDevice::GetWrapped(device),
options, returned_future, fill_future));
for (std::unique_ptr<PjRtBuffer>& buffer : out) {
buffer = std::make_unique<PjRtCApiBuffer>(client_, std::move(buffer));
}
return out;
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtCApiExecutable::ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
std::optional<PjRtFuture<Status>>& returned_future, bool fill_future) {
std::vector<PjRtBuffer*> wrapped_args =
PjRtCApiBuffer::GetWrappedVector(argument_handles);
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> out,
wrapped_->ExecutePortable(
wrapped_args, PjRtCApiDevice::GetWrapped(device),
options, returned_future, fill_future));
for (std::unique_ptr<PjRtBuffer>& buffer : out) {
buffer = std::make_unique<PjRtCApiBuffer>(client_, std::move(buffer));
}
return out;
}
StatusOr<std::unique_ptr<PjRtClient>> GetCApiClient() {
const PJRT_Api* c_api = tensorflow::tpu::PjrtApi();
// TODO(skyewm): make status
CHECK(c_api != nullptr);
PJRT_Client_Create_Args init_args;
init_args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
init_args.priv = nullptr;
PJRT_Error* error = c_api->PJRT_Client_Create(&init_args);
// TODO(skyewm): handle error
CHECK(error == nullptr);
PJRT_Client* c_client = init_args.client;
PjRtClient* wrapped = c_client->client.get();
std::vector<std::unique_ptr<PjRtCApiDevice>> devices;
devices.reserve(wrapped->devices().size());
for (PjRtDevice* device : wrapped->devices()) {
devices.emplace_back(std::make_unique<PjRtCApiDevice>(device));
}
return std::unique_ptr<PjRtClient>(
std::make_unique<PjRtCApiClient>(c_api, c_client, std::move(devices)));
}
} // namespace xla