blob: 8d5a76820e398b83479a723695dffb589e762c48 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_C_API_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_C_API_CLIENT_H_
#include <functional>
#include <string>
#include <utility>
#include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/core/platform/casts.h"
namespace xla {
class PjRtCApiDevice : public PjRtDevice {
public:
explicit PjRtCApiDevice(PjRtDevice* wrapped) : wrapped_(wrapped) {}
// Must set client exactly once.
void SetClient(PjRtClient* client) {
CHECK(client_ == nullptr) << ToString();
client_ = client;
}
PjRtClient* client() const override { return client_; }
bool IsAddressable() const override { return wrapped_->IsAddressable(); }
int id() const override { return wrapped_->id(); }
int process_index() const override { return wrapped_->process_index(); }
int local_hardware_id() const override {
return wrapped_->local_hardware_id();
}
absl::string_view device_kind() const override {
return wrapped_->device_kind();
}
std::string DebugString() const override { return wrapped_->DebugString(); }
std::string ToString() const override {
return absl::StrCat("PjRtCApiDevice(wrapped=", wrapped_->ToString(), ")");
}
Status TransferToInfeed(const LiteralSlice& literal) override {
return wrapped_->TransferToInfeed(literal);
}
Status TransferFromOutfeed(MutableBorrowingLiteral literal) override {
return wrapped_->TransferFromOutfeed(std::move(literal));
}
std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent(
absl::string_view description) const override {
return wrapped_->CreateAsyncTrackingEvent(description);
}
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
const override {
return wrapped_->Attributes();
}
PjRtDevice* wrapped() const { return wrapped_; }
static PjRtDevice* GetWrapped(PjRtDevice* c_api_device) {
return tensorflow::down_cast<PjRtCApiDevice*>(c_api_device)->wrapped();
}
private:
PjRtClient* client_ = nullptr;
PjRtDevice* wrapped_;
};
class PjRtCApiClient : public PjRtClient {
public:
PjRtCApiClient(const PJRT_Api* c_api, PJRT_Client* c_client,
std::vector<std::unique_ptr<PjRtCApiDevice>> devices);
~PjRtCApiClient() override;
int process_index() const override;
int device_count() const override { return wrapped_->device_count(); }
int addressable_device_count() const override {
return wrapped_->addressable_device_count();
}
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
absl::Span<PjRtDevice* const> addressable_devices() const override {
return addressable_devices_;
}
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
TF_ASSIGN_OR_RETURN(PjRtDevice * wrapped_device,
wrapped_->LookupDevice(device_id));
return GetCApiDevice(wrapped_device);
}
StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const override {
TF_ASSIGN_OR_RETURN(PjRtDevice * wrapped_device,
wrapped_->LookupAddressableDevice(local_hardware_id));
return GetCApiDevice(wrapped_device);
}
PjRtPlatformId platform_id() const override {
return wrapped_->platform_id();
}
absl::string_view platform_name() const override;
absl::string_view platform_version() const override;
PjRtRuntimeType runtime_type() const override {
return wrapped_->runtime_type();
}
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override {
return wrapped_->GetDefaultDeviceAssignment(num_replicas, num_partitions);
}
StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override {
return wrapped_->GetHloCostAnalysis();
}
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override {
return WrapExecutable(wrapped_->Compile(computation, options));
}
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
mlir::ModuleOp module, CompileOptions options) override {
return WrapExecutable(wrapped_->Compile(module, options));
}
StatusOr<std::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const override;
StatusOr<std::string> SerializeExecutable(
const PjRtExecutable& executable) const override;
StatusOr<std::unique_ptr<PjRtExecutable>> DeserializeExecutable(
absl::string_view serialized, CompileOptions options) override {
return WrapExecutable(wrapped_->DeserializeExecutable(serialized, options));
}
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) override {
return Unimplemented(
"PJRT C API does not support CreateUninitializedBuffer");
}
StatusOr<std::unique_ptr<AsyncBufferTransferManager>>
CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes,
PjRtDevice* device) override {
return Unimplemented(
"PJRT C API does not support CreateBuffersForAsyncTransfer");
}
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, PrimitiveType type, absl::Span<int64_t const> dims,
std::optional<absl::Span<int64_t const>> byte_strides,
HostBufferSemantics host_buffer_semantics,
std::function<void()> on_done_with_host_buffer,
PjRtDevice* device) override {
return WrapBuffer(wrapped_->BufferFromHostBuffer(
data, type, dims, byte_strides, host_buffer_semantics,
on_done_with_host_buffer, PjRtCApiDevice::GetWrapped(device)));
}
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) override {
return WrapBuffer(wrapped_->BufferFromHostLiteral(
literal, PjRtCApiDevice::GetWrapped(device)));
}
StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
void* device_ptr, const Shape& shape, PjRtDevice* device,
std::function<void()> on_delete_callback) override {
return WrapBuffer(wrapped_->CreateViewOfDeviceBuffer(
device_ptr, shape, PjRtCApiDevice::GetWrapped(device),
on_delete_callback));
}
StatusOr<std::uintptr_t> UnsafeBufferPointer(PjRtBuffer* buffer) override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
PjRtDevice* device,
PjRtCrossHostRecvNotifier notifier) override {
return Unimplemented(
"PJRT C API does not support MakeCrossHostReceiveBuffers");
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
MakeCrossHostReceiveBuffersForGather(
absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details,
PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override {
return Unimplemented(
"PJRT C API does not support MakeCrossHostReceiveBuffers");
}
StatusOr<ChannelHandle> CreateChannelHandle() override {
return Unimplemented("PJRT C API does not support CreateChannelHandle");
}
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
return Unimplemented(
"PJRT C API does not support CreateDeviceToHostChannelHandle");
}
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
return Unimplemented(
"PJRT C API does not support CreateHostToDeviceChannelHandle");
}
Status Defragment() override { return wrapped_->Defragment(); }
PjRtDevice* GetCApiDevice(PjRtDevice* wrapped_device) const {
auto it = wrapped_device_map_.find(wrapped_device);
CHECK(it != wrapped_device_map_.end());
return it->second;
}
StatusOr<std::unique_ptr<PjRtExecutable>> WrapExecutable(
StatusOr<std::unique_ptr<PjRtExecutable>> to_wrap);
StatusOr<std::unique_ptr<PjRtBuffer>> WrapBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>> to_wrap);
private:
const PJRT_Api* c_api_;
PJRT_Client* c_client_;
std::vector<std::unique_ptr<PjRtCApiDevice>> owned_devices_;
std::vector<PjRtDevice*> devices_;
std::vector<PjRtDevice*> addressable_devices_;
// TODO(skyewm): this is a shim so we can run PjRtCApiClient code without the
// C API being fully implemented. All methods using wrapped_ should either be
// marked unimplemented or implemented in terms of the C API, at which point
// wrapped_ and related functionality should be removed.
PjRtClient* wrapped_;
absl::flat_hash_map<PjRtDevice*, PjRtCApiDevice*> wrapped_device_map_;
};
class PjRtCApiBuffer : public PjRtBuffer {
public:
PjRtCApiBuffer(PjRtCApiClient* client, std::unique_ptr<PjRtBuffer> wrapped)
: client_(client), wrapped_(std::move(wrapped)) {}
const Shape& on_device_shape() const override {
return wrapped_->on_device_shape();
}
StatusOr<Shape> logical_on_device_shape() override {
return wrapped_->logical_on_device_shape();
}
PjRtDevice* device() const override {
return client_->GetCApiDevice(wrapped_->device());
}
PjRtClient* client() const override { return client_; }
StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference()
override {
return wrapped_->AcquireExternalReference();
}
PjRtFuture<Status> ToLiteral(MutableLiteralBase* literal) override {
return wrapped_->ToLiteral(literal);
}
StatusOr<size_t> GetOnDeviceSizeInBytes() const override {
return wrapped_->GetOnDeviceSizeInBytes();
}
PjRtFuture<Status> CopyRawToHost(void* dst, int64_t offset,
int64_t transfer_size) override {
return wrapped_->CopyRawToHost(dst, offset, transfer_size);
}
void Delete() override { wrapped_->Delete(); }
StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership(
bool wait_for_operations_to_complete) override {
return wrapped_->ReleaseDeviceMemoryOwnership(
wait_for_operations_to_complete);
}
bool IsDeleted() override { return wrapped_->IsDeleted(); }
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) override {
if (dst_device->client() == client_) {
return client_->WrapBuffer(
wrapped_->CopyToDevice(PjRtCApiDevice::GetWrapped(dst_device)));
} else {
return wrapped_->CopyToDevice(dst_device);
}
}
void CopyToRemoteDevice(absl::string_view serialized_descriptor,
RemoteSendCallback on_done) override {
LOG(ERROR) << "PJRT C API does not support CopyToRemoteDevice";
}
void CopyToRemoteDeviceScattered(
absl::Span<const std::pair<std::string, RemoteSendCallback>>
serialized_descriptors_and_callbacks,
const ScatterDetails& scatter_details) override {
LOG(ERROR) << "PJRT C API does not support CopyToRemoteDeviceScattered";
}
PjRtFuture<Status> GetReadyFuture() override {
return wrapped_->GetReadyFuture();
}
bool IsOnCpu() const override { return wrapped_->IsOnCpu(); }
PjRtBuffer* wrapped() const { return wrapped_.get(); }
static PjRtBuffer* GetWrapped(PjRtBuffer* c_api_buffer) {
return tensorflow::down_cast<PjRtCApiBuffer*>(c_api_buffer)->wrapped();
}
static std::vector<PjRtBuffer*> GetWrappedVector(
absl::Span<PjRtBuffer* const> c_api_buffers) {
std::vector<PjRtBuffer*> wrapped;
wrapped.reserve(c_api_buffers.size());
for (PjRtBuffer* c_api_buf : c_api_buffers) {
wrapped.push_back(GetWrapped(c_api_buf));
}
return wrapped;
}
private:
PjRtCApiClient* client_;
std::unique_ptr<PjRtBuffer> wrapped_;
};
class PjRtCApiExecutable : public PjRtExecutable {
public:
PjRtCApiExecutable(PjRtCApiClient* client,
std::unique_ptr<PjRtExecutable> wrapped);
PjRtClient* client() const override { return client_; }
absl::string_view name() const override { return wrapped_->name(); }
int num_replicas() const override { return wrapped_->num_replicas(); }
int num_partitions() const override { return wrapped_->num_partitions(); }
int64_t SizeOfGeneratedCodeInBytes() const override {
return wrapped_->SizeOfGeneratedCodeInBytes();
}
const DeviceAssignment& device_assignment() const override {
return wrapped_->device_assignment();
}
absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
const override {
return wrapped_->addressable_device_logical_ids();
}
absl::Span<PjRtDevice* const> addressable_devices() const override {
return addressable_devices_;
}
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
const override {
return wrapped_->GetHloModules();
}
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options,
std::optional<std::vector<PjRtFuture<Status>>>& returned_futures)
override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
std::optional<PjRtFuture<Status>>& returned_future,
bool fill_future) override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options,
std::optional<PjRtFuture<Status>>& returned_future,
bool fill_future) override;
void Delete() override { return wrapped_->Delete(); }
bool IsDeleted() override { return wrapped_->IsDeleted(); }
PjRtExecutable* wrapped() const { return wrapped_.get(); }
static PjRtExecutable* GetWrapped(const PjRtExecutable* c_api_executable) {
return tensorflow::down_cast<const PjRtCApiExecutable*>(c_api_executable)
->wrapped();
}
private:
PjRtCApiClient* client_;
std::unique_ptr<PjRtExecutable> wrapped_;
std::vector<PjRtDevice*> addressable_devices_;
};
// Takes ownership of wrapped.
StatusOr<std::unique_ptr<PjRtClient>> GetCApiClient();
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_C_API_CLIENT_H_