blob: 077e381d1c27089666eaa1b338a2bf92062ad746 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Implementation notes:
//
// Asynchronous execution:
// -----------------------
//
// Computations and host-to-device transfers do not need to block the host
// waiting for the operation to complete but instead return control to the host
// immediately. This allows client logic to overlap with device-side
// computation.
//
// For a good user experience, we must be careful only to enqueue operations
// that are unlikely to fail; as a rule error checking must be done eagerly
// before returning control to the client.
//
// The degree to which the client can enqueue operations ahead of the client
// is limited by a semaphore. There are at two modes: asynchronous, where we
// allow the client to enqueue up to 32 executions ahead of the device, and
// synchronous, where we limit the client to having one enqueued operation at
// a time. The value of 32 is arbitrary.
//
// Even in asynchronous mode, it is important that we do not permit
// unbounded queue-ahead. Firstly it is problematic when the user does something
// like the following in Python:
// %timeit run_computation()
// To the timeit logic, op() appears to be extremely cheap since it is deferring
// all of its real work and not blocking, and so the %timeit will run op() many
// (e.g., 10000) times to get better timing resolution, even though in reality
// it may be expensive. Secondly, on CPU the allocator is synchronized with the
// head of the compute stream, and we allocate buffers for all of the enqueued
// programs without any reuse (unlike GPU). This means that the memory usage
// is proportional to the queue size.
//
// Multi-stream execution:
// -----------------------
//
// We use a multistream execution design, where different Streams are used for
// host-to-device transfers, device-to-host transfers, and compute. This allows
// us to overlap transfers on and off the device with computation.
//
// Synchronization between streams occurs via BufferSequencingEvents that
// describe when the contents of a logical buffer are known to be valid on
// a particular stream, and when a buffer's uses have all completed.
//
// Synchronous vs asynchronous deallocation:
// -----------------------------------------
//
// See the comment on LocalDeviceState::AllocationModel for a discussion of the
// different allocation semantics on CPU, GPU, and TPU.
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include <cstddef>
#include <cstdlib>
#include <memory>
#include <string>
#include <vector>
#include "absl/base/casts.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/pjrt/utils.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/profiler/lib/traceme_encode.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/host/host_platform_id.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/stream.h"
namespace xla {
PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
return client_->platform_id();
}
absl::string_view PjRtStreamExecutorDevice::platform_name() const {
return client_->platform_name();
}
StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
const {
if (local_device_state_) {
return local_device_state_.get();
}
return InvalidArgument("Device %s is not a local device.", DebugString());
}
std::string PjRtStreamExecutorDevice::DebugString() const {
return absl::StrCat(platform_name(), ":", id());
}
StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
absl::Span<const std::vector<PjRtDevice*>> devices) {
if (devices.empty()) {
return InvalidArgument(
"Device assignment passed to Compile() must be non-empty.");
}
if (devices[0].empty()) {
return InvalidArgument(
"Device assignment passed to Compile() must have a nonzero number of "
"partitions per replica; replica 0 had 0 partitions.");
}
DeviceAssignment xla_assignment(devices.size(), devices[0].size());
for (int replica = 0; replica < devices.size(); ++replica) {
if (devices[replica].size() != devices[0].size()) {
return InvalidArgument(
"Device assignment passed to Compile() has different numbers of "
"partitions between replicas; %d partitions for replica %d versus %d "
"partitions for replica 0.",
devices[replica].size(), replica, devices[0].size());
}
for (int partition = 0; partition < devices[replica].size(); ++partition) {
if (devices[0][0]->client()->platform_id() !=
devices[replica][partition]->client()->platform_id()) {
return InvalidArgument(
"Device assignment passed to Compile() must have devices of a "
"single kind, got %s for replica 0 partition 0 and %s for replica "
"%d partition %d.",
devices[0][0]->client()->platform_name(),
devices[replica][partition]->client()->platform_name(), replica,
partition);
}
xla_assignment(replica, partition) = devices[replica][partition]->id();
}
}
return xla_assignment;
}
class CpuAllocator : public tensorflow::Allocator {
public:
CpuAllocator() = default;
std::string Name() override { return "cpu"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
return tensorflow::port::AlignedMalloc(num_bytes, alignment);
}
void DeallocateRaw(void* ptr) override {
return tensorflow::port::AlignedFree(ptr);
}
};
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options)
: platform_id_(tensorflow::Fingerprint64(platform_name)),
platform_name_(std::move(platform_name)),
client_(client),
host_memory_allocator_(std::move(host_memory_allocator)),
owned_devices_(std::move(devices)),
task_id_(task_id),
owned_allocator_(std::move(allocator)),
should_stage_host_to_device_transfers_(
should_stage_host_to_device_transfers),
gpu_run_options_(std::move(gpu_run_options)),
thread_pool_(
tensorflow::Env::Default(), "pjrt_thread_pool",
std::max<int>(DefaultThreadPoolSize(), client->device_count())) {
if (owned_allocator_ != nullptr) {
allocator_ = owned_allocator_.get();
} else {
allocator_ = client_->backend().memory_allocator();
}
if (!host_memory_allocator_) {
host_memory_allocator_ = std::make_unique<CpuAllocator>();
}
for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
owned_devices_) {
devices_.push_back(device.get());
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
<< "Duplicate device id: " << device->id();
if (device->IsAddressable()) {
int idx = device->local_hardware_id();
if (idx >= addressable_devices_.size()) {
addressable_devices_.resize(idx + 1);
}
CHECK(addressable_devices_[idx] == nullptr) << idx;
addressable_devices_[idx] = device.get();
}
device->SetClient(this);
}
for (int idx = 0; idx < addressable_devices_.size(); ++idx) {
CHECK(addressable_devices_[idx] != nullptr) << idx;
}
}
StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const {
return client_->backend().computation_placer()->AssignDevices(num_replicas,
num_partitions);
}
StatusOr<std::unique_ptr<HloCostAnalysis>>
PjRtStreamExecutorClient::GetHloCostAnalysis() {
return absl::make_unique<HloCostAnalysis>(
client_->backend().compiler()->ShapeSizeBytesFunction());
}
namespace {
// Ensures that it is safe to deallocate any buffers that have been enqueued in
// an operation on stream. Called only in rare error cases that are triggered
// during enqueue. These cases generally correspond to resource exhaustion.
void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) {
switch (local_device->allocation_model()) {
case LocalDeviceState::kAsynchronous:
// We can safely deallocate any dangling buffers immediately. NOTE: this
// assumes that any buffers enqueued on stream are local to stream's
// executor, and manual action may be needed if that condition is not met.
break;
case LocalDeviceState::kComputeSynchronized:
// This will stall computation but that's ok in this very rare error
// case.
if (stream != local_device->compute_stream()) {
local_device->compute_stream()->ThenWaitFor(stream);
}
break;
case LocalDeviceState::kSynchronous:
// This will stall the calling thread but that's ok in this very rare
// error case. If the stall fails just crash, since we have no other
// way to synchronize.
TF_CHECK_OK(stream->BlockHostUntilDone());
break;
}
}
// Does all necessary bookkeeping, after a buffer is successfully enqueued onto
// a stream, to ensure that the buffer will be kept alive until its use on that
// stream is complete.
//
// device_buffer: the buffer that was enqueued.
// buffer_local_device: the device the buffer was allocated on.
// stream_local_device: the device that manages usage_stream.
// event: an event that was recorded on usage_stream
// after the usage of device_buffer was enqueued.
// usage_stream: the stream the operation using device_buffer
// was enqueued on.
// prefer_to_retain_reference: relevant only for the compute synchronous
// allocation model. If true, retain a reference
// to device_buffer until after the operation
// completes. If false then the compute stream
// will have to be synchronized past event before
// device_buffer can be freed.
//
// prefer_to_retain_reference encodes a heuristic set by the caller for the
// compute synchronous model:
//
// Generally when a buffer is the destination of a copy to a device, it will
// subsequently be used on the device's compute stream before being freed. In
// that case, there is no need to retain a reference to the buffer. If the
// buffer is freed before being used on the compute stream, the free will be
// delayed until the host knows that event has completed, but this is expected
// to be uncommon.
//
// When a buffer is the source of a copy from a device, we need to either retain
// a reference to the buffer until the copy completes or serialize the compute
// stream behind the copy. It is often better to retain a reference since while
// that keeps memory alive longer, it avoids stalling the compute stream.
void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
LocalDeviceState* buffer_local_device,
LocalDeviceState* stream_local_device,
std::shared_ptr<BufferSequencingEvent> event,
se::Stream* usage_stream, bool prefer_to_retain_reference,
std::vector<std::shared_ptr<TrackedDeviceBuffer>>*
buffers_to_release = nullptr) {
tensorflow::profiler::TraceMe traceme("RecordUsage");
bool retain_buffer_until_completion =
// If the buffer wasn't allocated on the same device as the stream, always
// retain a reference.
(stream_local_device != buffer_local_device) ||
// In the synchronous allocation model, always retain a reference.
(stream_local_device->allocation_model() ==
LocalDeviceState::kSynchronous) ||
// In the compute synchronous model, use the caller's heuristic.
(stream_local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized &&
prefer_to_retain_reference);
if (retain_buffer_until_completion) {
if (buffers_to_release) {
buffers_to_release->push_back(device_buffer.buffer());
} else {
buffer_local_device->ThenRelease(usage_stream, device_buffer.buffer());
}
}
device_buffer.ConvertUsageHold(usage_stream, event,
retain_buffer_until_completion);
}
// Allocates the device buffers for a buffer that will be used as the
// destination of a copy, either from the host or another device. copy_stream
// may be nullptr, e.g., when allocating a buffer for a cross-host copy. If the
// buffer is a tuple then the tuple tables are allocated, and all necessary
// synchronization for them is dealt with, before the buffer is returned.
//
// It is safe to delete the returned PjRtBuffer without further
// synchronization if an error occurs before the buffer is used.
//
// The caller may optionally provide a definition event to be recorded in
// the buffer.
// TODO(phawkins): replace on_host_shape here with on_device_shape.
StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
const Shape& on_host_shape, PjRtDevice* device,
LocalDeviceState* local_device, se::Stream* copy_stream,
bool is_uninitialized_create, PjRtClient* client,
std::shared_ptr<BufferSequencingEvent> definition_event = nullptr) {
if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
return InvalidArgument("Can't make a buffer from an empty tuple");
}
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
TransferManager* transfer_manager =
se_client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
transfer_manager->AllocateScopedShapedBuffer(
on_host_shape, se_client->allocator(),
local_device->device_ordinal()));
if (local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized) {
if (copy_stream == nullptr) {
CHECK(is_uninitialized_create);
} else {
copy_stream->ThenWaitFor(local_device->compute_stream());
}
} else {
DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
local_device->compute_stream()->parent(), dst_buffer));
}
Shape on_device_shape = dst_buffer.on_device_shape();
absl::InlinedVector<std::shared_ptr<BufferSequencingEvent>, 2>
definition_events;
if (is_uninitialized_create) {
// There is not going to be any copy into the buffer so in general we don't
// need a definition event.
if (local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized) {
// The allocation is not valid until the compute stream passes this point,
// so add a definition event in the compute stream.
definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
local_device->event_pool().ThenAllocateAndRecordEvent(
local_device->compute_stream()));
definition_events.back()->SetSequencingEvent(
std::move(event), local_device->compute_stream());
}
// if the caller provided a definition event then we record that.
if (definition_event) {
definition_events.emplace_back(definition_event);
}
} else {
// We have at least one definition event, for the copy completing to
// the device buffers.
if (definition_event) {
definition_events.emplace_back(definition_event);
} else {
definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
}
}
se::Stream* tuple_table_stream = local_device->host_to_device_stream();
if (on_device_shape.IsTuple()) {
// We also need to copy the tuple tables, so we'll have an additional
// definition event for that copy to complete.
if (tuple_table_stream != copy_stream) {
if (local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized) {
tuple_table_stream->ThenWaitFor(local_device->compute_stream());
} else {
DCHECK(transfer_manager->CanShapedBufferBeAccessedNow(
local_device->compute_stream()->parent(), dst_buffer));
}
}
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
tuple_table_stream, dst_buffer));
// CAUTION: From this point onwards we need to be careful about returning
// from error cases because we have started a transfer and must not allow
// dst_buffer to be freed too soon in the non-async allocation models.
definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
StatusOr<EventPool::Handle> event_or =
local_device->event_pool().ThenAllocateAndRecordEvent(
tuple_table_stream);
if (!event_or.ok()) {
StallStreamOnError(local_device, tuple_table_stream);
return event_or.status();
}
definition_events.back()->SetSequencingEvent(event_or.ConsumeValueOrDie(),
tuple_table_stream);
}
std::shared_ptr<TrackedDeviceBuffer> dst_device_buffer =
TrackedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
definition_events);
auto py_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
on_device_shape, std::move(dst_device_buffer), client, device);
if (on_device_shape.IsTuple()) {
// Add a usage hold for the tuple table write and immediately convert it to
// the appropriate form of synchronization. prefer_to_retain_reference=false
// means don't retain a memory reference until the transfer is complete when
// using the ComputeSynchronized allocation model. This is a heuristic
// because in the common case destination buffers will be used on the
// compute stream and therefore don't require any synchronization before
// being freed. If the buffer is allocated and never used, the free will
// take longer and this is assumed to be ok.
RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device,
definition_events.back(), tuple_table_stream,
/*prefer_to_retain_reference=*/false);
}
return py_buffer;
}
// Adds necessary synchronization after a copy has been enqueued to a buffer.
// definition_event was added when the buffer was allocated, but has not yet
// had an event recorded.
Status AddDestinationBufferSynchronization(
LocalDeviceState* local_device,
PjRtStreamExecutorBuffer::ScopedHold device_buffer,
std::shared_ptr<BufferSequencingEvent> definition_event,
se::Stream* copy_stream) {
StatusOr<EventPool::Handle> event_or =
local_device->event_pool().ThenAllocateAndRecordEvent(copy_stream);
if (!event_or.ok()) {
StallStreamOnError(local_device, copy_stream);
return event_or.status();
}
definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(),
copy_stream);
// prefer_to_retain_reference=false means don't retain a memory reference
// until the transfer is complete when using the ComputeSynchronized
// allocation model. This is a heuristic because in the common case
// destination buffers will be used on the compute stream and therefore don't
// require any synchronization before being freed. If the buffer is allocated
// and never used, the free will take longer and this is assumed to be ok.
RecordUsage(std::move(device_buffer), local_device, local_device,
definition_event, copy_stream,
/*prefer_to_retain_reference=*/false);
return Status::OK();
}
} // namespace
PjRtStreamExecutorBuffer::ScopedHold::~ScopedHold() {
if (ok()) {
parent_->DropHold(type_, buffer().get());
}
}
PjRtStreamExecutorBuffer::ScopedHold::ScopedHold(ScopedHold&& other)
: parent_(other.parent_),
type_(other.type_),
state_(other.state_),
status_(std::move(other.status_)),
buffer_(std::move(other.buffer_)) {
// Preserve the invariant that status is invalid if buffer == nullptr.
other.SetState(kMoved);
}
void PjRtStreamExecutorBuffer::ScopedHold::Acquire(
StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or) {
CHECK(!ok());
if (buffer_or.ok()) {
buffer_ = buffer_or.ValueOrDie();
SetState(kValid);
} else {
status_ = buffer_or.status();
buffer_ = nullptr;
SetState(kError);
}
// Check the invariant holds.
CHECK(!ok() || buffer_ != nullptr);
}
PjRtStreamExecutorBuffer::ScopedHold::ForClosure
PjRtStreamExecutorBuffer::ScopedHold::ToClosure() {
CHECK(ok());
ForClosure for_closure(parent_, type_, state_, std::move(status_),
std::move(buffer_));
SetState(kReleased);
return for_closure;
}
void PjRtStreamExecutorBuffer::ScopedHold::ConvertUsageHold(
se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
bool reference_held) {
CHECK(ok());
CHECK_EQ(type_, kUsage);
parent_->ConvertUsageHold(buffer().get(), usage_stream, std::move(event),
reference_held);
SetState(kConverted);
}
void PjRtStreamExecutorBuffer::ScopedHold::ConfirmDonation() {
CHECK(ok());
CHECK_EQ(type_, kDonation);
parent_->ConfirmDonation(buffer().get());
SetState(kDonated);
}
void PjRtStreamExecutorBuffer::ScopedHold::AddToInput(
ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
ExecutionInput* execution_input,
se::DeviceMemoryAllocator* allocator) const {
CHECK(ok());
if (type_ == kDonation) {
buffer()->AddToInputAsDonated(iterator, end, execution_input, allocator);
} else {
CHECK_EQ(type_, kUsage);
buffer()->AddToInputAsImmutable(iterator, end);
}
}
bool PjRtStreamExecutorBuffer::IsOnCpu() const {
return client()->platform_id() == kCpuId;
}
StatusOr<Shape> PjRtStreamExecutorBuffer::logical_on_device_shape() {
if (IsOnCpu() && on_device_shape().is_dynamic()) {
// TODO(b/182468546): TransferManager may return corrupted dynamic_shape on
// CPU non-deterministically.
return Unimplemented(
"Gathering DynamicShape is not implemented properly on CPU yet.");
}
if (on_device_shape_.is_static()) {
return on_device_shape_;
}
auto* local_device = device_->local_device_state();
auto* stream = local_device->GetDeviceToHostStream();
ScopedHold device_buffer(this, ScopedHold::kUsage);
{
absl::MutexLock lock(&mu_);
// We can't perform any other action while a donation hold is in progress.
WaitForOutstandingDonationHold();
if (device_buffer_ == nullptr) {
return InvalidArgument(
"logical_on_device_shape() called on deleted or donated buffer");
}
AcquireHoldLocked(&device_buffer);
}
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
StatusOr<EventPool::Handle> event_or =
local_device->event_pool().AllocateEvent(stream->parent());
if (!event_or.ok()) {
return event_or.status();
}
Shape ret_shape = on_device_shape_;
TransferManager* transfer_manager =
client_->client()->backend().transfer_manager();
TF_RETURN_IF_ERROR(
transfer_manager->ReadDynamicShapes(stream, &shaped_buffer, &ret_shape));
return ret_shape;
}
namespace {
// Implements PjRtBuffer::ExternalReference as a wrapped
// ScopedHold::kExternalReference.
class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference {
public:
explicit ScopedHoldAsExternalReference(
PjRtStreamExecutorBuffer::ScopedHold hold)
: external_reference_(std::move(hold)) {
CHECK(external_reference_.type() ==
PjRtStreamExecutorBuffer::ScopedHold::kExternalReference);
data_ptr_ = external_reference_->device_memory().front().opaque();
}
~ScopedHoldAsExternalReference() override = default;
private:
PjRtStreamExecutorBuffer::ScopedHold external_reference_;
};
} // namespace
StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
PjRtStreamExecutorBuffer::AcquireExternalReference() {
ScopedHold hold = GetBufferWithExternalReference();
Status hold_status = hold.status();
if (!hold_status.ok()) return hold_status;
return std::unique_ptr<ExternalReference>(
std::make_unique<ScopedHoldAsExternalReference>(std::move(hold)));
}
class TrackedDeviceBufferExternalReference
: public PjRtBuffer::ExternalReference {
public:
explicit TrackedDeviceBufferExternalReference(
std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer)
: tracked_device_buffer_(std::move(tracked_device_buffer)) {
data_ptr_ = tracked_device_buffer_->device_memory()[0].opaque();
}
~TrackedDeviceBufferExternalReference() override = default;
private:
std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer_;
};
StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>>
PjRtStreamExecutorBuffer::ReleaseDeviceMemoryOwnership(
bool wait_for_operations_to_complete) {
if (on_device_shape_.IsTuple()) {
return InvalidArgument(
"ReleaseDeviceMemoryOwnership allowed only for non-tuple");
}
TF_ASSIGN_OR_RETURN(
std::shared_ptr<TrackedDeviceBuffer> tracked_device_buffer,
Release(wait_for_operations_to_complete));
std::unique_ptr<PjRtBuffer::ExternalReference> ref;
if (tracked_device_buffer) {
ref = std::make_unique<TrackedDeviceBufferExternalReference>(
std::move(tracked_device_buffer));
}
return ref;
}
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::function<void()> on_done_with_host_buffer, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::BufferFromHostBuffer");
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString();
if (shape.IsTuple()) {
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
}
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState());
int64 size = ShapeUtil::ByteSizeOf(shape);
TransferManager* transfer_manager = client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(shape));
// The CPU platform is special because the "host" and the "device" are in the
// same memory space. If the input shape is in the correct layout and we don't
// want to defer the copy onto a thread, we can use the following fast
// path.
bool is_cpu_platform =
local_device->executor()->platform()->id() == se::host::kHostPlatformId;
if (is_cpu_platform) {
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any
// further copies. At the time of writing we require a 16-byte alignment
// because XLA may generate code which requires it.
bool can_use_zero_copy =
host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
((absl::bit_cast<std::uintptr_t>(data) &
(cpu_function_runtime::kMinAlign - 1)) == 0);
if (shape.layout() == compact_shape.layout() &&
(host_buffer_semantics ==
HostBufferSemantics::kImmutableOnlyDuringCall ||
can_use_zero_copy)) {
std::function<void()> on_delete_callback;
se::DeviceMemoryBase buffer;
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any
// further copies. At the time of writing we require a 16-byte alignment
// because XLA may generate code which requires it.
if (can_use_zero_copy) {
on_delete_callback = std::move(on_done_with_host_buffer);
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
} else {
void* staging_buffer = host_memory_allocator()->AllocateRaw(
cpu_function_runtime::kMinAlign, size);
buffer = se::DeviceMemoryBase(staging_buffer, size);
std::memcpy(staging_buffer, data, size);
if (on_done_with_host_buffer) {
on_done_with_host_buffer();
}
on_delete_callback = [staging_buffer, host_memory_allocator =
host_memory_allocator()]() {
host_memory_allocator->DeallocateRaw(staging_buffer);
};
}
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
definition_events;
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
/*allocator=*/nullptr, local_device->device_ordinal(),
std::initializer_list<se::DeviceMemoryBase>{buffer},
definition_events, std::move(on_delete_callback));
return std::unique_ptr<PjRtBuffer>(
std::make_unique<PjRtStreamExecutorBuffer>(
shape, std::move(device_buffer), this, device));
}
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
AllocateDestinationBuffer(compact_shape, device, local_device,
local_device->host_to_device_stream(),
/*is_uninitialized_create=*/false, this));
PjRtStreamExecutorBuffer::ScopedHold device_buffer(
py_buffer->GetBufferWithUsageHold());
CHECK(device_buffer.ok());
// If necessary, allocate a host-side buffer for staging host-to-device
// transfers. On GPU this is a buffer in pinned memory.
std::shared_ptr<void> staging_buffer;
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
should_stage_host_to_device_transfers()) {
void* ptr = host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size);
staging_buffer = std::shared_ptr<void>(
ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
host_memory_allocator->DeallocateRaw(ptr);
});
}
// Copy the buffer into a staging buffer before returning control to the
// caller if the caller only guaranteed that the buffer is valid for the
// duration of the call. Otherwise, we stage (if necessary) on a separate
// thread.
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
std::memcpy(staging_buffer.get(), data, size);
if (on_done_with_host_buffer) {
on_done_with_host_buffer();
on_done_with_host_buffer = nullptr;
}
data = nullptr;
}
// The host to device transfer is performed on a thread pool, mostly because
// it includes linearization that may be slow. It is OK to capture the
// py_buffer pointer because the py_buffer can't be deleted until all the
// usage holds have gone away.
// TODO(misard) assess if it would be preferable to introduce a heuristic to
// put the transfer into the calling thread for small literals.
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
data, size,
movable_device_buffer{device_buffer.ToClosure()}, shape,
py_buffer{py_buffer.get()},
on_device_shape{py_buffer->on_device_shape()},
staging_buffer{std::move(staging_buffer)},
on_done_with_host_buffer{
std::move(on_done_with_host_buffer)},
host_buffer_semantics]() {
PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
// to report failures from a callback. However, the operations here are
// unlikely to fail and not recoverable even if we were to fail: DMAs to
// memory that has already been allocated, and a possible Event
// allocation.
ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
// If applicable on the backend, stage the transfer via host memory
// allocated via the host_memory_allocator. On GPU, this is pinned
// memory.
if (staging_buffer) {
// If we didn't already copy the input buffer into the staging buffer,
// do so now.
if (host_buffer_semantics !=
HostBufferSemantics::kImmutableOnlyDuringCall) {
std::memcpy(staging_buffer.get(), data, size);
}
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
shape);
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
local_device->host_to_device_stream(), literal, buffer));
} else {
BorrowingLiteral literal(static_cast<const char*>(data), shape);
// Otherwise, just transfer the literal.
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
local_device->host_to_device_stream(), literal, buffer));
}
std::shared_ptr<BufferSequencingEvent> event =
device_buffer->definition_events()[0];
TF_CHECK_OK(AddDestinationBufferSynchronization(
local_device, std::move(device_buffer), event,
local_device->host_to_device_stream()));
local_device->callback_stream()->ThenWaitFor(
local_device->host_to_device_stream());
local_device->ThenExecuteOnCallbackThread(
local_device->callback_stream(),
[staging_buffer{std::move(staging_buffer)},
on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
if (on_done_with_host_buffer) {
on_done_with_host_buffer();
}
});
};
if (is_cpu_platform) {
// Using the thread_pool would be a double thread hop; the code
// already defers its work onto a stream (= thread on CPU).
transfer_h2d();
} else {
thread_pool()->Schedule(transfer_h2d);
}
return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
}
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
PjRtDevice* device) {
return CreateUninitializedBuffer(shape, device, nullptr);
}
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event) {
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::CreateUninitializedBuffer");
VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState());
TransferManager* transfer_manager = client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(shape));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
AllocateDestinationBuffer(compact_shape, device, local_device,
/*copy_stream=*/nullptr,
/*is_uninitialized_create=*/true, this,
definition_event));
return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
}
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::BufferFromHostLiteral");
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
<< literal.shape().ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState());
TransferManager* transfer_manager = client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(
Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
AllocateDestinationBuffer(compact_shape, device, local_device,
local_device->host_to_device_stream(),
/*is_uninitialized_create=*/false, this));
PjRtStreamExecutorBuffer::ScopedHold device_buffer(
py_buffer->GetBufferWithUsageHold());
CHECK(device_buffer.ok());
// The host to device transfer is performed on a thread pool, mostly because
// it includes linearization that may be slow. It is OK to capture the
// py_buffer pointer because the py_buffer can't be deleted until all the
// usage holds have gone away.
// TODO(misard) assess if it would be preferable to introduce a heuristic to
// put the transfer into the calling thread for small literals.
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
movable_device_buffer{device_buffer.ToClosure()},
literal, py_buffer{py_buffer.get()},
on_device_shape{py_buffer->on_device_shape()}]() {
PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
// to report failures from a callback. However, the operations here are
// unlikely to fail and not recoverable even if we were to fail: DMAs to
// memory that has already been allocated, and a possible Event
// allocation.
se::Stream* h2d_stream = local_device->host_to_device_stream();
ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
h2d_stream, literal, buffer));
std::shared_ptr<BufferSequencingEvent> event =
device_buffer->definition_events()[0];
TF_CHECK_OK(AddDestinationBufferSynchronization(
local_device, std::move(device_buffer), event, h2d_stream));
// This can sometimes catch the case where the literal memory has been
// freed before the H2D transfer was issued.
h2d_stream->RefreshStatus()
.IgnoreError(); // Can return error::Unimplemented
QCHECK(h2d_stream->ok());
};
thread_pool()->Schedule(transfer_h2d);
return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
}
void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) {
if (shapes.empty()) {
notifier(InvalidArgument(
"shapes parameter empty in MakeCrossHostReceiveBuffers"));
return;
}
auto local_device_or =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState();
if (!local_device_or.ok()) {
notifier(local_device_or.status());
return;
}
LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
std::shared_ptr<BufferSequencingEvent> definition_event =
std::make_shared<BufferSequencingEvent>();
std::vector<std::unique_ptr<PjRtBuffer>> buffers;
buffers.reserve(shapes.size());
for (const auto& shape : shapes) {
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(
shape, device, local_device,
/*copy_stream=*/nullptr,
/*is_uninitialized_create=*/false, this, definition_event);
if (!buffer_or.ok()) {
notifier(buffer_or.status());
return;
}
buffers.push_back(buffer_or.ConsumeValueOrDie());
}
EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event),
std::move(notifier));
}
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
void* device_ptr, const Shape& shape, PjRtDevice* device,
std::function<void()> on_delete_callback) {
se::DeviceMemoryBase buffer(device_ptr, ShapeUtil::ByteSizeOf(shape));
absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
/*allocator=*/nullptr, device->local_hardware_id(),
std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
std::move(on_delete_callback));
return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
shape, std::move(device_buffer), this, device));
}
// Transfer the given literal to the infeed queue of the given local device.
Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) {
// Only support infeed to local device.
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferToInfeedLocal(
literal, local_device->device_ordinal());
}
Status PjRtStreamExecutorDevice::TransferFromOutfeed(
MutableBorrowingLiteral literal) {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferFromOutfeedLocal(
local_device->device_ordinal(), literal);
}
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
int local_hardware_id) const {
for (auto* device : addressable_devices_) {
if (local_hardware_id == device->local_hardware_id()) {
return device;
}
}
return InvalidArgument("No matching device found for local_hardware_id %d",
local_hardware_id);
}
PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer(
Shape on_device_shape, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device)
: client_(tensorflow::down_cast<PjRtStreamExecutorClient*>(client)),
on_device_shape_(std::move(on_device_shape)),
device_(tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)),
device_buffer_(std::move(device_buffer)) {
for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
holds_[i] = 0;
}
}
PjRtStreamExecutorBuffer::~PjRtStreamExecutorBuffer() {
Delete();
for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
CHECK_EQ(holds_[i], 0);
}
}
int64 PjRtStreamExecutorBuffer::OnDeviceSizeInBytes() const {
return client_->client()
->backend()
.transfer_manager()
->GetByteSizeRequirement(on_device_shape_);
}
void PjRtStreamExecutorBuffer::WaitForOutstandingUsageHolds() {
auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return holds_[ScopedHold::kUsage] == 0;
};
mu_.Await(absl::Condition(&not_in_usage_hold));
}
void PjRtStreamExecutorBuffer::WaitForOutstandingDonationHold() {
auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return holds_[ScopedHold::kDonation] == 0;
};
mu_.Await(absl::Condition(&not_in_donation_hold));
}
StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
tensorflow::profiler::TraceMe trace_me("PjRtStreamExecutorBuffer::Release");
std::shared_ptr<TrackedDeviceBuffer> device_buffer;
TrackedDeviceBuffer::StreamAndEventContainer events;
{
absl::MutexLock lock(&mu_);
// We first wait for a donation hold to complete if there is one in
// progress. If the donation succeeds via ConfirmDonation() then it will
// set device_buffer_ to nullptr before returning to this thread.
WaitForOutstandingDonationHold();
if (device_buffer_ == nullptr) {
return std::shared_ptr<TrackedDeviceBuffer>();
}
// Set device_buffer_ to null now so that no other
// thread can add a hold while we are in WaitForOutstandingUsageHolds()
// below.
std::swap(device_buffer_, device_buffer);
WaitForOutstandingUsageHolds();
// Now that all holds have completed and no more can be added, we can get
// the final set of usage events.
events = device_buffer->LockUseAndTransferUsageEvents();
}
LocalDeviceState* local_device_state = device_->local_device_state();
if (wait_for_operations_to_complete) {
// Block the host until all usage events have completed. Usage events
// dominate definition events, so this also waits for the buffer to be
// defined.
std::unique_ptr<se::Stream> stream;
for (const auto& stream_and_event : events) {
if (!stream_and_event.event->IsComplete()) {
if (stream == nullptr) {
stream = local_device_state->BorrowStreamFromPool();
}
stream_and_event.event->WaitForEventOnStream(stream.get());
}
}
if (stream != nullptr) {
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
local_device_state->ReturnStreamToPool(std::move(stream));
}
} else {
if (local_device_state->allocation_model() ==
LocalDeviceState::kComputeSynchronized) {
std::unique_ptr<se::Stream> block_stream;
for (const auto& stream_and_event : events) {
// We only need to do something for events that didn't already acquire a
// reference to the buffer, and also which the compute stream didn't
// already wait for. Based on our heuristics this rare case should only
// occur when a buffer was copied to a device and then never used there.
// In that case we get a new stream and use it to hold onto a reference
// to the buffer until the events are complete.
if (!stream_and_event.reference_held &&
!stream_and_event.event->DefinedOn(
local_device_state->compute_stream()) &&
!stream_and_event.event->IsComplete()) {
if (block_stream == nullptr) {
block_stream = local_device_state->BorrowStreamFromPool();
}
stream_and_event.event->WaitForEventOnStream(block_stream.get());
}
}
if (block_stream != nullptr) {
se::Stream* block_stream_ptr = block_stream.release();
local_device_state->ThenExecuteOnCallbackThread(
block_stream_ptr,
[device_buffer, block_stream_ptr, local_device_state]() {
local_device_state->ReturnStreamToPool(
std::unique_ptr<se::Stream>(block_stream_ptr));
});
}
}
}
return device_buffer;
}
void PjRtStreamExecutorBuffer::Delete() {
// When wait_for_reads_to_complete is false, Release should never fail.
TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status());
}
bool PjRtStreamExecutorBuffer::IsDeleted() {
absl::MutexLock lock(&mu_);
return device_buffer_ == nullptr;
}
StatusOr<std::shared_ptr<TrackedDeviceBuffer>>
PjRtStreamExecutorBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
// All callers should have called WaitForOutstandingDonationHold().
CHECK_EQ(holds_[ScopedHold::kDonation], 0);
if (type == ScopedHold::kDonation) {
if (device_buffer_ == nullptr) {
return InvalidArgument("Donation requested for invalid buffer");
}
if (holds_[ScopedHold::kExternalReference] > 0) {
return InvalidArgument(
"Donation requested for buffer with external reference");
}
// First add the donation hold.
++holds_[type];
// Then wait for any usage holds to be dropped or converted. No new usage
// holds can be added until we drop the donation hold so this wait will
// complete eventually.
WaitForOutstandingUsageHolds();
// Because we added a donation hold, nobody could release the buffer while
// we were waiting.
CHECK(device_buffer_ != nullptr);
} else {
if (device_buffer_ == nullptr) {
return InvalidArgument("Hold requested on deleted or donated buffer");
} else {
++holds_[type];
}
}
return device_buffer_;
}
void PjRtStreamExecutorBuffer::AcquireHoldLocked(ScopedHold* hold) {
hold->Acquire(GetBufferForHoldLocked(hold->type()));
}
void PjRtStreamExecutorBuffer::ConvertUsageHold(
TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
std::shared_ptr<BufferSequencingEvent> event, bool reference_held) {
absl::MutexLock lock(&mu_);
CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
buffer->AddUsageEvent(usage_stream, std::move(event), reference_held);
CHECK_GT(holds_[ScopedHold::kUsage], 0);
--holds_[ScopedHold::kUsage];
}
void PjRtStreamExecutorBuffer::ConfirmDonation(
TrackedDeviceBuffer* device_buffer) {
{
absl::MutexLock lock(&mu_);
CHECK_EQ(holds_[ScopedHold::kUsage], 0);
CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
CHECK_EQ(holds_[ScopedHold::kDonation], 1);
holds_[ScopedHold::kDonation] = 0;
CHECK(device_buffer_.get() == device_buffer);
// As a sanity check ensure no more usage events can be added to the buffer.
device_buffer->LockUseAndTransferUsageEvents();
// Give up ownership of the device memory so we don't free it when the last
// reference to device_buffer_ goes away.
device_buffer->ReleaseDeviceMemory();
// Make *this invalid so it can't be used again. Any threads blocking in
// Release or GetBufferWithHold will see an invalid buffer and return.
device_buffer_.reset();
}
}
void PjRtStreamExecutorBuffer::DropHold(ScopedHold::Type type,
TrackedDeviceBuffer* buffer) {
absl::MutexLock lock(&mu_);
CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
CHECK_GT(holds_[type], 0);
--holds_[type];
if (type == ScopedHold::kDonation) {
CHECK_EQ(holds_[ScopedHold::kDonation], 0);
CHECK_EQ(holds_[ScopedHold::kUsage], 0);
CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
}
}
void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal,
std::function<void(Status)> on_ready) {
if (IsEmptyTuple()) {
on_ready(InvalidArgument("ToLiteral called on empty tuple"));
return;
}
LocalDeviceState* local_device = device_->local_device_state();
se::Stream* stream = local_device->GetDeviceToHostStream();
ScopedHold device_buffer(this, ScopedHold::kUsage);
{
absl::MutexLock lock(&mu_);
// We can't perform any other action while a donation hold is in progress.
WaitForOutstandingDonationHold();
if (device_buffer_ == nullptr) {
on_ready(InvalidArgument(
"CopyToHostAsync() called on deleted or donated buffer"));
return;
}
AcquireHoldLocked(&device_buffer);
}
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
StatusOr<EventPool::Handle> event_or =
local_device->event_pool().AllocateEvent(stream->parent());
if (!event_or.ok()) {
on_ready(event_or.status());
return;
}
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
stream, shaped_buffer, literal, std::move(on_ready));
auto usage_event = std::make_shared<BufferSequencingEvent>();
local_device->event_pool().ThenRecordEvent(stream, event_or.ValueOrDie());
usage_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
// When using the ComputeSynchronized allocation model, retain a reference to
// the device_buffer until the copy completes, to ensure that the buffer isn't
// deleted or donated while it is still in use. The choice of retaining a
// reference at the host is a heuristic; the alternative is to ensure, before
// freeing the buffer, that the compute stream is synchronized past the
// transfer, but it seems better to hold onto the buffer too long than to
// stall the compute stream, particularly since the overwhelmingly common
// use case of CopyToHostAsync will hold onto the reference long enough to
// read the buffer in a subsequent call to ToLiteral.
RecordUsage(std::move(device_buffer), local_device, local_device, usage_event,
stream,
/*prefer_to_retain_reference=*/true);
}
StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
absl::MutexLock lock(&mu_);
if (device_buffer_ == nullptr) {
return InvalidArgument(
"Attempted to fetch value of invalid/deleted buffer.");
}
return device_buffer_->AsShapedBuffer(on_device_shape_);
}
PjRtStreamExecutorBuffer::ScopedHold
PjRtStreamExecutorBuffer::GetBufferWithHold(ScopedHold::Type type) {
absl::MutexLock lock(&mu_);
// Ensure that at most one donation hold can be in progress at a time.
WaitForOutstandingDonationHold();
ScopedHold hold(this, type);
AcquireHoldLocked(&hold);
return hold;
}
StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
std::shared_ptr<BufferSequencingEvent>>>
PjRtStreamExecutorBuffer::CopyToDeviceHelper(
PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
AllocateDestinationBuffer(
ShapeUtil::DeviceShapeToHostShape(on_device_shape_),
dst_device, dst_local_device, transfer_stream,
/*is_uninitialized_create=*/false, client_));
TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
WaitForBufferDefinitionEventsOnStream(*src_device_buffer, transfer_stream);
ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold());
CHECK(dst_device_buffer.ok());
ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(on_device_shape_);
// Copy the leaf buffers.
StatusOr<std::shared_ptr<BufferSequencingEvent>> copy_event_or =
[&]() -> StatusOr<std::shared_ptr<BufferSequencingEvent>> {
for (const auto& leaf : src_buffer.buffers().leaves()) {
const ShapeIndex& index = leaf.first;
const se::DeviceMemoryBase& input_buffer = leaf.second;
const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index);
TF_RET_CHECK(input_buffer.size() == output_buffer.size())
<< "input: " << input_buffer.size()
<< " output: " << output_buffer.size();
if (input_buffer.size() != 0) {
TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice(
transfer_stream, dst_local_device->compute_stream(), input_buffer,
output_buffer));
}
}
std::shared_ptr<BufferSequencingEvent> event =
dst_device_buffer->definition_events()[0];
TF_RETURN_IF_ERROR(AddDestinationBufferSynchronization(
transfer_local_device, std::move(dst_device_buffer), event,
transfer_stream));
return event;
}();
if (!copy_event_or.ok()) {
StallStreamOnError(transfer_local_device, transfer_stream);
if (transfer_local_device == dst_local_device) {
// Some copies may have been enqueued before the error was returned, and
// StallStreamOnError only makes sure the destination device is ok, so
// make sure that the src buffer remains valid until after any transfers
// have completed.
device_->local_device_state()->ThenRelease(transfer_stream,
std::move(src_device_buffer));
}
return copy_event_or.status();
}
return std::pair<std::unique_ptr<PjRtBuffer>,
std::shared_ptr<BufferSequencingEvent>>(
std::unique_ptr<PjRtStreamExecutorBuffer>(std::move(py_buffer)),
copy_event_or.ConsumeValueOrDie());
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtStreamExecutorBuffer::CopyToDevice(
PjRtDevice* dst_device) {
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorBuffer::CopyToDevice");
if (dst_device == device_) {
return InvalidArgument(
"CopyToDevice cannot accept the same source and destination devices");
}
// Copying across PjRtClients involves a copy through the host.
if (dst_device->client() != client_) {
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
// Avoid use-after-free on `literal` due to unsequenced move and use.
Literal* literal_pointer = literal.get();
return dst_device->client()->BufferFromHostBuffer(
literal_pointer->untyped_data(), literal_pointer->shape(),
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy,
[literal{std::move(literal)}]() { /* frees literal */ }, dst_device);
}
TF_ASSIGN_OR_RETURN(
LocalDeviceState * dst_local_device,
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
->GetLocalDeviceState());
LocalDeviceState* transfer_local_device =
client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
: dst_local_device;
CHECK_EQ(dst_local_device->allocation_model(),
transfer_local_device->allocation_model());
se::Stream* transfer_stream =
transfer_local_device->GetDeviceToDeviceStream();
ScopedHold src_device_buffer(this, ScopedHold::kUsage);
{
absl::MutexLock lock(&mu_);
// We can't perform any other action while a donation hold is in progress.
WaitForOutstandingDonationHold();
if (device_buffer_ == nullptr) {
return InvalidArgument(
"CopyToDevice called on deleted or donated buffer");
}
AcquireHoldLocked(&src_device_buffer);
}
StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
std::shared_ptr<BufferSequencingEvent>>>
buffer_and_event_or = CopyToDeviceHelper(
dst_device, dst_local_device, transfer_local_device, transfer_stream,
src_device_buffer.buffer());
if (!buffer_and_event_or.ok()) {
return buffer_and_event_or.status();
}
auto& buffer_and_event = buffer_and_event_or.ValueOrDie();
std::unique_ptr<PjRtBuffer>& buffer = buffer_and_event.first;
std::shared_ptr<BufferSequencingEvent>& event = buffer_and_event.second;
// prefer_to_retain_reference=*/true means that, when using the
// ComputeSynchronized allocation model, retain a reference to the
// src_device_buffer until the copy completes. This is a heuristic; the
// alternative is to ensure, before freeing the buffer, that the compute
// stream is synchronized past the transfer, but it seems better to hold onto
// the buffer too long than to stall the compute stream.
RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
transfer_local_device, event, transfer_stream,
/*prefer_to_retain_reference=*/true);
return std::move(buffer);
}
Status PjRtStreamExecutorBuffer::CopyToRemoteDevice(
absl::string_view serialized_descriptor) {
return client_->CopyToRemoteDevice(this, serialized_descriptor);
}
Status PjRtStreamExecutorBuffer::BlockHostUntilReady() {
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorBuffer::BlockHostUntilReady");
std::shared_ptr<TrackedDeviceBuffer> device_buffer;
{
absl::MutexLock lock(&mu_);
if (device_buffer_ == nullptr) {
return InvalidArgument(
"BlockHostUntilReady() called on deleted or donated buffer");
}
device_buffer = device_buffer_;
}
LocalDeviceState* local_device_state = device_->local_device_state();
std::unique_ptr<se::Stream> stream;
for (auto& event : device_buffer->definition_events()) {
if (!event->IsComplete()) {
if (stream == nullptr) {
stream = local_device_state->BorrowStreamFromPool();
}
event->WaitForEventOnStream(stream.get());
}
}
if (stream != nullptr) {
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
local_device_state->ReturnStreamToPool(std::move(stream));
}
return Status::OK();
}
namespace {
// Helper struct for the tuple that is transiently constructed to hold the
// arguments of an execution.
struct TupleHandle {
// The ExecutionInput describing the tuple.
ExecutionInput execution_input;
// A definition event that has been recorded on the host_to_device stream
// after the tuple table transfer.
std::shared_ptr<BufferSequencingEvent> event;
};
Status CheckCompatibleShapes(bool strict_shape_checking,
const Shape& buffer_shape,
const Shape& execution_shape,
const TransferManager& transfer_manager,
int parameter_index) {
// TODO(misard) Support casting of tuple parameters.
if (strict_shape_checking || buffer_shape.IsTuple()) {
if (!ShapeUtil::Equal(buffer_shape, execution_shape)) {
return InvalidArgument(
"Executable expected shape %s for argument %d but got "
"incompatible "
"shape %s",
ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index,
ShapeUtil::HumanStringWithLayout(buffer_shape));
}
} else {
if (transfer_manager.GetByteSizeRequirement(buffer_shape) !=
transfer_manager.GetByteSizeRequirement(execution_shape)) {
return InvalidArgument(
"Executable expected shape %s for argument %d but got "
"incompatible "
"shape %s",
ShapeUtil::HumanStringWithLayout(execution_shape), parameter_index,
ShapeUtil::HumanStringWithLayout(buffer_shape));
}
}
return Status::OK();
}
// Makes a tuple from the arguments to an execution.
StatusOr<TupleHandle> MakeTupleHelper(
PjRtStreamExecutorClient* client, LocalDeviceState* local_device,
bool strict_shape_checking, const Shape& tupled_parameter_shape,
absl::Span<PjRtBuffer* const> py_buffers,
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
int device_ordinal) {
se::DeviceMemoryAllocator* allocator = client->allocator();
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
if (tupled_parameter_shape.tuple_shapes_size() != py_buffers.size()) {
return InvalidArgument("Executable expected %lld parameters but got %lld",
tupled_parameter_shape.tuple_shapes_size(),
py_buffers.size());
}
for (int i = 0; i < py_buffers.size(); ++i) {
TF_RETURN_IF_ERROR(CheckCompatibleShapes(
strict_shape_checking, py_buffers[i]->on_device_shape(),
tupled_parameter_shape.tuple_shapes(i), *transfer_manager, i));
}
se::Stream* stream = local_device->host_to_device_stream();
TF_ASSIGN_OR_RETURN(
se::OwningDeviceMemory root_table_memory,
allocator->Allocate(
device_ordinal,
transfer_manager->GetByteSizeRequirement(tupled_parameter_shape)));
if (local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized) {
stream->ThenWaitFor(local_device->compute_stream());
} else {
DCHECK(transfer_manager->CanBufferBeAccessedNow(
local_device->compute_stream()->parent(), root_table_memory.cref()));
}
ExecutionInput execution_input(tupled_parameter_shape);
ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
execution_input.MutableBuffers()->begin();
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
execution_input.MutableBuffers()->end();
// First set the root tuple table which is the first buffer in the ShapeTree.
execution_input.SetBuffer(
input_iterator->first,
MaybeOwningDeviceMemory(std::move(root_table_memory)));
++input_iterator;
// Then set each sub-tuple in turn from the parameters.
for (const PjRtStreamExecutorBuffer::ScopedHold& device_buffer :
device_buffers) {
device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input,
allocator);
}
CHECK(input_iterator == iterator_end);
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
stream, execution_input.Buffers()));
StatusOr<EventPool::Handle> event_or =
local_device->event_pool().ThenAllocateAndRecordEvent(stream);
if (!event_or.ok()) {
StallStreamOnError(local_device, stream);
return event_or.status();
}
auto transfer_event = std::make_shared<BufferSequencingEvent>();
transfer_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
return TupleHandle({std::move(execution_input), std::move(transfer_event)});
}
// Converts a ScopedShapedBuffer returned from an execution into a
// PjRtBuffer.
std::unique_ptr<PjRtBuffer> OutputBufferHelper(
ScopedShapedBuffer* result_buffer,
std::shared_ptr<BufferSequencingEvent> definition_event, PjRtClient* client,
PjRtDevice* device, LocalDeviceState* local_device,
std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release) {
std::shared_ptr<TrackedDeviceBuffer> out_buffer =
TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
{definition_event});
auto pjrt_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
result_buffer->on_device_shape(), std::move(out_buffer), client, device);
RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
definition_event, local_device->compute_stream(),
/*prefer_to_retain_reference=*/false, &buffers_to_release);
return std::unique_ptr<PjRtBuffer>(std::move(pjrt_buffer));
}
} // namespace
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client)
: client_(client),
device_assignment_(std::move(device_assignment)),
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
addressable_device_logical_ids_(
std::move(addressable_device_logical_ids)),
addressable_devices_(std::move(addressable_devices)) {
TransferManager* transfer_manager =
client_->client()->backend().transfer_manager();
executables_.reserve(executables.size());
for (auto& executable : executables) {
const auto& computation_layout =
executable->executable()->module().entry_computation_layout();
std::vector<Shape> parameter_shapes;
parameter_shapes.reserve(computation_layout.parameter_count());
for (int i = 0; i < computation_layout.parameter_count(); ++i) {
parameter_shapes.push_back(transfer_manager->HostShapeToDeviceShape(
computation_layout.parameter_shape(i)));
}
executables_.emplace_back(std::move(executable));
on_device_executable_parameter_shapes_.push_back(
std::move(parameter_shapes));
}
int num_partitions;
if (device_assignment_ == nullptr) {
// This must go after `executables_` is initialized.
VLOG(1) << "PjRtStreamExecutorExecutable portable single-core";
num_partitions = 1;
CHECK(addressable_devices_.empty());
} else {
// This must go after `executables_` is initialized.
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
<< device_assignment_->ToString();
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
<< "Inconsistent local device count.";
num_partitions = device_assignment_->computation_count();
}
// SPMD sharding produces a single executable for multiple partitions.
if (executables_.size() > 1) {
CHECK_EQ(num_partitions, executables_.size())
<< "Number of executables " << executables_.size()
<< " did not match number of partitions " << num_partitions;
}
}
Status PjRtStreamExecutorExecutable::SetUpDonation(bool tuple_inputs) {
parameters_that_must_be_donated_.reserve(executables_.size());
for (auto& executable : executables_) {
TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> parameters_to_donate,
GetParametersThatMustBeDonated(
executable->executable()->module(), tuple_inputs));
parameters_that_must_be_donated_.emplace_back(
std::move(parameters_to_donate));
}
return Status::OK();
}
absl::string_view PjRtStreamExecutorExecutable::name() const {
Executable* executable = executables_[0]->executable();
if (executable->has_module()) {
return executable->module().name();
} else {
return "<unknown executable>";
}
}
bool PjRtStreamExecutorExecutable::MustDonateParameter(int executable_idx,
int parameter) const {
return parameters_that_must_be_donated_[executable_idx].contains(parameter);
}
StatusOr<std::vector<ExecutionInput>>
PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
int device_ordinal, const ExecuteOptions& options,
absl::Span<const Shape> executable_parameter_shapes,
absl::Span<PjRtBuffer* const> argument_handles,
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
std::vector<ExecutionInput> execution_inputs;
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
TransferManager* transfer_manager =
client_->client()->backend().transfer_manager();
// Lift tuple_handle outside the conditional so that the event it returns is
// not destroyed until after the loop below that waits on events.
absl::optional<TupleHandle> tuple_handle;
if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
TF_ASSIGN_OR_RETURN(
tuple_handle,
MakeTupleHelper(client_, device_state, options.strict_shape_checking,
executable_parameter_shapes[0], argument_handles,
device_buffers, device_ordinal));
events.insert(tuple_handle->event.get());
execution_inputs.emplace_back(std::move(tuple_handle->execution_input));
} else {
if (argument_handles.size() != executable_parameter_shapes.size()) {
return InvalidArgument("Executable expected %lld arguments but got %lld",
executable_parameter_shapes.size(),
argument_handles.size());
}
execution_inputs.reserve(argument_handles.size());
for (int i = 0; i < argument_handles.size(); ++i) {
PjRtBuffer* handle = argument_handles[i];
// Make an ExecutionInput from the device buffer.
TF_RETURN_IF_ERROR(CheckCompatibleShapes(
options.strict_shape_checking, handle->on_device_shape(),
executable_parameter_shapes[i], *transfer_manager, i));
execution_inputs.emplace_back(executable_parameter_shapes[i]);
ExecutionInput& execution_input = execution_inputs.back();
ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
execution_input.MutableBuffers()->begin();
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
execution_input.MutableBuffers()->end();
device_buffers[i].AddToInput(&input_iterator, iterator_end,
&execution_input, client_->allocator());
CHECK(input_iterator == iterator_end);
}
}
for (BufferSequencingEvent* event : events) {
event->WaitForEventOnStream(device_state->compute_stream());
}
return execution_inputs;
}
// Enqueues a computation onto the compute stream. Each buffer returned in
// device_buffers has a usage hold added that must be dropped on error or
// converted on success.
StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
PjRtDevice* device,
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<std::function<void()>>& compute_callbacks) const {
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state()
->device_ordinal();
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
tensorflow::profiler::TraceMeConsumer activity(
"PjRtStreamExecutorExecutable::EnqueueExecution",
tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
VLOG(3) << "Replica " << replica << ", partition " << partition
<< " mapped to device ordinal for execution: " << device_ordinal;
absl::flat_hash_set<BufferSequencingEvent*> events;
device_buffers->reserve(argument_handles.size());
for (int i = 0; i < argument_handles.size(); ++i) {
auto* handle =
tensorflow::down_cast<PjRtStreamExecutorBuffer*>(argument_handles[i]);
if (handle->device() != device) {
return InvalidArgument(
"Buffer passed to Execute() as argument %d to replica %d is on "
"device %s, but replica is assigned to device %s.",
i, replica, handle->device()->DebugString(), device->DebugString());
}
bool must_donate = MustDonateParameter(executable_idx, i);
device_buffers->emplace_back(handle->GetBufferWithHold(
must_donate ? PjRtStreamExecutorBuffer::ScopedHold::kDonation
: PjRtStreamExecutorBuffer::ScopedHold::kUsage));
PjRtStreamExecutorBuffer::ScopedHold& device_buffer =
device_buffers->back();
if (!device_buffer.ok()) {
return InvalidArgument(
"Invalid buffer passed to Execute() as argument %d to replica %d: "
"%s",
i, replica, device_buffer.status().ToString());
}
// If we are trying to donate the buffer wait on the usage events as well
// as the definition events to ensure that all reads have been completed
// before the buffer is mutated. Usage holds are excluded during a donation
// hold so we know that the set of usage events won't be modified while we
// are enqueueing.
GetDeviceBufferEvents(*device_buffer, /*get_usage_events=*/must_donate,
&events);
}
if (options.arguments_are_tupled) {
if (!parameter_is_tupled_arguments_) {
return InvalidArgument(
"Arguments may only be supplied as a tuple when the executable was "
"compiled with a single tupled parameter");
}
if (argument_handles.size() != 1) {
return InvalidArgument(
"Option arguments_are_tupled was true but %d buffers were passed to "
"execution",
argument_handles.size());
}
}
TF_ASSIGN_OR_RETURN(
std::vector<ExecutionInput> execution_inputs,
MakeExecutionInputsAndWaitForEvents(
device_ordinal, options,
on_device_executable_parameter_shapes_[executable_idx],
argument_handles, *device_buffers, events));
ExecutableRunOptions run_options;
run_options.set_stream(device_state->compute_stream());
run_options.set_host_to_device_stream(device_state->host_to_device_stream());
run_options.set_allocator(client_->allocator());
run_options.set_intra_op_thread_pool(
client_->client()->backend().eigen_intra_op_thread_pool_device());
run_options.set_device_assignment(device_assignment.get());
run_options.set_run_id(run_id);
run_options.set_rng_seed(device_state->GetNewPrngSeed());
run_options.set_gpu_executable_run_options(client_->gpu_run_options());
run_options.set_launch_id(options.launch_id);
if (run_options.launch_id() != 0) {
VLOG(1) << "launch id for " << name() << ": " << run_options.launch_id();
}
// The choice of where we wait is arbitrary; the reason for the wait is
// pacing to avoid problems such as memory fragmentation and running ahead
// too far, not for correctness. Placing it before the executable launch
// allows the inputs for the next executable to be fetched even if the
// launch is delayed.
auto compute_reservation = std::make_shared<Semaphore::ScopedReservation>(
device_state->compute_semaphore().ScopedAcquire(1));
StatusOr<ExecutionOutput> result_buffer_or_status =
executables_[executable_idx]->RunAsync(std::move(execution_inputs),
run_options);
VLOG(1) << "Replica " << replica << " partition " << partition
<< " completed; ok=" << result_buffer_or_status.ok();
if (!result_buffer_or_status.ok()) {
return result_buffer_or_status.status();
}
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
ExecutionOutput& execution_output = result_buffer_or_status.ValueOrDie();
// If we used a transient tuple for the arguments we donated its root table
// buffer. In that case, and/or if we donated any input buffers that were
// not aliased, the donated buffers are going to be passed back to us via
// the execution output. We need to ensure they aren't freed until after
// execution completes. (Currently XLA does not support aliasing tuple
// tables, so if any donated parameter is a tuple there will be donated but
// unaliased buffers.)
std::vector<se::OwningDeviceMemory> donated_memory =
execution_output.ConsumeToBeReleased();
absl::InlinedVector<se::DeviceMemoryBase, 3> donated_ptrs;
donated_ptrs.reserve(donated_memory.size());
for (se::OwningDeviceMemory& owning : donated_memory) {
// Release the owning memory so we can pass it to the closure.
donated_ptrs.push_back(owning.Release());
}
compute_callbacks.push_back(
[references{std::make_tuple(executables_[executable_idx],
compute_reservation, device_assignment)},
donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()},
device_ordinal]() {
for (const auto& ptr : donated_ptrs) {
TF_CHECK_OK(allocator->Deallocate(device_ordinal, ptr));
}
});
} else {
// Any donated memory returned by the ExecutionOutput can be immediately
// freed.
compute_callbacks.push_back(
[to_release{std::make_tuple(executables_[executable_idx],
compute_reservation,
device_assignment)}]() {});
}
return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult();
}
std::vector<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorExecutable::MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options,
ScopedShapedBuffer result_buffer,
std::shared_ptr<BufferSequencingEvent> definition_event, PjRtDevice* device,
std::vector<std::function<void()>>& compute_callbacks,
std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release)
const {
tensorflow::profiler::TraceMe traceme("MakeOutputBuffers");
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
if (options.untuple_result && result_buffer.on_device_shape().IsTuple()) {
int tuple_count = result_buffer.on_device_shape().tuple_shapes_size();
outputs.reserve(tuple_count);
// Take ownership of each of the output values, leaving only the root table
// in result_buffer.
for (int i = 0; i < tuple_count; ++i) {
ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i});
outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event,
client_, device, device_state,
buffers_to_release));
}
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
// Don't release the root buffer until after execution completes.
ShapedBuffer root_buffer_holder = result_buffer.release();
se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer();
compute_callbacks.push_back(
[root_buffer, allocator{client_->allocator()}, device_ordinal]() {
TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer));
});
}
} else {
outputs.push_back(OutputBufferHelper(&result_buffer, definition_event,
client_, device, device_state,
buffers_to_release));
}
return outputs;
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::ExecuteHelper(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
const RunId& run_id, const ExecuteOptions& options,
PjRtDevice* device) const {
std::shared_ptr<DeviceAssignment> device_assignment;
if (device == nullptr) {
CHECK(device_assignment_ != nullptr);
const int device_id = (*device_assignment_)(replica, partition);
TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
device_assignment = device_assignment_;
} else {
CHECK(device_assignment_ == nullptr);
CHECK_EQ(replica, 0);
CHECK_EQ(partition, 0);
CHECK(addressable_devices_.empty());
device_assignment = std::make_shared<DeviceAssignment>(1, 1);
(*device_assignment)(0, 0) = device->id();
}
CHECK_EQ(device->task_id(), client_->task_id());
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state()
->device_ordinal();
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorExecutable::ExecuteHelper");
VLOG(3) << "Replica " << replica << ", partition " << partition
<< " mapped to device ordinal for execution: " << device_ordinal;
// SPMD sharding produces a single executable for multiple partitions.
int executable_idx = executables_.size() > 1 ? partition : 0;
std::vector<std::function<void()>> compute_callbacks;
std::vector<PjRtStreamExecutorBuffer::ScopedHold> device_buffers;
device_buffers.reserve(argument_handles.size());
StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
argument_handles, replica, partition, executable_idx, run_id, options,
device, &device_buffers, std::move(device_assignment), compute_callbacks);
if (!result_buffer_or_status.ok()) {
LOG(ERROR) << "Execution of replica " << replica
<< " failed: " << result_buffer_or_status.status();
return result_buffer_or_status.status();
}
ScopedShapedBuffer result_buffer =
result_buffer_or_status.ConsumeValueOrDie();
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
se::Stream* stream = device_state->compute_stream();
StatusOr<EventPool::Handle> event_or =
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
if (!event_or.ok()) {
StallStreamOnError(device_state, stream);
for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation) {
// Even though there was an error we need to call ConfirmDonation, which
// renders b invalid, since the computation has been enqueued and b has
// been donated.
b.ConfirmDonation();
}
}
return event_or.status();
}
auto definition_event = std::make_shared<BufferSequencingEvent>();
definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream);
std::vector<std::shared_ptr<TrackedDeviceBuffer>> buffers_to_release;
std::vector<std::unique_ptr<PjRtBuffer>> outputs = MakeOutputBuffers(
device_ordinal, options, std::move(result_buffer), definition_event,
device, compute_callbacks, buffers_to_release);
for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) {
// prefer_to_retain_reference=false because when using the
// ComputeSynchronized allocation model we don't need to retain a reference
// to the device_buffer during execution because by definition the compute
// stream is synchronized past the execution.
if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) {
RecordUsage(std::move(b), device_state, device_state, definition_event,
stream,
/*prefer_to_retain_reference=*/false, &buffers_to_release);
} else {
CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation);
b.ConfirmDonation();
}
}
if (!compute_callbacks.empty()) {
device_state->callback_stream()->ThenWaitFor(stream);
device_state->ThenExecuteOnCallbackThread(
device_state->callback_stream(),
[callbacks{std::move(compute_callbacks)},
buffers_to_release{std::move(buffers_to_release)}]() {
for (auto& fn : callbacks) {
fn();
}
});
}
return outputs;
}
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
PjRtStreamExecutorExecutable::Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) {
if (device_assignment_ == nullptr) {
return InvalidArgument("Execute expects a non-null device_assignment");
}
RunId run_id;
tensorflow::profiler::TraceMeProducer activity(
"PjRtStreamExecutorExecutable::Execute",
tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
const int num_addressable_devices = addressable_devices_.size();
if (argument_handles.size() != num_addressable_devices) {
return InvalidArgument(
"Attempted to execute with %d argument lists when local device "
"count is %d (total replica count: %d, partition count: %d)",
argument_handles.size(), num_addressable_devices, num_replicas(),
num_partitions());
}
VLOG(1) << "Executing computation " << name()
<< "; num_replicas=" << num_replicas()
<< " num_partitions=" << num_partitions()
<< " num_addressable_devices=" << num_addressable_devices;
std::vector<StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>> results(
num_addressable_devices);
if (num_addressable_devices == 1) {
// Fast-path if there is only one device — run the computation on the
// current thread.
const int replica = addressable_device_logical_ids_[0].replica;
const int partition = addressable_device_logical_ids_[0].partition;
results[0] =
ExecuteHelper(argument_handles[0], replica, partition, run_id, options);
} else {
absl::Mutex mu;
int running = num_addressable_devices;
int failed = 0;
Status first_failure_status;
for (int i = 0; i < num_addressable_devices; ++i) {
const int replica = addressable_device_logical_ids_[i].replica;
const int partition = addressable_device_logical_ids_[i].partition;
PjRtDevice* device = addressable_devices_[i];
const LocalDeviceState& device_state =
*tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state();
device_state.execute_thread()->Schedule([&, replica, partition, i] {
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
run_id, options);
absl::MutexLock lock(&mu);
--running;
if (!results[i].ok()) {
if (failed == 0) {
first_failure_status = results[i].status();
}
++failed;
}
});
}
auto done_running_or_failed = [&]() {
mu.AssertHeld();
return running == 0 || failed > 0;
};
absl::MutexLock lock(&mu);
mu.Await(absl::Condition(&done_running_or_failed));
if (failed > 0) {
auto done_running = [&]() {
mu.AssertHeld();
return running == 0;
};
// If execution does not terminate within a reasonable amount of time,
// we may be stuck at a cross-replica barrier on-device. Terminate the
// process since that's the only way we can escape this situation at the
// moment (b/130629719).
if (!mu.AwaitWithTimeout(absl::Condition(&done_running),
absl::Seconds(10))) {
LOG(FATAL)
<< "Replicated computation launch failed, but not all replicas "
"terminated. Aborting process to work around deadlock. "
"Failure message (there may have been multiple failures, see "
"the error log for all failures): \n\n"
<< first_failure_status.error_message();
}
}
}
VLOG(1) << "Replicated execution complete.";
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
num_addressable_devices);
for (int i = 0; i < num_addressable_devices; ++i) {
const int replica = addressable_device_logical_ids_[i].replica;
const int partition = addressable_device_logical_ids_[i].partition;
auto& statusor = results[i];
if (!statusor.ok()) {
if (num_addressable_devices == 1) {
return statusor.status();
} else {
return AppendStatus(
statusor.status(),
absl::StrFormat("while running replica %d and partition %d of a "
"replicated computation (other "
"replicas may have failed as well).",
replica, partition));
}
}
wrapped_results[i] = std::move(statusor.ValueOrDie());
}
return wrapped_results;
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) {
if (device_assignment_ == nullptr) {
return InvalidArgument("ExecuteShard expects a non-null device_assignment");
}
for (int i = 0; i < addressable_devices_.size(); ++i) {
if (addressable_devices_[i] == device) {
VLOG(1) << "ExecuteShard executes computation " << name()
<< " on assigned replica/partition on device "
<< device->DebugString();
return ExecuteHelper(
argument_handles, addressable_device_logical_ids_[i].replica,
addressable_device_logical_ids_[i].partition, RunId(), options);
}
}
return InvalidArgument(
"ExecuteShard attempted to execute on device id %d which is not "
"addressable by this client",
device->id());
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) {
if (device_assignment_ != nullptr) {
return InvalidArgument("ExecutePortable gets a non-portable executable");
}
if (num_replicas() != 1 || num_partitions() != 1) {
return InvalidArgument(
"ExecutePortable expects a single-core executable but gets "
"one with %d replica %d partition",
num_replicas(), num_partitions());
}
if (device == nullptr) {
return InvalidArgument("ExecutePortable expects a device to be specified");
}
VLOG(1) << "ExecutePortable executes single-core portable executable "
<< name();
return ExecuteHelper(argument_handles,
/*replica=*/0,
/*partition=*/0, RunId(), options, device);
}
StatusOr<std::vector<std::shared_ptr<HloModule>>>
PjRtStreamExecutorExecutable::GetHloModules() const {
std::vector<std::shared_ptr<HloModule>> modules;
modules.reserve(executables().size());
for (const auto& local_exec : executables()) {
if (!local_exec->executable()->has_module()) {
return InvalidArgument("Executable does not have HLO modules.");
}
modules.push_back(local_exec->executable()->shared_module());
}
return std::move(modules);
}
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
const XlaComputation& computation, CompileOptions options) {
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
ExecutableBuildOptions& build_options = options.executable_build_options;
if (!build_options.compile_thread_pool()) {
build_options.set_compile_thread_pool(thread_pool());
}
if (!build_options.device_allocator()) {
build_options.set_device_allocator(allocator());
}
int num_replicas;
int num_partitions;
std::shared_ptr<DeviceAssignment> device_assignment;
TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions(
options.compile_portable_executable, &options.executable_build_options,
[this](int num_replicas, int num_partitions) {
return this->GetDefaultDeviceAssignment(num_replicas, num_partitions);
},
&num_replicas, &num_partitions, &device_assignment));
std::vector<const Shape*> argument_layout_pointers;
TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions(
computation,
[local_client = client()](Shape shape) {
return local_client->backend()
.transfer_manager()
->ChooseCompactLayoutForShape(shape);
},
options.argument_layouts, &options.executable_build_options,
&argument_layout_pointers));
// Find devices that are addressable by this client/task.
std::vector<PjRtExecutable::LogicalDeviceIds> addressable_device_logical_ids;
std::vector<PjRtDevice*> addressable_devices;
if (device_assignment != nullptr) {
addressable_device_logical_ids.reserve(num_replicas * num_partitions);
addressable_devices.reserve(num_replicas * num_partitions);
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition);
TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
if (device->task_id() != task_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
}
PjRtExecutable::LogicalDeviceIds logica_device_ids;
logica_device_ids.replica = replica;
logica_device_ids.partition = partition;
addressable_device_logical_ids.push_back(std::move(logica_device_ids));
addressable_devices.push_back(device);
}
}
if (addressable_devices.empty()) {
return InvalidArgument(
"Device assignment (%s) does not have any local devices.",
device_assignment->ToString());
}
if (build_options.device_ordinal() < 0) {
build_options.set_device_ordinal(
addressable_devices.front()->local_hardware_id());
}
}
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<LocalExecutable>> local_executables,
client()->Compile(computation, argument_layout_pointers, build_options));
auto executable = absl::make_unique<PjRtStreamExecutorExecutable>(
std::move(local_executables), options.parameter_is_tupled_arguments,
std::move(device_assignment), std::move(addressable_device_logical_ids),
std::move(addressable_devices), this);
TF_RETURN_IF_ERROR(
executable->SetUpDonation(options.parameter_is_tupled_arguments));
return std::unique_ptr<PjRtExecutable>(std::move(executable));
}
} // namespace xla