| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ |
| #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ |
| |
| #include <functional> |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/container/inlined_vector.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/synchronization/mutex.h" |
| #include "absl/synchronization/notification.h" |
| #include "absl/types/optional.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/xla/client/executable_build_options.h" |
| #include "tensorflow/compiler/xla/client/local_client.h" |
| #include "tensorflow/compiler/xla/client/xla_computation.h" |
| #include "tensorflow/compiler/xla/layout.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/pjrt/local_device_state.h" |
| #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" |
| #include "tensorflow/compiler/xla/pjrt/pjrt_future.h" |
| #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" |
| #include "tensorflow/compiler/xla/pjrt/transpose.h" |
| #include "tensorflow/compiler/xla/service/computation_layout.h" |
| #include "tensorflow/compiler/xla/service/computation_placer.h" |
| #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" |
| #include "tensorflow/compiler/xla/service/hlo_module.h" |
| #include "tensorflow/compiler/xla/service/shaped_buffer.h" |
| #include "tensorflow/compiler/xla/shape.h" |
| #include "tensorflow/compiler/xla/status.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/framework/allocator.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/platform/casts.h" |
| #include "tensorflow/stream_executor/stream.h" |
| |
| namespace xla { |
| |
| class PjRtStreamExecutorDevice : public PjRtDevice { |
| public: |
| explicit PjRtStreamExecutorDevice( |
| int id, std::unique_ptr<LocalDeviceState> local_device_state, |
| std::string device_kind, int process_index = 0) |
| : id_(id), |
| device_ordinal_( |
| local_device_state ? local_device_state->device_ordinal() : -1), |
| local_device_state_(std::move(local_device_state)), |
| process_index_(process_index), |
| device_kind_(std::move(device_kind)) {} |
| ~PjRtStreamExecutorDevice() override {} |
| |
| // Must set client exactly once. |
| void SetClient(PjRtClient* client) { |
| CHECK(client_ == nullptr); |
| client_ = client; |
| } |
| |
| int process_index() const override { return process_index_; } |
| |
| // Return `platform_id` from client. |
| PjRtPlatformId platform_id() const; |
| |
| // Return `platform_name` from client. |
| absl::string_view platform_name() const; |
| |
| PjRtClient* client() const override { return client_; } |
| |
| int id() const override { return id_; } |
| |
| bool IsAddressable() const override { return device_ordinal_ != -1; } |
| |
| int local_hardware_id() const override { return device_ordinal_; } |
| |
| // If this is a device local to this host, returns a LocalDeviceState object |
| // that can be used to manipulate the device. Returns nullptr if the device is |
| // not local to this host. |
| LocalDeviceState* local_device_state() const { |
| return local_device_state_.get(); |
| } |
| |
| // If this is a device local to this host, returns a LocalDeviceState object |
| // that can be used to manipulate the device. Returns an error if the device |
| // is not local to this host. |
| StatusOr<LocalDeviceState*> GetLocalDeviceState() const; |
| |
| absl::string_view device_kind() const override { return device_kind_; } |
| |
| std::string ToString() const override; |
| |
| std::string DebugString() const override; |
| |
| Status TransferToInfeed(const LiteralSlice& literal) override; |
| |
| Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; |
| |
| std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent( |
| absl::string_view description) const override { |
| return nullptr; |
| } |
| |
| const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes() |
| const override { |
| return attributes_; |
| } |
| |
| protected: |
| absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_; |
| |
| private: |
| const int id_; |
| const int device_ordinal_; // -1 means not local. |
| const std::unique_ptr<LocalDeviceState> local_device_state_; |
| const int process_index_; |
| const std::string device_kind_; |
| PjRtClient* client_ = nullptr; |
| }; |
| |
| class PjRtStreamExecutorClient : public PjRtClient { |
| public: |
| // `allocator` may null, in which case the platform default allocator is used. |
| explicit PjRtStreamExecutorClient( |
| std::string platform_name, LocalClient* client, |
| std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, |
| int process_index, std::unique_ptr<se::DeviceMemoryAllocator> allocator, |
| std::unique_ptr<tensorflow::Allocator> host_memory_allocator, |
| bool should_stage_host_to_device_transfers, |
| std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options); |
| ~PjRtStreamExecutorClient() override = default; |
| |
| int process_index() const override { return process_index_; } |
| |
| int device_count() const override { return devices_.size(); } |
| int addressable_device_count() const override { |
| return addressable_devices_.size(); |
| } |
| absl::Span<PjRtDevice* const> devices() const override { return devices_; } |
| absl::Span<PjRtDevice* const> addressable_devices() const override { |
| return addressable_devices_; |
| } |
| |
| StatusOr<PjRtDevice*> LookupDevice(int device_id) const override { |
| auto it = id_to_device_.find(device_id); |
| if (it != id_to_device_.end()) { |
| return it->second; |
| } |
| return InvalidArgument("No matching device found for device_id %d", |
| device_id); |
| } |
| |
| StatusOr<PjRtDevice*> LookupAddressableDevice( |
| int local_hardware_id) const override; |
| |
| PjRtPlatformId platform_id() const override { return platform_id_; } |
| absl::string_view platform_name() const override { return platform_name_; } |
| absl::string_view platform_version() const override { return "<unknown>"; } |
| PjRtRuntimeType runtime_type() const override { return kStreamExecutor; } |
| |
| // Most platforms expect device-to-device transfers to be enqueued on the |
| // source d2d stream, but some platforms use the destination d2d stream. This |
| // function specifies which one the platform expects. |
| virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } |
| |
| StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( |
| int num_replicas, int num_partitions) const override; |
| |
| StatusOr<std::unique_ptr<PjRtExecutable>> Compile( |
| const XlaComputation& computation, CompileOptions options) override; |
| StatusOr<std::unique_ptr<PjRtExecutable>> Compile( |
| mlir::ModuleOp mlir_module, CompileOptions options) override; |
| |
| StatusOr<std::optional<std::string>> ExecutableFingerprint( |
| const PjRtExecutable& executable) const override { |
| return std::optional<std::string>(); |
| } |
| |
| StatusOr<std::string> SerializeExecutable( |
| const PjRtExecutable& executable) const override { |
| return Unimplemented("SerializeExecutable not implemented on %s", |
| platform_name()); |
| } |
| |
| StatusOr<std::unique_ptr<PjRtExecutable>> DeserializeExecutable( |
| absl::string_view serialized, CompileOptions options) override { |
| return Unimplemented("DeserializeExecutable not implemented on %s", |
| platform_name()); |
| } |
| |
| StatusOr<std::unique_ptr<HloCostAnalysis>> GetHloCostAnalysis() override; |
| |
| // Creates a buffer on the device without initializing or copying any data. |
| // An optional `definition_event` may be speficied that can be used to |
| // ensure the buffer isn't referenced until some external mechanism has |
| // initialized the data. |
| StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( |
| const Shape& shape, PjRtDevice* device) override; |
| StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( |
| const Shape& shape, PjRtDevice* device, |
| std::shared_ptr<BufferSequencingEvent> definition_event); |
| |
| StatusOr<std::unique_ptr<PjRtClient::AsyncBufferTransferManager>> |
| CreateBuffersForAsyncTransfer(absl::Span<const Shape> shapes, |
| PjRtDevice* device) override { |
| return Unimplemented("Async transfer to buffers not implemented"); |
| }; |
| |
| StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( |
| const void* data, PrimitiveType type, absl::Span<int64_t const> dims, |
| std::optional<absl::Span<int64_t const>> byte_strides, |
| HostBufferSemantics host_buffer_semantics, |
| std::function<void()> on_done_with_host_buffer, |
| PjRtDevice* device) override; |
| |
| StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( |
| const LiteralSlice& literal, PjRtDevice* device) override; |
| |
| StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> |
| MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes, |
| PjRtDevice* device, |
| PjRtCrossHostRecvNotifier notifier) override; |
| |
| StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> |
| MakeCrossHostReceiveBuffersForGather( |
| absl::Span<const Shape> shapes, std::vector<GatherDetails> gather_details, |
| PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) override; |
| |
| StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer( |
| void* device_ptr, const Shape& shape, PjRtDevice* device, |
| std::function<void()> on_delete_callback) override; |
| |
| StatusOr<ChannelHandle> CreateChannelHandle() override { |
| return client()->CreateChannelHandle(); |
| } |
| StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override { |
| return client()->CreateDeviceToHostChannelHandle(); |
| } |
| StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override { |
| return client()->CreateHostToDeviceChannelHandle(); |
| } |
| |
| // TODO(zhangqiaorjc): Experimental. Will be removed. |
| Status Defragment() override { |
| return Unimplemented("Defragment not implemented"); |
| } |
| |
| LocalDeviceState& device_state(int device_ordinal) const { |
| return *tensorflow::down_cast<PjRtStreamExecutorDevice*>( |
| addressable_devices_.at(device_ordinal)) |
| ->local_device_state(); |
| } |
| LocalClient* client() const { return client_; } |
| se::DeviceMemoryAllocator* allocator() const { return allocator_; } |
| tensorflow::Allocator* host_memory_allocator() const { |
| return host_memory_allocator_.get(); |
| } |
| bool should_stage_host_to_device_transfers() const { |
| return should_stage_host_to_device_transfers_; |
| } |
| |
| gpu::GpuExecutableRunOptions* gpu_run_options() const { |
| return gpu_run_options_.get(); |
| } |
| |
| tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; } |
| |
| protected: |
| friend class PjRtStreamExecutorBuffer; |
| |
| virtual Status EnqueueCrossHostReceive( |
| absl::Span<const std::unique_ptr<PjRtBuffer>> buffers, |
| std::shared_ptr<BufferSequencingEvent> definition_event, |
| PjRtCrossHostRecvNotifier notifier, |
| std::optional<std::vector<GatherDetails>> gather_details) const { |
| return Unimplemented("Cross host receives not implemented."); |
| } |
| |
| virtual void CopyToRemoteDevice( |
| PjRtBuffer* buffer, absl::string_view serialized_descriptor, |
| PjRtBuffer::RemoteSendCallback on_done) const { |
| on_done(Unimplemented("Cross host sends not implemented."), |
| /*sends_were_enqueued=*/false); |
| } |
| |
| virtual void CopyToRemoteDeviceScattered( |
| PjRtBuffer* buffer, |
| absl::Span<const std::pair<std::string, PjRtBuffer::RemoteSendCallback>> |
| serialized_descriptors_and_callbacks, |
| const PjRtBuffer::ScatterDetails& scatter_details) const { |
| for (const auto& d_and_cb : serialized_descriptors_and_callbacks) { |
| d_and_cb.second( |
| Unimplemented("Scattered cross host sends not implemented."), |
| /*sends_were_enqueued=*/false); |
| } |
| } |
| |
| virtual PjRtFuture<Status> CopyRawSubBufferToHost(PjRtBuffer* buffer, |
| void* dst, int64_t offset, |
| int64_t transfer_size) { |
| return PjRtFuture<Status>( |
| Unimplemented("Raw copies to host not implemented.")); |
| } |
| |
| // Helper function for creating PjRtStreamExecutorExecutables. Modifies |
| // `options` in-place. |
| struct ExecutableExtras { |
| std::shared_ptr<DeviceAssignment> device_assignment; |
| std::vector<PjRtExecutable::LogicalDeviceIds> |
| addressable_device_logical_ids; |
| std::vector<PjRtDevice*> addressable_devices; |
| }; |
| StatusOr<ExecutableExtras> GetExecutableExtras(CompileOptions* options); |
| |
| const PjRtPlatformId platform_id_; |
| const std::string platform_name_; |
| LocalClient* client_; |
| |
| // Allocator to be used for staging memory transfers to devices. |
| std::unique_ptr<tensorflow::Allocator> host_memory_allocator_; |
| |
| // Device memory allocator. If owned, the allocator must outlive the devices, |
| // because it is the device destructor that waits for any outstanding work to |
| // complete. |
| se::DeviceMemoryAllocator* allocator_; |
| std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_; |
| |
| // Includes all devices, including non-local devices on multi-host platforms. |
| std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_; |
| // Pointers to `owned_devices_`. |
| std::vector<PjRtDevice*> devices_; |
| // Maps Device::id() to the corresponding Device. Includes all devices. |
| std::map<int, PjRtDevice*> id_to_device_; |
| // Local devices indexed by local device ordinal. |
| std::vector<PjRtDevice*> addressable_devices_; |
| int process_index_; |
| |
| // Should we always prefer to stage host-to-device transfers via memory |
| // allocated on host_memory_allocator_? True only on GPU, where we prefer to |
| // transfer via pinned memory. |
| bool should_stage_host_to_device_transfers_; |
| |
| std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_; |
| |
| tensorflow::thread::ThreadPool thread_pool_; |
| |
| absl::Mutex transpose_mu_; |
| TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_); |
| }; |
| |
| // Converts a 2D set of Device objects indexed by [replica][partition] into an |
| // xla::DeviceAssignment. |
| StatusOr<DeviceAssignment> DevicesToDeviceAssignment( |
| absl::Span<const std::vector<PjRtDevice*>> devices); |
| |
| class PjRtStreamExecutorBuffer : public PjRtBuffer { |
| public: |
| // Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold |
| // may not outlive its parent PjRtStreamExecutorBuffer. |
| // |
| // There are three types of hold, as follows: |
| // |
| // 1) Usage hold: a transient hold while an operation using the buffer is |
| // being enqueued onto a stream. |
| // A client acquires a usage hold by calling |
| // PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience |
| // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the |
| // hold should be released using a call to ConvertUsageHold. If the ScopedHold |
| // is deleted without ConvertUsageHold being called, e.g., on error, the hold |
| // is dropped. It is legal to drop a usage hold instead of calling |
| // ConvertUsageHold, even if the buffer was successfully enqueued, as long as |
| // the client ensures that all necessary synchronization has been done. |
| // |
| // 2) External hold: a potentially long-lived hold while the buffer is being |
| // shared by an external framework, e.g., NumPy. |
| // A client acquires an external hold by calling |
| // PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience |
| // wrapper GetBufferWithExternalReference and releases it by deleting the |
| // ScopedHold. The external framework should not modify the underlying buffer |
| // unless it is confident via its own synchronization that modifications do |
| // not race with reads from the PjRtStreamExecutorBuffer. |
| // |
| // 3) Donation hold: a transient hold while an execution that donates the |
| // buffer is being enqueued onto the compute stream. |
| // A client acquires a donation hold by calling |
| // PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue |
| // completes successfully the hold should be released using a call to |
| // ConfirmDonation after which the buffer is invalid. If the ScopedHold is |
| // deleted without ConfirmDonation being called, e.g., on error, the hold is |
| // dropped and the buffer remains valid. If the buffer is successfully |
| // enqueued the client *must* call ConfirmDonation. |
| // |
| // Donation holds behave like exclusive write locks: when a donation hold |
| // has been acquired, any attempt to acquire another hold of any type will |
| // block until the donation hold is dropped or confirmed. Acquiring a donation |
| // hold will fail with an error if there is any outstanding external hold, and |
| // will block if there are any outstanding usage holds until those holds are |
| // dropped or converted. |
| // |
| // Calls to PjRtStreamExecutorBuffer::Release (and transitively to |
| // PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will |
| // block until all usage and donation holds are either deleted or |
| // converted/confirmed. |
| class ScopedHold { |
| public: |
| enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue }; |
| // Use a State enum instead of encoding the state in an error Status to |
| // avoid creating Status values in non-error cases. Creating a Status |
| // entails several allocations and can add O(us) to every use of a hold. |
| enum State { |
| kUninitialized = 0, |
| kValid, |
| kMoved, |
| kConverted, |
| kReleased, |
| kDonated, |
| kError |
| }; |
| |
| ~ScopedHold(); |
| ScopedHold(ScopedHold&& other); |
| ScopedHold(const ScopedHold&) = delete; |
| ScopedHold& operator=(const ScopedHold&) = delete; |
| |
| Type type() const { return type_; } |
| |
| Status status() const { |
| // Lazily create Status values only when they are requested. |
| switch (state_) { |
| case kUninitialized: |
| return InvalidArgument("Buffer has not been initialized"); |
| case kValid: |
| return Status::OK(); |
| case kMoved: |
| return InvalidArgument("Buffer has been moved."); |
| case kConverted: |
| return InvalidArgument("Buffer has been converted"); |
| case kReleased: |
| return InvalidArgument("Buffer has been released"); |
| case kDonated: |
| return InvalidArgument("Buffer has been donated"); |
| case kError: |
| return status_; |
| default: |
| CHECK(false) << "Unexpected state value " << state_; |
| } |
| } |
| bool ok() const { return state_ == kValid; } |
| |
| // Access to the underlying device buffer storage. Requires this->ok(). |
| const std::shared_ptr<TrackedDeviceBuffer>& buffer() const { |
| CHECK_EQ(state_, kValid); |
| CHECK_NE(buffer_, nullptr); |
| return buffer_; |
| } |
| TrackedDeviceBuffer* operator->() const { return buffer().get(); } |
| const TrackedDeviceBuffer& operator*() const { return *buffer(); } |
| |
| // Converts the hold into a usage event. Only valid for holds of type |
| // kUsage. |
| // |
| // usage_stream: the stream that the buffer was used on. |
| // event: an event that has been recorded on usage_stream after |
| // the buffer was used. |
| // reference_held: true if and only if the caller has caused a |
| // reference to this->buffer() to stay live until after |
| // the host is sure that the usage (transfer or execution) |
| // has completed. |
| void ConvertUsageHold(se::Stream* usage_stream, |
| std::shared_ptr<BufferSequencingEvent> event, |
| bool reference_held); |
| |
| // Confirms that the buffer was successfully donated to an execution. |
| // Only valid for holds of type kDonation. Causes the buffer to become |
| // invalid. |
| void ConfirmDonation(); |
| |
| // Adds the held device buffers in order to 'iterator'. Used to add the |
| // buffers to an ExecutionInput. We require but do not verify that |
| // 'iterator' when passed in is pointing to a sub-tuple of the |
| // ExecutionInput whose on_device_shape matches that of the |
| // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run |
| // out of bounds. Donates the device buffers if the hold type is kDonation, |
| // otherwise retains ownership of the device buffers. |
| void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator, |
| const ShapeTree<MaybeOwningDeviceMemory>::iterator& end, |
| ExecutionInput* execution_input, |
| se::DeviceMemoryAllocator* allocator) const; |
| |
| private: |
| friend class PjRtStreamExecutorBuffer; |
| friend class PjRtStreamExecutorClient; |
| |
| // Helper struct that makes it possible to move a ScopedHold through a |
| // closure. |
| using ForClosure = std::tuple<PjRtStreamExecutorBuffer*, Type, State, |
| Status, std::shared_ptr<TrackedDeviceBuffer>>; |
| |
| ScopedHold(PjRtStreamExecutorBuffer* parent, Type type) |
| : parent_(parent), type_(type), state_(kUninitialized) {} |
| explicit ScopedHold(const ForClosure& closure_helper) |
| : parent_(std::get<0>(closure_helper)), |
| type_(std::get<1>(closure_helper)), |
| state_(std::get<2>(closure_helper)), |
| status_(std::get<3>(closure_helper)), |
| buffer_(std::get<4>(closure_helper)) { |
| // Check the buffer is not in an error state. |
| CHECK(status_.ok() && buffer_ != nullptr); |
| } |
| |
| // Sets buffer state. |
| void SetState(State state) { state_ = state; } |
| |
| // Sets buffer_ and status_. Called by parent_ to initialize the hold. |
| void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or); |
| // Releases the contents of *this, so *this can subsequently be |
| // deleted without releasing the parent's hold. Should be passed to the |
| // appropriate constructor of another ScopedHold, e.g., when a hold must be |
| // passed through a closure that is incompatible with std::move. |
| ForClosure ToClosure(); |
| |
| PjRtStreamExecutorBuffer* const parent_; |
| const Type type_; |
| |
| // There is an invariant that if ok() then |
| // buffer_.ValueOrDie() != nullptr. |
| State state_; |
| Status status_; |
| std::shared_ptr<TrackedDeviceBuffer> buffer_; |
| }; |
| |
| PjRtStreamExecutorBuffer(Shape on_device_shape, |
| std::shared_ptr<TrackedDeviceBuffer> device_buffer, |
| PjRtClient* client, PjRtDevice* device); |
| ~PjRtStreamExecutorBuffer() override; |
| |
| PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete; |
| PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete; |
| PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete; |
| PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete; |
| |
| const Shape& on_device_shape() const override { return on_device_shape_; } |
| StatusOr<Shape> logical_on_device_shape() override; |
| PjRtStreamExecutorDevice* device() const override { return device_; } |
| PjRtPlatformId platform_id() const { return client_->platform_id(); } |
| absl::string_view platform_name() const { return client_->platform_name(); } |
| PjRtStreamExecutorClient* client() const override { return client_; } |
| bool IsEmptyTuple() const { |
| return on_device_shape_.IsTuple() && |
| on_device_shape_.tuple_shapes_size() == 0; |
| } |
| |
| StatusOr<std::unique_ptr<ExternalReference>> AcquireExternalReference() |
| override; |
| |
| StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership( |
| bool wait_for_operations_to_complete) override; |
| |
| using PjRtBuffer::ToLiteralSync; |
| PjRtFuture<Status> ToLiteral(MutableLiteralBase* literal) override; |
| |
| StatusOr<size_t> GetOnDeviceSizeInBytes() const override; |
| |
| PjRtFuture<Status> CopyRawToHost(void* dst, int64_t offset, |
| int64_t transfer_size) override; |
| |
| // Drops the buffer's reference to its associated device memory, leaving the |
| // buffer in an invalid state. The memory will be freed lazily when all async |
| // operations using the buffer have completed, according to the allocation |
| // semantics of the underlying platform. Delete may briefly block if another |
| // thread is in the process of enqueuing an operation on this buffer, but it |
| // will never block for a stream operation to complete. If an external |
| // framework holds a reference to the TrackedDeviceBuffer via |
| // GetBufferWithExternalReference, the memory will not be freed until the |
| // external framework drops the reference. |
| void Delete() override; |
| |
| bool IsDeleted() override; |
| |
| // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The |
| // PjRtBuffer retains ownership of the device buffers. |
| StatusOr<ShapedBuffer> AsShapedBuffer() const; |
| |
| // Returns a hold on the TrackedDeviceBuffer holding the device |
| // buffers. See comment on ScopedHold. |
| ScopedHold GetBufferWithHold(ScopedHold::Type type); |
| ScopedHold GetBufferWithUsageHold() { |
| return GetBufferWithHold(ScopedHold::kUsage); |
| } |
| ScopedHold GetBufferWithExternalReference() { |
| return GetBufferWithHold(ScopedHold::kExternalReference); |
| } |
| |
| StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice( |
| PjRtDevice* dst_device) override; |
| |
| void CopyToRemoteDevice(absl::string_view serialized_descriptor, |
| RemoteSendCallback on_done) override; |
| |
| void CopyToRemoteDeviceScattered( |
| absl::Span<const std::pair<std::string, RemoteSendCallback>> |
| serialized_descriptors_and_callbacks, |
| const ScatterDetails& scatter_details) override; |
| |
| PjRtFuture<Status> GetReadyFuture() override; |
| |
| bool IsOnCpu() const override; |
| |
| // Similar to Delete, drops the buffer's reference to its associated device |
| // memory, leaving the buffer in an invalid state, but returns the |
| // TrackedDeviceBuffer rather than freeing the device memory, so that another |
| // framework can take ownership of it. The buffer returned from Release may |
| // be safely dropped at any time even if it still has pending async |
| // operations. The client should call GetReadyFuture()->Await() before calling |
| // Release with wait_for_operations_to_complete=false, to ensure that the host |
| // has synchronized past any outstanding write operations to the buffer. If |
| // wait_for_operations_to_complete=true the host will block until any |
| // potentially outstanding asynchronous operations have completed before |
| // returning, in which case it is safe to read or mutate the returned buffer. |
| // If the buffer was shared via an external reference it is the client's |
| // responsibility that accesses via that reference do not interfere with |
| // accesses via the buffer returned from Release. |
| StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release( |
| bool wait_for_operations_to_complete); |
| |
| private: |
| friend class PjRtClient; |
| |
| // Blocks in mu_.Await until there are no more usage holds. |
| void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
| |
| // Blocks in mu_.Await until there is no donation hold. |
| void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
| |
| // Adds a hold of 'type' and returns device_buffer_. Returns an error if |
| // device_buffer_ is null, or if a donation hold was requested when there is |
| // an outstanding external hold. |
| // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() |
| // must be called first.) |
| StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked( |
| ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
| |
| // Adds a hold of hold->type() and initializes `hold` with device_buffer_. |
| // Initializes hold with an error if device_buffer_ is null, or if a donation |
| // hold was requested when there is an outstanding external hold. |
| // Requires holds_[kDonation] == 0 (i.e., WaitForOutstandingDonationHolds() |
| // must be called first.) |
| void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
| |
| // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity |
| // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after |
| // device_buffer_ was successfully enqueued on a stream. |
| void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream, |
| std::shared_ptr<BufferSequencingEvent> event, |
| bool reference_held); |
| |
| // Drops a donation hold and makes *this invalid for further use. Does a |
| // sanity check that buffer==device_buffer_. Called after device_buffer_ was |
| // successfully donated to an execution. |
| void ConfirmDonation(TrackedDeviceBuffer* device_buffer); |
| |
| // Drops a hold without taking any other action. Does a sanity check that |
| // buffer==device_buffer_ or device_buffer_==nullptr. |
| void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); |
| |
| StatusOr<std::pair<std::unique_ptr<PjRtBuffer>, |
| std::shared_ptr<BufferSequencingEvent>>> |
| CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device, |
| LocalDeviceState* transfer_local_device, |
| se::Stream* transfer_stream, |
| std::shared_ptr<TrackedDeviceBuffer> src_device_buffer); |
| |
| PjRtStreamExecutorClient* const client_; |
| const Shape on_device_shape_; |
| PjRtStreamExecutorDevice* const device_; |
| |
| mutable absl::Mutex mu_; |
| std::shared_ptr<TrackedDeviceBuffer> device_buffer_ ABSL_GUARDED_BY(mu_); |
| // Count of holds on the buffer. |
| std::array<int, ScopedHold::Type::kMaxValue> holds_ ABSL_GUARDED_BY(mu_); |
| PjRtFuture<Status>::Promise definition_promise_ ABSL_GUARDED_BY(mu_); |
| }; |
| |
| // Wraps one or more XLA LocalExecutables (one per partition, as specified by |
| // the build options). |
| class PjRtStreamExecutorExecutable : public PjRtExecutable { |
| public: |
| PjRtStreamExecutorExecutable( |
| std::vector<std::unique_ptr<LocalExecutable>> executables, |
| bool parameter_is_tupled_arguments, |
| std::shared_ptr<DeviceAssignment> device_assignment, |
| std::vector<LogicalDeviceIds> addressable_device_logical_ids, |
| std::vector<PjRtDevice*> addressable_devices, |
| PjRtStreamExecutorClient* client); |
| |
| ~PjRtStreamExecutorExecutable() override = default; |
| |
| PjRtStreamExecutorClient* client() const override { return client_; } |
| |
| absl::string_view name() const override; |
| |
| int num_replicas() const override { |
| return executables_[0]->build_options().num_replicas(); |
| } |
| |
| int num_partitions() const override { |
| return executables_[0]->build_options().num_partitions(); |
| } |
| |
| int64_t SizeOfGeneratedCodeInBytes() const override { |
| int64_t size = 0; |
| for (auto& executable : executables_) { |
| size += executable->executable()->SizeOfGeneratedCodeInBytes(); |
| } |
| return size; |
| } |
| |
| const DeviceAssignment& device_assignment() const override { |
| return *device_assignment_; |
| } |
| |
| absl::Span<const LogicalDeviceIds> addressable_device_logical_ids() |
| const override { |
| return addressable_device_logical_ids_; |
| } |
| |
| absl::Span<PjRtDevice* const> addressable_devices() const override { |
| return addressable_devices_; |
| } |
| |
| // Return an HloModule per partition. |
| StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules() |
| const override; |
| |
| using PjRtExecutable::Execute; |
| StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute( |
| absl::Span<const std::vector<PjRtBuffer*>> argument_handles, |
| const ExecuteOptions& options, |
| std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) |
| override; |
| |
| using PjRtExecutable::ExecuteSharded; |
| StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded( |
| absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, |
| const ExecuteOptions& options, |
| std::optional<PjRtFuture<Status>>& returned_future, |
| bool fill_future) override; |
| |
| using PjRtExecutable::ExecutePortable; |
| StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable( |
| absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, |
| const ExecuteOptions& options, |
| std::optional<PjRtFuture<Status>>& returned_future, |
| bool fill_future) override; |
| |
| void Delete() override { executables_.clear(); } |
| |
| bool IsDeleted() override { return executables_.empty(); } |
| |
| absl::Span<const std::shared_ptr<LocalExecutable>> executables() const { |
| return executables_; |
| } |
| |
| protected: |
| bool parameter_is_tupled_arguments() const { |
| return parameter_is_tupled_arguments_; |
| } |
| |
| private: |
| friend class PjRtStreamExecutorClient; |
| friend class PjRtTpuClient; |
| friend class InternalPjRtTpuClient; |
| // Initializes information about which arguments to which executables must be |
| // donated due to aliases that were specified by the computation. |
| Status SetUpDonation(bool tuple_inputs); |
| |
| // Returns a sorted list of the parameters that must be donated. Derived |
| // classes may use custom logic. |
| virtual absl::Span<int const> ParametersThatMustBeDonated( |
| int executable_idx) const; |
| |
| virtual StatusOr<std::vector<ExecutionInput>> |
| MakeExecutionInputsAndWaitForEvents( |
| int device_ordinal, const ExecuteOptions& options, |
| absl::Span<const Shape> executable_parameter_shapes, |
| absl::Span<PjRtBuffer* const> argument_handles, |
| absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers, |
| absl::flat_hash_set<BufferSequencingEvent*>& events) const; |
| |
| StatusOr<ScopedShapedBuffer> EnqueueExecution( |
| absl::Span<PjRtBuffer* const> argument_handles, int replica, |
| int partition, int executable_idx, const RunId& run_id, |
| const ExecuteOptions& options, PjRtDevice* device, |
| std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers, |
| std::shared_ptr<DeviceAssignment> device_assignment, |
| std::vector<std::function<void()>>& compute_callbacks) const; |
| |
| virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers( |
| int device_ordinal, const ExecuteOptions& options, |
| ScopedShapedBuffer result_buffer, |
| std::shared_ptr<BufferSequencingEvent> definition_event, |
| PjRtDevice* device, std::vector<std::function<void()>>& compute_callbacks, |
| std::vector<std::shared_ptr<TrackedDeviceBuffer>>& buffers_to_release) |
| const; |
| |
| StatusOr<Result> ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles, |
| int replica, int partition, |
| const RunId& run_id, |
| const ExecuteOptions& options, |
| bool fill_future, |
| PjRtDevice* device = nullptr) const; |
| |
| // Create shared pointers so we can free them after the execution: with |
| // asynchronous execution, the process being executed can outlive the |
| // executable itself. |
| PjRtStreamExecutorClient* const client_; |
| // One executable per partition. |
| std::vector<std::shared_ptr<LocalExecutable>> executables_; |
| // On device shapes of the executable parameters. |
| std::vector<std::vector<Shape>> on_device_executable_parameter_shapes_; |
| // Per-executable sorted vector of parameters that have any aliased buffers |
| // and thus must be donated when executing the computation. |
| std::vector<std::vector<int>> parameters_that_must_be_donated_; |
| std::shared_ptr<DeviceAssignment> device_assignment_; |
| |
| // True if the executables were compiled expecting arguments in a single |
| // tuple. |
| const bool parameter_is_tupled_arguments_; |
| |
| // The replica and partition indices of device_assignment_ to be run by this |
| // client. On single-host platforms without partitioning, this is all replicas |
| // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the |
| // case on multi-host platforms. If there are 4 replicas and 2 partitions on a |
| // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8. |
| std::vector<LogicalDeviceIds> addressable_device_logical_ids_; |
| |
| // addressable_devices_[i] is the Device to which |
| // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of |
| // unique_ptrs to play well with the Python bindings (see xla.cc). |
| std::vector<PjRtDevice*> addressable_devices_; |
| }; |
| |
| } // namespace xla |
| |
| #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ |