| /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h" |
| |
| #include <memory> |
| |
| #define EIGEN_USE_THREADS |
| |
| #include "absl/base/thread_annotations.h" |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/synchronization/mutex.h" |
| #include "absl/types/optional.h" |
| #include "absl/types/span.h" |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
| #include "tensorflow/compiler/xla/client/executable_build_options.h" |
| #include "tensorflow/compiler/xla/client/xla_computation.h" |
| #include "tensorflow/compiler/xla/layout.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" |
| #include "tensorflow/compiler/xla/pjrt/semaphore.h" |
| #include "tensorflow/compiler/xla/pjrt/utils.h" |
| #include "tensorflow/compiler/xla/pjrt/worker_thread.h" |
| #include "tensorflow/compiler/xla/service/buffer_assignment.h" |
| #include "tensorflow/compiler/xla/service/computation_placer.h" |
| #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" |
| #include "tensorflow/compiler/xla/service/cpu/cpu_xfeed.h" |
| #include "tensorflow/compiler/xla/service/executable.h" |
| #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" |
| #include "tensorflow/compiler/xla/shape.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/platform/denormal.h" |
| #include "tensorflow/core/platform/setround.h" |
| #include "tensorflow/core/profiler/lib/connected_traceme.h" |
| #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime |
| #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime |
| #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime |
| #include "tfrt/host_context/host_allocator.h" // from @tf_runtime |
| #include "tfrt/host_context/host_context.h" // from @tf_runtime |
| #include "tfrt/support/forward_decls.h" // from @tf_runtime |
| |
| namespace xla { |
| |
| static const char kCpuPlatformName[] = "cpu"; |
| static constexpr size_t kSmallDataTransferByteSize = 102400; // 100 KiB |
| |
| static tfrt::AsyncValueRef<CpuEvent> GetOrCreateReadyEvent( |
| tfrt::HostContext* host_context) { |
| static const auto* ready_event = new tfrt::AsyncValueRef<CpuEvent>( |
| tfrt::MakeAvailableAsyncValueRef<CpuEvent>(host_context)); |
| return ready_event->CopyRef(); |
| } |
| |
| TfrtCpuDevice::TfrtCpuDevice(int id, bool asynchronous) |
| : id_(id), |
| max_inflight_computations_semaphore_(/*capacity=*/asynchronous ? 32 : 1) { |
| } |
| |
| absl::string_view TfrtCpuDevice::device_kind() const { |
| return kCpuPlatformName; |
| } |
| |
| std::string TfrtCpuDevice::DebugString() const { |
| return absl::StrCat("TFRT_CPU_", id()); |
| } |
| |
| Status TfrtCpuDevice::TransferToInfeed(const LiteralSlice& literal) { |
| return TransferLiteralToInfeedOnCpu(local_hardware_id(), literal); |
| } |
| |
| Status TfrtCpuDevice::TransferFromOutfeed(MutableBorrowingLiteral literal) { |
| return TransferLiteralFromOutfeedOnCpu(local_hardware_id(), literal); |
| } |
| |
| static int CpuDeviceCount() { |
| // By default we fix the number of devices to one. However we do let the user |
| // override this behavior to help run tests on the host that run models in |
| // parallel across multiple devices, e.g. pmap. |
| return GetDebugOptionsFromFlags().xla_force_host_platform_device_count(); |
| } |
| |
| static StatusOr<std::vector<std::unique_ptr<TfrtCpuDevice>>> GetTfrtCpuDevices( |
| bool asynchronous) { |
| std::vector<std::unique_ptr<TfrtCpuDevice>> devices; |
| for (int i = 0; i < CpuDeviceCount(); ++i) { |
| auto device = std::make_unique<TfrtCpuDevice>( |
| /*id=*/i, asynchronous); |
| devices.push_back(std::move(device)); |
| } |
| return std::move(devices); |
| } |
| |
| StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(bool asynchronous) { |
| // TODO(zhangqiaorjc): Allow users set the number of threads. |
| // `num_blocking_threads=16` is picked arbitrarily for now. |
| // Need at least CpuDeviceCount threads to launch one collective. |
| int num_threads = std::max(DefaultThreadPoolSize(), CpuDeviceCount()); |
| auto host_context = std::make_unique<tfrt::HostContext>( |
| [](const tfrt::DecodedDiagnostic& diag) { |
| LOG(ERROR) << "Encountered runtime error: " << diag.message << "\n"; |
| }, |
| tfrt::CreateMallocAllocator(), |
| tfrt::CreateMultiThreadedWorkQueue( |
| /*num_threads=*/num_threads, |
| /*num_blocking_threads=*/16)); |
| |
| TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<TfrtCpuDevice>> devices, |
| GetTfrtCpuDevices(asynchronous)); |
| |
| return std::unique_ptr<PjRtClient>(std::make_unique<TfrtCpuClient>( |
| /*process_index=*/0, std::move(devices), std::move(host_context))); |
| } |
| |
| TfrtCpuClient::TfrtCpuClient( |
| int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices, |
| std::unique_ptr<tfrt::HostContext> host_ctx) |
| : process_index_(process_index), |
| owned_devices_(std::move(devices)), |
| host_ctx_(std::move(host_ctx)), |
| computation_placer_(std::make_unique<ComputationPlacer>()), |
| eigen_intraop_pool_(new tensorflow::thread::ThreadPool( |
| tensorflow::Env::Default(), "XLAEigen", DefaultThreadPoolSize())), |
| eigen_intraop_device_( |
| new Eigen::ThreadPoolDevice(eigen_intraop_pool_->AsEigenThreadPool(), |
| eigen_intraop_pool_->NumThreads())), |
| last_collective_launch_event_( |
| tfrt::MakeAvailableAsyncValueRef<CpuEvent>(host_ctx_.get())) { |
| for (const std::unique_ptr<TfrtCpuDevice>& device : owned_devices_) { |
| devices_.push_back(device.get()); |
| CHECK(id_to_device_.insert({device->id(), device.get()}).second) |
| << "Duplicate device id: " << device->id(); |
| |
| device->SetClient(this); |
| if (device->IsAddressable()) { |
| int idx = device->local_hardware_id(); |
| if (idx >= addressable_devices_.size()) { |
| addressable_devices_.resize(idx + 1); |
| } |
| CHECK(addressable_devices_[idx] == nullptr) << idx; |
| addressable_devices_[idx] = device.get(); |
| } |
| } |
| for (int idx = 0; idx < addressable_devices_.size(); ++idx) { |
| CHECK(addressable_devices_[idx] != nullptr) << idx; |
| } |
| LOG(INFO) << "TfrtCpuClient created."; |
| } |
| |
| StatusOr<PjRtDevice*> TfrtCpuClient::LookupDevice(int device_id) const { |
| auto it = id_to_device_.find(device_id); |
| if (it != id_to_device_.end()) { |
| return it->second; |
| } |
| return InvalidArgument("No matching device found for device_id %d", |
| device_id); |
| } |
| |
| StatusOr<PjRtDevice*> TfrtCpuClient::LookupAddressableDevice( |
| int local_hardware_id) const { |
| for (auto* device : addressable_devices_) { |
| if (local_hardware_id == device->local_hardware_id()) { |
| return device; |
| } |
| } |
| return InvalidArgument("No matching device found for local_hardware_id %d", |
| local_hardware_id); |
| } |
| |
| StatusOr<DeviceAssignment> TfrtCpuClient::GetDefaultDeviceAssignment( |
| int num_replicas, int num_partitions) const { |
| return computation_placer_->AssignDevices(num_replicas, num_partitions); |
| } |
| |
| StatusOr<std::unique_ptr<HloCostAnalysis>> TfrtCpuClient::GetHloCostAnalysis() { |
| return absl::make_unique<HloCostAnalysis>(cpu::CpuExecutable::ShapeSizeBytes); |
| } |
| |
| StatusOr<absl::optional<std::string>> TfrtCpuClient::ExecutableFingerprint( |
| const PjRtExecutable& executable) const { |
| return absl::optional<std::string>(); |
| } |
| |
| static StatusOr<std::unique_ptr<xla::Executable>> JitCompile( |
| const XlaComputation& computation, |
| const absl::Span<const Shape* const> argument_layouts, |
| const ExecutableBuildOptions& build_options, |
| const ExecutionOptions& execution_options) { |
| TF_ASSIGN_OR_RETURN(ProgramShape program_shape, |
| computation.GetProgramShape()); |
| // Unoptimized HloModuleConfig. |
| TF_ASSIGN_OR_RETURN( |
| std::unique_ptr<HloModuleConfig> hlo_module_config, |
| CreateModuleConfig(program_shape, argument_layouts, &execution_options, |
| execution_options.num_replicas(), |
| /*num_threads=*/absl::nullopt, |
| /*aot_options=*/nullptr)); |
| |
| // Unoptimized HloModule. |
| const xla::HloModuleProto& hlo_module_proto = computation.proto(); |
| TF_ASSIGN_OR_RETURN( |
| std::unique_ptr<HloModule> hlo_module, |
| xla::HloModule::CreateFromProto(hlo_module_proto, *hlo_module_config)); |
| VLOG(3) << "Unoptimized HLO module: " << hlo_module->ToString(); |
| |
| // Run Hlo Passes |
| cpu::CpuCompiler compiler; |
| xla::Compiler::CompileOptions dummy; |
| TF_ASSIGN_OR_RETURN(hlo_module, |
| compiler.RunHloPasses(std::move(hlo_module), |
| /*stream_exec=*/nullptr, dummy)); |
| |
| // Run backend. |
| return compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr, |
| dummy); |
| } |
| |
| // Find the root instruction of the entry computation. |
| static const InstructionValueSet& GetRootValueSet( |
| const BufferAssignment& assignment, const HloModule& module) { |
| return assignment.dataflow_analysis().GetInstructionValueSet( |
| module.entry_computation()->root_instruction()); |
| } |
| |
| // Buffer table is indexed by buffer allocation indices. The output buffer is |
| // made up of a subset of those buffer allocations (for tuple, it includes tuple |
| // index table). This helper finds the buffer allocation indices in buffer |
| // assignment that make up for the output buffer. It is used by |
| // CreateResultShapedBuffer to reconstruct the output buffer from the buffer |
| // table allocated by MemoryForAllocation. |
| static StatusOr<absl::InlinedVector<BufferAllocation::Index, 4>> |
| FindResultBufferAllocationIndex(const BufferAssignment& assignment, |
| const HloModule& module) { |
| absl::InlinedVector<BufferAllocation::Index, 4> buffer_indices; |
| const InstructionValueSet& root_value_set = |
| GetRootValueSet(assignment, module); |
| const Shape& result_shape = module.result_shape(); |
| if (!result_shape.IsTuple()) { |
| // Find the buffer allocation that corresponds to the output buffer. |
| const HloValueSet& sources = root_value_set.element({}); |
| // The points to set is unambiguous so the set should be a singleton. |
| CHECK_EQ(1, sources.values().size()); |
| const HloValue* value_source = sources.values()[0]; |
| HloInstruction* src = value_source->instruction(); |
| TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, |
| assignment.GetUniqueSlice(src, value_source->index())); |
| const BufferAllocation::Index buffer_index = slice.index(); |
| buffer_indices.push_back(buffer_index); |
| return {std::move(buffer_indices)}; |
| } |
| buffer_indices.reserve(result_shape.tuple_shapes_size()); |
| for (int i = 0; i < result_shape.tuple_shapes_size(); ++i) { |
| // Find the buffer allocations that corresponds to the output tuple, |
| // including the tuple index table. |
| const HloValueSet& sources = root_value_set.element({i}); |
| // The points to set is unambiguous so the set should be a singleton. |
| CHECK_EQ(1, sources.values().size()); |
| const HloValue* value_source = sources.values()[0]; |
| HloInstruction* src = value_source->instruction(); |
| TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, |
| assignment.GetUniqueSlice(src, value_source->index())); |
| const BufferAllocation::Index buffer_index = slice.index(); |
| buffer_indices.push_back(buffer_index); |
| } |
| return {std::move(buffer_indices)}; |
| } |
| |
| StatusOr<std::unique_ptr<PjRtExecutable>> TfrtCpuClient::Compile( |
| const XlaComputation& computation, CompileOptions options) { |
| tensorflow::profiler::TraceMe traceme("TfrtCpuClient::Compile"); |
| ExecutableBuildOptions& build_options = options.executable_build_options; |
| |
| int num_replicas; |
| int num_partitions; |
| std::shared_ptr<DeviceAssignment> device_assignment; |
| TF_RETURN_IF_ERROR(ParseDeviceAssignmentCompileOptions( |
| options.compile_portable_executable, &options.executable_build_options, |
| [this](int num_replicas, int num_partitions) { |
| return this->GetDefaultDeviceAssignment(num_replicas, num_partitions); |
| }, |
| &num_replicas, &num_partitions, &device_assignment)); |
| |
| std::vector<const Shape*> argument_layout_pointers; |
| TF_RETURN_IF_ERROR(DetermineArgumentLayoutsFromCompileOptions( |
| computation, &LayoutUtil::GetWithDefaultLayout, options.argument_layouts, |
| &options.executable_build_options, &argument_layout_pointers)); |
| |
| std::vector<PjRtExecutable::LogicalDeviceIds> addressable_device_logical_ids; |
| std::vector<PjRtDevice*> addressable_devices; |
| if (device_assignment != nullptr) { |
| addressable_device_logical_ids.reserve(num_replicas * num_partitions); |
| addressable_devices.reserve(num_replicas * num_partitions); |
| for (int replica = 0; replica < num_replicas; ++replica) { |
| for (int partition = 0; partition < num_partitions; ++partition) { |
| int device_id = (*device_assignment)(replica, partition); |
| TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id)); |
| if (device->process_index() != process_index()) { |
| VLOG(3) << "Non-local device: " << device_id; |
| continue; |
| } |
| PjRtExecutable::LogicalDeviceIds logica_device_ids; |
| logica_device_ids.replica = replica; |
| logica_device_ids.partition = partition; |
| addressable_device_logical_ids.push_back(std::move(logica_device_ids)); |
| addressable_devices.push_back(device); |
| } |
| } |
| if (addressable_devices.empty()) { |
| return InvalidArgument( |
| "Device assignment (%s) does not have any local devices.", |
| device_assignment->ToString()); |
| } |
| |
| if (build_options.device_ordinal() < 0) { |
| build_options.set_device_ordinal( |
| addressable_devices.front()->local_hardware_id()); |
| } |
| } |
| |
| TF_ASSIGN_OR_RETURN(ProgramShape program_shape, |
| computation.GetProgramShape()); |
| ExecutionOptions execution_options = |
| CreateExecutionOptions(build_options, &program_shape); |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> cpu_executable, |
| JitCompile(computation, argument_layout_pointers, |
| build_options, execution_options)); |
| auto cpu_executable_ptr = |
| tensorflow::down_cast<cpu::CpuExecutable*>(cpu_executable.get()); |
| |
| // `buffer_table[result_slice.index()]` points to result buffer: |
| // If output is a tuple, it points to the buffer index table. |
| // If output is a non-tuple, it points to the buffer itself. |
| TF_ASSIGN_OR_RETURN( |
| const BufferAllocation::Slice result_slice, |
| cpu_executable_ptr->buffer_assignment().GetUniqueTopLevelOutputSlice()); |
| |
| // `result_buffer_indices` has the buffer allocation indices that make up the |
| // output buffer (could be tuple). |
| TF_ASSIGN_OR_RETURN( |
| auto result_buffer_indices, |
| FindResultBufferAllocationIndex(cpu_executable_ptr->buffer_assignment(), |
| cpu_executable->module())); |
| |
| auto executable = std::make_unique<TfrtCpuExecutable>( |
| num_replicas, num_partitions, std::move(device_assignment), |
| options.parameter_is_tupled_arguments, std::move(cpu_executable), |
| result_slice.index(), std::move(result_buffer_indices), |
| std::move(addressable_device_logical_ids), std::move(addressable_devices), |
| this); |
| TF_RETURN_IF_ERROR( |
| executable->SetUpDonation(options.parameter_is_tupled_arguments)); |
| |
| return std::unique_ptr<PjRtExecutable>(std::move(executable)); |
| } |
| |
| StatusOr<std::unique_ptr<TfrtCpuBuffer>> AllocateDestinationBuffer( |
| const Shape& on_device_shape, |
| absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events, |
| TfrtCpuDevice* device, TfrtCpuClient* client) { |
| absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers; |
| if (!on_device_shape.IsTuple()) { |
| size_t byte_size = ShapeUtil::ByteSizeOf(on_device_shape); |
| auto device_buffer = MaybeOwningCpuMemory::AllocateShared(byte_size); |
| buffers.push_back(std::move(device_buffer)); |
| return std::make_unique<TfrtCpuBuffer>( |
| on_device_shape, |
| std::make_shared<TrackedTfrtCpuDeviceBuffer>( |
| /*is_tuple=*/false, std::move(buffers), |
| std::move(definition_events)), |
| client, device); |
| } |
| // Tuple case. |
| buffers.reserve(on_device_shape.tuple_shapes().size()); |
| for (const auto& leaf_shape : on_device_shape.tuple_shapes()) { |
| size_t byte_size = ShapeUtil::ByteSizeOf(leaf_shape); |
| auto device_buffer = MaybeOwningCpuMemory::AllocateShared(byte_size); |
| buffers.push_back(std::move(device_buffer)); |
| } |
| return std::make_unique<TfrtCpuBuffer>( |
| on_device_shape, |
| std::make_shared<TrackedTfrtCpuDeviceBuffer>( |
| /*is_tuple=*/true, std::move(buffers), std::move(definition_events)), |
| client, device); |
| } |
| |
| StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::CreateViewOfDeviceBuffer( |
| void* device_ptr, const Shape& shape, PjRtDevice* device, |
| std::function<void()> on_delete_callback) { |
| absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers; |
| size_t byte_size = ShapeUtil::ByteSizeOf(shape); |
| auto non_owning_buffer = |
| std::make_shared<MaybeOwningCpuMemory>(device_ptr, byte_size); |
| buffers.push_back(std::move(non_owning_buffer)); |
| absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> empty_definition_events; |
| auto tracked_device_buffer = std::make_shared<TrackedTfrtCpuDeviceBuffer>( |
| /*is_tuple=*/false, std::move(buffers), |
| std::move(empty_definition_events), std::move(on_delete_callback)); |
| return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtCpuBuffer>( |
| shape, std::move(tracked_device_buffer), this, |
| tensorflow::down_cast<TfrtCpuDevice*>(device))); |
| } |
| |
| StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::CreateUninitializedBuffer( |
| const Shape& shape, PjRtDevice* device) { |
| tensorflow::profiler::TraceMe traceme( |
| "TfrtCpuClient::CreateUninitializedBuffer"); |
| VLOG(1) << "TfrtCpuClient::CreateUninitializedBuffer: shape: " |
| << shape.DebugString() << " device: " << device->DebugString(); |
| return AllocateDestinationBuffer( |
| shape, /*definition_events=*/{}, |
| tensorflow::down_cast<TfrtCpuDevice*>(device), this); |
| } |
| |
| StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::BufferFromHostBuffer( |
| const void* data, const Shape& shape, |
| HostBufferSemantics host_buffer_semantics, |
| std::function<void()> on_done_with_host_buffer, PjRtDevice* device) { |
| tensorflow::profiler::TraceMe traceme("TfrtCpuClient::BufferFromHostBuffer"); |
| VLOG(2) << "TfrtCpuClient::BufferFromHostBuffer: shape: " << shape.ToString() |
| << " device: " << device->DebugString(); |
| if (shape.IsTuple()) { |
| return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple"); |
| } |
| bool has_default_layout = |
| !shape.has_layout() || |
| LayoutUtil::IsMonotonicWithDim0Major(shape.layout()); |
| // If the input buffer has a default layout and is sufficiently aligned, we |
| // can simply point to the input array's data without any further copies. At |
| // the time of writing we require a 16-byte alignment because XLA may generate |
| // code which requires it. |
| bool can_use_zero_copy = |
| has_default_layout && |
| host_buffer_semantics == HostBufferSemantics::kZeroCopy && |
| ((absl::bit_cast<std::uintptr_t>(data) & |
| (cpu_function_runtime::kMinAlign - 1)) == 0); |
| absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers; |
| absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events; |
| std::function<void()> on_delete_callback; |
| size_t byte_size = ShapeUtil::ByteSizeOf(shape); |
| if (can_use_zero_copy) { |
| auto device_buffer = std::make_shared<MaybeOwningCpuMemory>( |
| const_cast<void*>(data), byte_size); |
| buffers.push_back(std::move(device_buffer)); |
| on_delete_callback = std::move(on_done_with_host_buffer); |
| } else { |
| auto device_buffer = MaybeOwningCpuMemory::AllocateShared(byte_size); |
| auto dst_data_ptr = device_buffer->data(); |
| buffers.push_back(std::move(device_buffer)); |
| if (!has_default_layout) { |
| // If layout is not default, relayout and sync copy. |
| BorrowingLiteral literal(static_cast<const char*>(data), shape); |
| Literal transposed_literal = |
| literal.Relayout(LayoutUtil::GetDefaultLayoutForShape(shape), {}); |
| CHECK_EQ(byte_size, transposed_literal.size_bytes()); |
| std::memcpy(dst_data_ptr, transposed_literal.untyped_data(), byte_size); |
| if (on_done_with_host_buffer) { |
| on_done_with_host_buffer(); |
| on_done_with_host_buffer = nullptr; |
| } |
| } else { |
| bool should_sync_copy = |
| host_buffer_semantics == |
| HostBufferSemantics::kImmutableOnlyDuringCall || |
| (byte_size < kSmallDataTransferByteSize); |
| if (should_sync_copy) { |
| std::memcpy(dst_data_ptr, data, byte_size); |
| if (on_done_with_host_buffer) { |
| on_done_with_host_buffer(); |
| on_done_with_host_buffer = nullptr; |
| } |
| } else { |
| tfrt::AsyncValueRef<CpuEvent> copy_event = |
| tfrt::MakeConstructedAsyncValueRef<CpuEvent>(host_ctx_.get()); |
| definition_events.push_back(copy_event.CopyRef()); |
| tfrt::EnqueueWork( |
| host_ctx_.get(), |
| [dst_data_ptr, data, byte_size, copy_event = std::move(copy_event), |
| on_done_with_host_buffer = |
| std::move(on_done_with_host_buffer)]() mutable { |
| tensorflow::profiler::TraceMe traceme("H2D Dispatch"); |
| std::memcpy(dst_data_ptr, data, byte_size); |
| if (on_done_with_host_buffer) { |
| on_done_with_host_buffer(); |
| on_done_with_host_buffer = nullptr; |
| } |
| // Signal copy is complete. |
| copy_event.SetStateConcrete(); |
| }); |
| } |
| } |
| } |
| auto tracked_device_buffer = std::make_shared<TrackedTfrtCpuDeviceBuffer>( |
| /*is_tuple=*/false, std::move(buffers), std::move(definition_events), |
| std::move(on_delete_callback)); |
| return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtCpuBuffer>( |
| shape, std::move(tracked_device_buffer), this, |
| tensorflow::down_cast<TfrtCpuDevice*>(device))); |
| } |
| |
| StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuClient::BufferFromHostLiteral( |
| const LiteralSlice& literal, PjRtDevice* device) { |
| tensorflow::profiler::TraceMe traceme("TfrtCpuClient::BufferFromHostLiteral"); |
| VLOG(1) << "TfrtCpuClient::BufferFromHostLiteral: shape: " |
| << literal.shape().DebugString() |
| << " device: " << device->DebugString(); |
| const Shape& shape = literal.shape(); |
| |
| // Add a placeholder definition event for each leaf buffer when creating the |
| // buffer. They are set only after h2d dispatch. |
| absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events; |
| absl::InlinedVector<tfrt::RCReference<tfrt::AsyncValue>, 4> avs; |
| int num_leaf_buffers = shape.IsTuple() ? shape.tuple_shapes_size() : 1; |
| for (int i = 0; i < num_leaf_buffers; ++i) { |
| tfrt::AsyncValueRef<CpuEvent> definition_event = |
| tfrt::MakeConstructedAsyncValueRef<CpuEvent>(GetHostContext()); |
| definition_events.push_back(definition_event.CopyRef()); |
| avs.push_back(std::move(definition_event)); |
| } |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<TfrtCpuBuffer> output_buffer, |
| AllocateDestinationBuffer( |
| shape, std::move(definition_events), |
| tensorflow::down_cast<TfrtCpuDevice*>(device), this)); |
| |
| if (!shape.IsTuple()) { |
| TfrtCpuBuffer::ScopedHold device_buffer( |
| output_buffer->GetBufferWithUsageHold()); |
| CHECK(device_buffer.ok()); |
| // It is OK to capture `buffer` pointer because the `output_buffer` can't be |
| // deleted until all the usage holds have gone away. |
| tfrt::EnqueueWork( |
| GetHostContext(), |
| [literal, av = avs[0].CopyRef(), |
| movable_device_buffer{device_buffer.ToClosure()}, shape]() mutable { |
| tensorflow::profiler::TraceMe traceme("H2D Dispatch"); |
| TfrtCpuBuffer::ScopedHold device_buffer(movable_device_buffer); |
| const std::shared_ptr<MaybeOwningCpuMemory>& b = |
| device_buffer->Buffers()[0]; |
| CHECK_EQ(literal.size_bytes(), b->size()); |
| std::memcpy(b->data(), literal.untyped_data(), b->size()); |
| // Signal copy is complete. |
| av->SetStateConcrete(); |
| }); |
| } else { |
| // For tuple, transfer leaf literal individually in parallel. |
| for (int i = 0; i < shape.tuple_shapes_size(); ++i) { |
| TfrtCpuBuffer::ScopedHold device_buffer( |
| output_buffer->GetBufferWithUsageHold()); |
| CHECK(device_buffer.ok()); |
| // It is OK to capture `buffer` pointer because the `output_buffer` can't |
| // be deleted until all the usage holds have gone away. |
| tfrt::EnqueueWork( |
| GetHostContext(), |
| [i, literal, av = avs[i].CopyRef(), shape, |
| movable_device_buffer{device_buffer.ToClosure()}]() mutable { |
| tensorflow::profiler::TraceMe traceme("H2D Dispatch"); |
| TfrtCpuBuffer::ScopedHold device_buffer(movable_device_buffer); |
| auto slice = LiteralSlice(literal, {i}); |
| const std::shared_ptr<MaybeOwningCpuMemory>& b = |
| device_buffer->Buffers()[i]; |
| CHECK_EQ(slice.size_bytes(), b->size()); |
| std::memcpy(b->data(), slice.untyped_data(), slice.size_bytes()); |
| // Signal copy is complete. |
| av->SetStateConcrete(); |
| }); |
| } |
| } |
| return std::unique_ptr<PjRtBuffer>(std::move(output_buffer)); |
| } |
| |
| TfrtCpuBuffer::ScopedHold::~ScopedHold() { |
| if (ok()) { |
| parent_->DropHold(type_, buffer().get()); |
| } |
| } |
| |
| TfrtCpuBuffer::ScopedHold::ScopedHold(ScopedHold&& other) |
| : parent_(other.parent_), |
| type_(other.type_), |
| state_(other.state_), |
| status_(std::move(other.status_)), |
| buffer_(std::move(other.buffer_)) { |
| // Preserve the invariant that status is invalid if buffer == nullptr. |
| other.SetState(kMoved); |
| } |
| |
| void TfrtCpuBuffer::ScopedHold::Acquire( |
| StatusOr<std::shared_ptr<TrackedTfrtCpuDeviceBuffer>>&& buffer_or) { |
| CHECK(!ok()); |
| if (buffer_or.ok()) { |
| buffer_ = buffer_or.ValueOrDie(); |
| SetState(kValid); |
| } else { |
| status_ = buffer_or.status(); |
| buffer_ = nullptr; |
| SetState(kError); |
| } |
| // Check the invariant holds. |
| CHECK(!ok() || buffer_ != nullptr); |
| } |
| |
| TfrtCpuBuffer::ScopedHold::ForClosure TfrtCpuBuffer::ScopedHold::ToClosure() { |
| CHECK(ok()); |
| ForClosure for_closure(parent_, type_, state_, std::move(status_), |
| std::move(buffer_)); |
| SetState(kReleased); |
| return for_closure; |
| } |
| |
| void TfrtCpuBuffer::ScopedHold::ConvertUsageHold( |
| absl::Span<tfrt::AsyncValueRef<CpuEvent>> events) { |
| CHECK(ok()); |
| CHECK_EQ(type_, kUsage); |
| parent_->ConvertUsageHold(buffer().get(), events); |
| SetState(kConverted); |
| } |
| |
| void TfrtCpuBuffer::ScopedHold::ConfirmDonation() { |
| CHECK(ok()); |
| CHECK_EQ(type_, kDonation); |
| parent_->ConfirmDonation(buffer().get()); |
| SetState(kDonated); |
| } |
| |
| TfrtCpuBuffer::TfrtCpuBuffer( |
| Shape on_device_shape, |
| std::shared_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer, |
| TfrtCpuClient* client, TfrtCpuDevice* device) |
| : client_(client), |
| on_device_shape_(std::move(on_device_shape)), |
| device_(device), |
| tracked_device_buffer_(std::move(tracked_device_buffer)) { |
| for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) { |
| holds_[i] = 0; |
| } |
| } |
| |
| TfrtCpuBuffer::~TfrtCpuBuffer() { |
| Delete(); |
| for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) { |
| CHECK_EQ(holds_[i], 0); |
| } |
| } |
| |
| StatusOr<size_t> TfrtCpuBuffer::GetOnDeviceSizeInBytes() const { |
| return ShapeUtil::ByteSizeOf(on_device_shape_); |
| } |
| |
| namespace { |
| |
| // Implements PjRtBuffer::ExternalReference as a wrapped |
| // ScopedHold::kExternalReference. |
| class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference { |
| public: |
| explicit ScopedHoldAsExternalReference(TfrtCpuBuffer::ScopedHold hold) |
| : external_reference_(std::move(hold)) { |
| CHECK(external_reference_.type() == |
| TfrtCpuBuffer::ScopedHold::kExternalReference); |
| data_ptr_ = external_reference_->Buffers()[0]->data(); |
| } |
| |
| ~ScopedHoldAsExternalReference() override = default; |
| |
| private: |
| TfrtCpuBuffer::ScopedHold external_reference_; |
| }; |
| |
| } // namespace |
| |
| StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>> |
| TfrtCpuBuffer::AcquireExternalReference() { |
| ScopedHold hold = GetBufferWithExternalReference(); |
| Status hold_status = hold.status(); |
| if (!hold_status.ok()) return hold_status; |
| return std::unique_ptr<ExternalReference>( |
| std::make_unique<ScopedHoldAsExternalReference>(std::move(hold))); |
| } |
| |
| class TrackedCpuDeviceBufferExternalReference |
| : public PjRtBuffer::ExternalReference { |
| public: |
| explicit TrackedCpuDeviceBufferExternalReference( |
| std::shared_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer) |
| : tracked_device_buffer_(std::move(tracked_device_buffer)) { |
| data_ptr_ = tracked_device_buffer_->Buffers()[0]->data(); |
| } |
| |
| ~TrackedCpuDeviceBufferExternalReference() override = default; |
| |
| private: |
| std::shared_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer_; |
| }; |
| |
| StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>> |
| TfrtCpuBuffer::ReleaseDeviceMemoryOwnership( |
| bool wait_for_operations_to_complete) { |
| if (on_device_shape_.IsTuple()) { |
| return InvalidArgument( |
| "ReleaseDeviceMemoryOwnership allowed only for non-tuple"); |
| } |
| TF_ASSIGN_OR_RETURN( |
| std::shared_ptr<TrackedTfrtCpuDeviceBuffer> tracked_device_buffer, |
| Release(wait_for_operations_to_complete)); |
| |
| std::unique_ptr<PjRtBuffer::ExternalReference> ref; |
| if (tracked_device_buffer) { |
| ref = std::make_unique<TrackedCpuDeviceBufferExternalReference>( |
| std::move(tracked_device_buffer)); |
| } |
| return ref; |
| } |
| |
| void TfrtCpuBuffer::Delete() { |
| // When wait_for_reads_to_complete is false, Release should never fail. |
| TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status()); |
| } |
| |
| bool TfrtCpuBuffer::IsDeleted() { |
| absl::MutexLock lock(&mu_); |
| return tracked_device_buffer_ == nullptr; |
| } |
| |
| void TfrtCpuBuffer::WaitForOutstandingUsageHolds() { |
| auto not_in_usage_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| return holds_[ScopedHold::kUsage] == 0; |
| }; |
| mu_.Await(absl::Condition(¬_in_usage_hold)); |
| } |
| |
| void TfrtCpuBuffer::WaitForOutstandingDonationHold() { |
| auto not_in_donation_hold = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| return holds_[ScopedHold::kDonation] == 0; |
| }; |
| mu_.Await(absl::Condition(¬_in_donation_hold)); |
| } |
| |
| StatusOr<std::shared_ptr<TrackedTfrtCpuDeviceBuffer>> TfrtCpuBuffer::Release( |
| bool wait_for_operations_to_complete) { |
| std::shared_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer; |
| absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> events; |
| { |
| absl::MutexLock lock(&mu_); |
| // We first wait for a donation hold to complete if there is one in |
| // progress. If the donation succeeds via ConfirmDonation() then it will |
| // set device_buffer_ to nullptr before returning to this thread. |
| WaitForOutstandingDonationHold(); |
| if (tracked_device_buffer_ == nullptr) { |
| // Buffer has been deleted. |
| return std::shared_ptr<TrackedTfrtCpuDeviceBuffer>(); |
| } |
| // Set device_buffer_ to null now so that no other thread can add a hold |
| // while we are in WaitForOutstandingUsageHolds() below. |
| std::swap(tracked_device_buffer_, device_buffer); |
| WaitForOutstandingUsageHolds(); |
| // Now that all holds have completed and no more can be added, we can get |
| // the final set of usage events. |
| events = device_buffer->LockUseAndTransferUsageEvents(); |
| } |
| if (wait_for_operations_to_complete) { |
| // Block the host until all usage events have completed. Usage events |
| // dominate definition events, so this also waits for the buffer to be |
| // defined. Return the first error encountered. |
| Status first_error; |
| for (const auto& av : events) { |
| client_->GetHostContext()->Await(av.CopyRCRef()); |
| if (auto* error = av.GetErrorIfPresent()) { |
| first_error.Update(InternalError("Error Execute: %s", error->message)); |
| } |
| } |
| if (!first_error.ok()) return std::move(first_error); |
| } |
| return std::move(device_buffer); |
| } |
| |
| StatusOr<std::shared_ptr<TrackedTfrtCpuDeviceBuffer>> |
| TfrtCpuBuffer::GetBufferForHoldLocked(ScopedHold::Type type) { |
| // All callers should have called WaitForOutstandingDonationHold(). |
| CHECK_EQ(holds_[ScopedHold::kDonation], 0); |
| if (type == ScopedHold::kDonation) { |
| if (tracked_device_buffer_ == nullptr) { |
| return InvalidArgument("Donation requested for invalid buffer"); |
| } |
| if (holds_[ScopedHold::kExternalReference] > 0) { |
| return InvalidArgument( |
| "Donation requested for buffer with external reference"); |
| } |
| // First add the donation hold. |
| ++holds_[type]; |
| // Then wait for any usage holds to be dropped or converted. No new usage |
| // holds can be added until we drop the donation hold so this wait will |
| // complete eventually. |
| WaitForOutstandingUsageHolds(); |
| // Because we added a donation hold, nobody could release the buffer while |
| // we were waiting. |
| CHECK(tracked_device_buffer_ != nullptr); |
| } else { |
| if (tracked_device_buffer_ == nullptr) { |
| return InvalidArgument("Buffer has been deleted or donated."); |
| } else { |
| ++holds_[type]; |
| } |
| } |
| return tracked_device_buffer_; |
| } |
| |
| void TfrtCpuBuffer::AcquireHoldLocked(ScopedHold* hold) { |
| hold->Acquire(GetBufferForHoldLocked(hold->type())); |
| } |
| |
| TfrtCpuBuffer::ScopedHold TfrtCpuBuffer::GetBufferWithHold( |
| ScopedHold::Type type) { |
| absl::MutexLock lock(&mu_); |
| // Ensure that at most one donation hold can be in progress at a time. |
| WaitForOutstandingDonationHold(); |
| ScopedHold hold(this, type); |
| AcquireHoldLocked(&hold); |
| return hold; |
| } |
| |
| void TfrtCpuBuffer::ConvertUsageHold( |
| TrackedTfrtCpuDeviceBuffer* buffer, |
| absl::Span<tfrt::AsyncValueRef<CpuEvent>> events) { |
| absl::MutexLock lock(&mu_); |
| CHECK(tracked_device_buffer_.get() == buffer || |
| tracked_device_buffer_ == nullptr); |
| buffer->AddUsageEvents(events); |
| CHECK_GT(holds_[ScopedHold::kUsage], 0); |
| --holds_[ScopedHold::kUsage]; |
| } |
| |
| void TfrtCpuBuffer::ConfirmDonation(TrackedTfrtCpuDeviceBuffer* device_buffer) { |
| { |
| absl::MutexLock lock(&mu_); |
| CHECK_EQ(holds_[ScopedHold::kUsage], 0); |
| CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); |
| CHECK_EQ(holds_[ScopedHold::kDonation], 1); |
| holds_[ScopedHold::kDonation] = 0; |
| CHECK(tracked_device_buffer_.get() == device_buffer); |
| // As a sanity check ensure no more usage events can be added to the buffer. |
| device_buffer->LockUseAndTransferUsageEvents(); |
| // Give up ownership of the device memory so we don't free it when the last |
| // reference to device_buffer_ goes away. |
| device_buffer->ReleaseDeviceMemory(); |
| // Make *this invalid so it can't be used again. Any threads blocking in |
| // Release or GetBufferWithHold will see an invalid buffer and return. |
| tracked_device_buffer_.reset(); |
| } |
| } |
| |
| void TfrtCpuBuffer::DropHold(ScopedHold::Type type, |
| TrackedTfrtCpuDeviceBuffer* buffer) { |
| absl::MutexLock lock(&mu_); |
| CHECK(tracked_device_buffer_.get() == buffer || |
| tracked_device_buffer_ == nullptr); |
| CHECK_GT(holds_[type], 0); |
| --holds_[type]; |
| if (type == ScopedHold::kDonation) { |
| CHECK_EQ(holds_[ScopedHold::kDonation], 0); |
| CHECK_EQ(holds_[ScopedHold::kUsage], 0); |
| CHECK_EQ(holds_[ScopedHold::kExternalReference], 0); |
| } |
| } |
| |
| static ShapedBuffer AsShapedBuffer( |
| int device_ordinal, const Shape& on_device_shape, |
| absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> buffers) { |
| ShapedBuffer shaped_buffer(on_device_shape, device_ordinal); |
| ShapeTree<se::DeviceMemoryBase>::iterator iterator = |
| shaped_buffer.buffers().begin(); |
| for (const auto& buf : buffers) { |
| CHECK(iterator != shaped_buffer.buffers().end()); |
| iterator->second = se::DeviceMemoryBase(buf->data(), buf->size()); |
| ++iterator; |
| } |
| CHECK(iterator == shaped_buffer.buffers().end()); |
| return shaped_buffer; |
| } |
| |
| StatusOr<Shape> TfrtCpuBuffer::logical_on_device_shape() { |
| if (on_device_shape_.is_static()) { |
| return on_device_shape_; |
| } |
| ScopedHold device_buffer(this, ScopedHold::kUsage); |
| { |
| absl::MutexLock lock(&mu_); |
| // We can't perform any other action while a donation hold is in progress. |
| WaitForOutstandingDonationHold(); |
| if (tracked_device_buffer_ == nullptr) { |
| return InvalidArgument( |
| "logical_on_device_shape() called on deleted or donated buffer"); |
| } |
| AcquireHoldLocked(&device_buffer); |
| } |
| |
| // Wait for definition events. |
| for (const auto& av : device_buffer->DefinitionEvents()) { |
| client_->GetHostContext()->Await(av.CopyRCRef()); |
| if (auto* error = av.GetErrorIfPresent()) { |
| return InternalError("Error Execute: %s", error->message); |
| } |
| } |
| |
| ShapedBuffer shaped_buffer = AsShapedBuffer( |
| device_->local_hardware_id(), on_device_shape_, device_buffer->Buffers()); |
| Shape ret_shape = on_device_shape_; |
| TF_RETURN_IF_ERROR(ReadDynamicShapesOnCpu( |
| &shaped_buffer, &ret_shape, cpu::CpuExecutable::ShapeSizeBytes)); |
| return ret_shape; |
| } |
| |
| static std::vector<tfrt::RCReference<tfrt::AsyncValue>> GetAsyncValues( |
| absl::Span<const tfrt::AsyncValueRef<CpuEvent>> events) { |
| std::vector<tfrt::RCReference<tfrt::AsyncValue>> avs; |
| avs.reserve(events.size()); |
| for (const auto& ev : events) { |
| avs.push_back(ev.CopyRCRef()); |
| } |
| return avs; |
| } |
| |
| // Enqueue to TFRT non-blocking work queue when all `values` are ready. |
| static void EnqueueWorkWhenReady( |
| tfrt::HostContext* host_ctx, |
| tfrt::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> values, |
| llvm::unique_function<void()> callee) { |
| tfrt::RunWhenReady(values, [host_ctx, callee = std::move(callee)]() mutable { |
| tfrt::EnqueueWork(host_ctx, std::move(callee)); |
| }); |
| } |
| |
| void TfrtCpuBuffer::ToLiteral(MutableLiteralBase* literal, |
| std::function<void(Status)> on_ready) { |
| tensorflow::profiler::TraceMe traceme("TfrtCpuBuffer::ToLiteral"); |
| if (IsEmptyTuple()) { |
| on_ready(InvalidArgument("ToLiteral called on empty tuple")); |
| return; |
| } |
| TfrtCpuBuffer::ScopedHold device_buffer(this, ScopedHold::kUsage); |
| { |
| absl::MutexLock lock(&mu_); |
| // We can't perform any other action while a donation hold is in progress. |
| WaitForOutstandingDonationHold(); |
| if (tracked_device_buffer_ == nullptr) { |
| on_ready(InvalidArgument( |
| "CopyToHostAsync() called on deleted or donated buffer")); |
| return; |
| } |
| AcquireHoldLocked(&device_buffer); |
| } |
| auto host_ctx = client_->GetHostContext(); |
| |
| std::vector<tfrt::RCReference<tfrt::AsyncValue>> device_buffer_wait_avs = |
| GetAsyncValues(device_buffer.buffer()->DefinitionEvents()); |
| |
| bool should_sync_copy = device_buffer_wait_avs.empty() && |
| literal->size_bytes() < kSmallDataTransferByteSize; |
| if (should_sync_copy) { |
| if (!on_device_shape().IsTuple()) { |
| const std::shared_ptr<MaybeOwningCpuMemory>& b = |
| device_buffer.buffer()->Buffers()[0]; |
| std::memcpy(literal->untyped_data(), b->data(), b->size()); |
| } else { |
| // Tuple case. |
| int num_leaves = literal->shape().tuple_shapes().size(); |
| for (int i = 0; i < num_leaves; ++i) { |
| const std::shared_ptr<MaybeOwningCpuMemory>& b = |
| device_buffer.buffer()->Buffers()[i]; |
| std::memcpy(literal->untyped_data({i}), b->data(), b->size()); |
| } |
| } |
| // Unblock ToLiteral caller. |
| on_ready(Status::OK()); |
| } else { |
| // Wait for buffer definition events to finish before d2h dispatch. |
| // D2H dispatch should be in parallel, e.g. one Execute event finish may |
| // trigger multiple outputs' D2H, they should happen in different threads in |
| // parallel. |
| EnqueueWorkWhenReady( |
| host_ctx, device_buffer_wait_avs, |
| [this, movable_device_buffer{device_buffer.ToClosure()}, |
| device_buffer_wait_avs = std::move(device_buffer_wait_avs), literal, |
| on_ready{std::move(on_ready)}] { |
| tensorflow::profiler::TraceMe traceme("D2H Dispatch"); |
| TfrtCpuBuffer::ScopedHold device_buffer(movable_device_buffer); |
| // Errors in src buffer are surfaced to user. |
| for (const auto& av : device_buffer_wait_avs) { |
| if (auto* error = av->GetErrorIfPresent()) { |
| on_ready( |
| Internal("Error converting to literal: %s", error->message)); |
| return; |
| } |
| } |
| |
| if (!on_device_shape().IsTuple()) { |
| const std::shared_ptr<MaybeOwningCpuMemory>& b = |
| device_buffer.buffer()->Buffers()[0]; |
| std::memcpy(literal->untyped_data(), b->data(), b->size()); |
| } else { |
| // Tuple case. |
| int num_leaves = literal->shape().tuple_shapes().size(); |
| for (int i = 0; i < num_leaves; ++i) { |
| const std::shared_ptr<MaybeOwningCpuMemory>& b = |
| device_buffer.buffer()->Buffers()[i]; |
| std::memcpy(literal->untyped_data({i}), b->data(), b->size()); |
| } |
| } |
| |
| // Unblock ToLiteral caller. |
| on_ready(Status::OK()); |
| }); |
| } |
| } |
| |
| // TODO(zhangqiaorjc): Consider disallowing multiple CPU devices and assign |
| // multiple pmap replicas to the same CPU device for multi-CPU pmap testing. |
| StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuBuffer::CopyToDevice( |
| PjRtDevice* dst_device) { |
| tensorflow::profiler::TraceMe traceme("TfrtCpuBuffer::CopyToDevice"); |
| // TODO(zhangqiaorjc): Remove this restriction after removing the test that |
| // explicitly asserts this. |
| if (dst_device == device_) { |
| return InvalidArgument( |
| "CopyToDevice cannot accept the same source and destination devices"); |
| } |
| |
| // Copying across PjRtClients involves a copy through the host. |
| if (dst_device->client() != client_) { |
| TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral()); |
| // Avoid use-after-free on `literal` due to unsequenced move and use. |
| Literal* literal_pointer = literal.get(); |
| return dst_device->client()->BufferFromHostBuffer( |
| literal_pointer->untyped_data(), literal_pointer->shape(), |
| TfrtCpuClient::HostBufferSemantics::kZeroCopy, |
| [literal{std::move(literal)}]() { /* frees literal */ }, dst_device); |
| } |
| |
| // Copy each leaf buffer to a destination buffer. |
| TfrtCpuBuffer::ScopedHold src_device_buffer( |
| this, TfrtCpuBuffer::ScopedHold::kUsage); |
| { |
| absl::MutexLock lock(&mu_); |
| WaitForOutstandingDonationHold(); |
| if (tracked_device_buffer_ == nullptr) { |
| return InvalidArgument( |
| "CopyToDevice called on deleted or donated buffer"); |
| } |
| AcquireHoldLocked(&src_device_buffer); |
| } |
| |
| int num_leaf_buffers = src_device_buffer->Buffers().size(); |
| absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> src_buffers; |
| absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> dst_buffers; |
| absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events; |
| absl::InlinedVector<tfrt::RCReference<tfrt::IndirectAsyncValue>, 4> |
| indirect_avs; |
| absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> src_usage_events; |
| src_buffers.reserve(num_leaf_buffers); |
| dst_buffers.reserve(num_leaf_buffers); |
| definition_events.reserve(num_leaf_buffers); |
| indirect_avs.reserve(num_leaf_buffers); |
| src_usage_events.reserve(num_leaf_buffers); |
| |
| for (int i = 0; i < num_leaf_buffers; ++i) { |
| auto src_buffer = src_device_buffer->Buffers()[i]; |
| auto dst_buffer = MaybeOwningCpuMemory::AllocateShared(src_buffer->size()); |
| src_buffers.push_back(std::move(src_buffer)); |
| dst_buffers.push_back(std::move(dst_buffer)); |
| tfrt::RCReference<tfrt::IndirectAsyncValue> definition_event = |
| tfrt::MakeIndirectAsyncValue(client_->GetHostContext()); |
| definition_events.push_back( |
| tfrt::AsyncValueRef<CpuEvent>(definition_event.CopyRef())); |
| indirect_avs.push_back(definition_event.CopyRef()); |
| src_usage_events.push_back( |
| tfrt::AsyncValueRef<CpuEvent>(std::move(definition_event))); |
| } |
| |
| // Wait for src buffer definition events to finish before d2d dispatch. |
| // Errors are propagated asynchronously in dst buffer's definition events. |
| std::vector<tfrt::RCReference<tfrt::AsyncValue>> |
| src_device_buffer_definition_events_avs = |
| GetAsyncValues(src_device_buffer.buffer()->DefinitionEvents()); |
| |
| // Add d2d as usage event on src_buffer. |
| src_device_buffer.ConvertUsageHold(absl::MakeSpan(src_usage_events)); |
| |
| EnqueueWorkWhenReady( |
| client()->GetHostContext(), src_device_buffer_definition_events_avs, |
| [client = client_, num_leaf_buffers, src_buffers = std::move(src_buffers), |
| dst_buffers_copies = dst_buffers, indirect_avs = std::move(indirect_avs), |
| src_device_buffer_definition_events_avs = |
| std::move(src_device_buffer_definition_events_avs)]() mutable { |
| tensorflow::profiler::TraceMe traceme("D2D Dispatch"); |
| for (const auto& av : src_device_buffer_definition_events_avs) { |
| if (auto* error = av->GetErrorIfPresent()) { |
| for (int i = 0; i < num_leaf_buffers; ++i) { |
| // Any error discovered in src buffer are propagated to dst buffer |
| // definition events, which will surface to users in |
| // dst_buffer->ToLiteral(). |
| indirect_avs[i]->ForwardTo(av.CopyRef()); |
| } |
| return; |
| } |
| } |
| auto copy_ready = GetOrCreateReadyEvent(client->GetHostContext()); |
| for (int i = 0; i < num_leaf_buffers; ++i) { |
| std::memcpy(dst_buffers_copies[i]->data(), src_buffers[i]->data(), |
| src_buffers[i]->size()); |
| indirect_avs[i]->ForwardTo(copy_ready.CopyRCRef()); |
| } |
| }); |
| |
| return std::unique_ptr<PjRtBuffer>(std::make_unique<TfrtCpuBuffer>( |
| on_device_shape_, |
| std::make_shared<TrackedTfrtCpuDeviceBuffer>( |
| on_device_shape_.IsTuple(), std::move(dst_buffers), |
| std::move(definition_events)), |
| client(), tensorflow::down_cast<TfrtCpuDevice*>(dst_device))); |
| } |
| |
| Status TfrtCpuBuffer::BlockHostUntilReady() { |
| tensorflow::profiler::TraceMe traceme("TfrtCpuBuffer::BlockHostUntilReady"); |
| std::shared_ptr<TrackedTfrtCpuDeviceBuffer> device_buffer; |
| { |
| absl::MutexLock lock(&mu_); |
| if (tracked_device_buffer_ == nullptr) { |
| return InvalidArgument( |
| "BlockHostUntilReady() called on deleted or donated buffer"); |
| } |
| device_buffer = tracked_device_buffer_; |
| } |
| |
| // Wait for all definition events to complete. |
| Status status; |
| for (const auto& ev : device_buffer->DefinitionEvents()) { |
| client_->GetHostContext()->Await(ev.CopyRCRef()); |
| if (auto* error = ev.GetErrorIfPresent()) { |
| status.Update(FailedPrecondition( |
| "Error in BlockHostUntilReady waiting for definition events: %s", |
| error->message)); |
| } |
| } |
| return status; |
| } |
| |
| TfrtCpuExecutable::TfrtCpuExecutable( |
| int num_replicas, int num_partitions, |
| std::shared_ptr<DeviceAssignment> device_assignment, |
| bool parameter_is_tupled_arguments, |
| std::unique_ptr<Executable> cpu_executable, |
| BufferAllocation::Index result_buffer_index, |
| absl::InlinedVector<BufferAllocation::Index, 4> result_buffer_indices, |
| std::vector<LogicalDeviceIds> addressable_device_logical_ids, |
| std::vector<PjRtDevice*> addressable_devices, TfrtCpuClient* client) |
| : client_(client), |
| num_replicas_(num_replicas), |
| num_partitions_(num_partitions), |
| device_assignment_(std::move(device_assignment)), |
| parameter_is_tupled_arguments_(parameter_is_tupled_arguments), |
| cpu_executable_(std::move(cpu_executable)), |
| result_buffer_index_(result_buffer_index), |
| result_buffer_indices_(std::move(result_buffer_indices)), |
| addressable_device_logical_ids_( |
| std::move(addressable_device_logical_ids)), |
| addressable_devices_(std::move(addressable_devices)) { |
| auto hlo_cost_analysis = |
| std::make_unique<HloCostAnalysis>(cpu::CpuExecutable::ShapeSizeBytes); |
| // Cache to avoid std::map lookup in flop_count() on critical path. |
| // The magic constant 1000 is determined by correlating computation with flop |
| // estimate. It is a crude heuristic to find computation less than the thread |
| // context switch time (~5us). |
| cheap_computation_ = hlo_cost_analysis->flop_count() < 1000; |
| |
| const auto& computation_layout = |
| cpu_executable_->module().entry_computation_layout(); |
| if (computation_layout.parameter_count() == 0) { |
| return; |
| } |
| // Assume compiled program expects either many non-tupled arguments or a |
| // singled tupled argument. Nested tuple is not yet supported. |
| if (computation_layout.parameter_count() > 1 || |
| !computation_layout.parameter_shape(0).IsTuple()) { |
| input_buffer_sizes_in_bytes_.reserve(computation_layout.parameter_count()); |
| for (int i = 0; i < computation_layout.parameter_count(); ++i) { |
| input_buffer_sizes_in_bytes_.push_back( |
| ShapeUtil::ByteSizeOf(computation_layout.parameter_shape(i))); |
| } |
| } else { |
| input_buffer_sizes_in_bytes_.reserve( |
| computation_layout.parameter_shape(0).tuple_shapes_size()); |
| for (int i = 0; |
| i < computation_layout.parameter_shape(0).tuple_shapes_size(); ++i) { |
| input_buffer_sizes_in_bytes_.push_back(ShapeUtil::ByteSizeOf( |
| computation_layout.parameter_shape(0).tuple_shapes(i))); |
| } |
| } |
| } |
| |
| void TfrtCpuExecutable::Delete() {} |
| |
| StatusOr<absl::optional<std::string>> TfrtCpuExecutable::Fingerprint() const { |
| return absl::optional<std::string>(); |
| } |
| |
| Status TfrtCpuExecutable::SetUpDonation(bool tuple_inputs) { |
| TF_ASSIGN_OR_RETURN(parameters_that_must_be_donated_, |
| GetParametersThatMustBeDonated( |
| *cpu_executable_->shared_module(), tuple_inputs)); |
| return Status::OK(); |
| } |
| |
| bool TfrtCpuExecutable::MustDonateParameter(int parameter) const { |
| return parameters_that_must_be_donated_.contains(parameter); |
| } |
| |
| // The following few helpers are adapted from XLA:CPU to create a buffer table |
| // and assemble the buffer pointers in order to call into CpuExecutable. |
| static std::shared_ptr<MaybeOwningCpuMemory> MemoryForAllocation( |
| const BufferAllocation& allocation, |
| absl::Span<const std::shared_ptr<TrackedTfrtCpuDeviceBuffer>> arguments) { |
| if (allocation.is_entry_computation_parameter()) { |
| const std::shared_ptr<TrackedTfrtCpuDeviceBuffer>& arg = |
| arguments[allocation.parameter_number()]; |
| std::shared_ptr<MaybeOwningCpuMemory> out = |
| arg->Buffer(allocation.param_shape_index()); |
| CHECK_EQ(allocation.size(), out->size()) |
| << "Size mismatch on param " << allocation.parameter_number() |
| << " at shape index " << allocation.param_shape_index().ToString(); |
| return out; |
| } else if (allocation.is_constant()) { |
| return std::make_shared<MaybeOwningCpuMemory>(); |
| } else if (allocation.is_thread_local()) { |
| return std::make_shared<MaybeOwningCpuMemory>(); |
| } |
| |
| // Output and temporary buffer. |
| int64 buffer_size = allocation.size(); |
| auto out = MaybeOwningCpuMemory::AllocateShared(buffer_size); |
| |
| // Since the output buffer and all the temporary buffers were written into |
| // by the JITed code, msan has no way of knowing their memory was |
| // initialized. Mark them initialized so that msan doesn't flag loads from |
| // these buffers. |
| TF_ANNOTATE_MEMORY_IS_INITIALIZED(out->data(), buffer_size); |
| return out; |
| } |
| |
| static StatusOr<std::vector<std::shared_ptr<MaybeOwningCpuMemory>>> |
| CreateBufferTable( |
| const BufferAssignment& assignment, |
| absl::Span<const std::shared_ptr<TrackedTfrtCpuDeviceBuffer>> arguments) { |
| std::vector<std::shared_ptr<MaybeOwningCpuMemory>> buffers( |
| assignment.Allocations().size()); |
| for (BufferAllocation::Index i = 0; i < assignment.Allocations().size(); |
| ++i) { |
| const BufferAllocation& allocation = assignment.GetAllocation(i); |
| buffers[i] = MemoryForAllocation(allocation, arguments); |
| } |
| |
| TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, |
| assignment.GetUniqueTopLevelOutputSlice()); |
| VLOG(3) << "result index: " << result_slice.index(); |
| return std::move(buffers); |
| } |
| |
| static StatusOr<absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4>> |
| CreateResultShapedBuffer( |
| absl::Span<const BufferAllocation::Index> buffer_indices, |
| absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> buffer_table, |
| absl::Span<const std::shared_ptr<TrackedTfrtCpuDeviceBuffer>> arguments) { |
| absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> output_buffers; |
| output_buffers.reserve(buffer_indices.size()); |
| for (int i = 0; i < buffer_indices.size(); ++i) { |
| output_buffers.push_back(buffer_table[buffer_indices[i]]); |
| } |
| return {std::move(output_buffers)}; |
| } |
| |
| Status TfrtCpuExecutable::CheckBufferCompatibilities( |
| absl::Span<const std::shared_ptr<TrackedTfrtCpuDeviceBuffer>> input_buffers) |
| const { |
| if (input_buffers.size() != input_buffer_sizes_in_bytes_.size()) { |
| return InvalidArgument( |
| "Execution supplied %lld buffers but compiled program expected %lld " |
| "buffers", |
| input_buffers.size(), input_buffer_sizes_in_bytes_.size()); |
| } |
| for (int i = 0; i < input_buffers.size(); ++i) { |
| const auto& buffer = input_buffers[i]; |
| if (input_buffer_sizes_in_bytes_[i] != buffer->Buffers()[0]->size()) { |
| return InvalidArgument( |
| "Executable expected parameter %d of size %lld but got buffer with " |
| "incompatible size %lld", |
| i, input_buffer_sizes_in_bytes_[i], buffer->Buffers()[0]->size()); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> |
| TfrtCpuExecutable::ExecuteHelper( |
| absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition, |
| const RunId& run_id, const ExecuteOptions& options, |
| tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event, |
| TfrtCpuDevice* device) { |
| tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::ExecuteHelper"); |
| auto* host_context = client_->GetHostContext(); |
| |
| std::shared_ptr<DeviceAssignment> device_assignment; |
| if (device == nullptr) { |
| CHECK(device_assignment_ != nullptr); |
| const int device_id = (*device_assignment_)(replica, partition); |
| TF_ASSIGN_OR_RETURN(PjRtDevice * pjrt_device, |
| client_->LookupDevice(device_id)); |
| device = tensorflow::down_cast<TfrtCpuDevice*>(pjrt_device); |
| device_assignment = device_assignment_; |
| } else { |
| CHECK(device_assignment_ == nullptr); |
| CHECK_EQ(replica, 0); |
| CHECK_EQ(partition, 0); |
| CHECK(addressable_devices_.empty()); |
| device_assignment = std::make_shared<DeviceAssignment>(1, 1); |
| (*device_assignment)(0, 0) = device->id(); |
| } |
| CHECK_EQ(device->process_index(), client_->process_index()); |
| |
| // Handle inputs. |
| if (options.arguments_are_tupled) { |
| if (!parameter_is_tupled_arguments_) { |
| return InvalidArgument( |
| "Arguments may only be supplied as a tuple when the executable was " |
| "compiled with a single tupled parameter"); |
| } |
| if (argument_handles.size() != 1) { |
| return InvalidArgument( |
| "Option arguments_are_tupled was true but %d buffers were passed to " |
| "execution", |
| argument_handles.size()); |
| } |
| } |
| |
| absl::InlinedVector<TfrtCpuBuffer::ScopedHold, 4> device_buffers; |
| absl::InlinedVector<std::shared_ptr<TrackedTfrtCpuDeviceBuffer>, 4> |
| tracked_buffers; |
| device_buffers.reserve(argument_handles.size()); |
| tracked_buffers.reserve(argument_handles.size()); |
| // To avoid clobbering inputs, we must ensure that |
| // `extra_deps` = inputs' definition events + donated inputs' usage events. |
| // This also ensures that the returned `execute_event` dominates all inputs' |
| // events, and thus output buffer only need to contain `execute_event` as the |
| // single definition event. |
| std::vector<tfrt::AsyncValueRef<CpuEvent>> input_deps; |
| input_deps.reserve(argument_handles.size()); |
| for (int i = 0; i < argument_handles.size(); ++i) { |
| PjRtBuffer* handle = argument_handles[i]; |
| auto* tfrt_buffer = tensorflow::down_cast<TfrtCpuBuffer*>(handle); |
| if (tfrt_buffer->device() != device) { |
| return InvalidArgument( |
| "Buffer passed to Execute() as argument %d to replica %d is on " |
| "device %s, but replica is assigned to device %s.", |
| i, replica, tfrt_buffer->device()->DebugString(), |
| device->DebugString()); |
| } |
| |
| bool must_donate = MustDonateParameter(i); |
| device_buffers.emplace_back(tfrt_buffer->GetBufferWithHold( |
| must_donate ? TfrtCpuBuffer::ScopedHold::kDonation |
| : TfrtCpuBuffer::ScopedHold::kUsage)); |
| TfrtCpuBuffer::ScopedHold& device_buffer = device_buffers.back(); |
| if (!device_buffer.ok()) { |
| return InvalidArgument( |
| "Invalid buffer passed to Execute() as argument %d to replica %d: " |
| "%s", |
| i, replica, device_buffer.status().ToString()); |
| } |
| |
| // Definition events are never modified after buffer construction. |
| for (const auto& ev : device_buffer->DefinitionEvents()) { |
| if (!ev.IsAvailable()) { |
| input_deps.push_back(ev.CopyRef()); |
| } |
| } |
| // If we are trying to donate this buffer, we must wait on its usage |
| // events as well as its definition events to ensure that all reads on |
| // this buffer (e.g., d2h transfer) have been completed before it can be |
| // mutated. Usage holds on this buffer are excluded during a donation hold |
| // so we know that its usage events won't be modified while we are |
| // enqueueing. |
| if (must_donate) { |
| for (const auto& ev : device_buffer->UsageEvents()) { |
| if (!ev.IsAvailable()) { |
| input_deps.push_back(ev.CopyRef()); |
| } |
| } |
| } |
| tracked_buffers.push_back(device_buffer.buffer()); |
| } |
| |
| TF_RETURN_IF_ERROR(CheckBufferCompatibilities(tracked_buffers)); |
| |
| // Tuplize the inputs if compiler expects a single tuple argument but runtime |
| // gets many inputs that are not yet tupled. |
| if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) { |
| absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> leaf_buffers; |
| leaf_buffers.reserve(tracked_buffers.size()); |
| for (const auto& tracked_buffer : tracked_buffers) { |
| auto span = tracked_buffer->Buffers(); |
| leaf_buffers.insert(leaf_buffers.end(), span.begin(), span.end()); |
| } |
| |
| // Tuplize into a single input. |
| tracked_buffers.clear(); |
| absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> |
| empty_definition_events; |
| tracked_buffers.push_back(std::make_shared<TrackedTfrtCpuDeviceBuffer>( |
| /*is_tuple=*/true, std::move(leaf_buffers), |
| std::move(empty_definition_events))); |
| } |
| |
| auto* cpu_executable = |
| tensorflow::down_cast<cpu::CpuExecutable*>(cpu_executable_.get()); |
| TF_ASSIGN_OR_RETURN( |
| std::vector<std::shared_ptr<MaybeOwningCpuMemory>> buffer_table, |
| CreateBufferTable(cpu_executable->buffer_assignment(), tracked_buffers)); |
| TF_ASSIGN_OR_RETURN(auto result_buffers, |
| CreateResultShapedBuffer(result_buffer_indices_, |
| buffer_table, tracked_buffers)); |
| |
| // The choice of where we wait is arbitrary; the reason for the wait is |
| // pacing to avoid problems such as memory fragmentation and running ahead |
| // too far, not for correctness. Placing it before the executable launch |
| // allows the inputs for the next executable to be fetched even if the |
| // launch is delayed. |
| auto compute_reservation = std::make_unique<Semaphore::ScopedReservation>( |
| device->max_inflight_computations_semaphore().ScopedAcquire(1)); |
| |
| // execute_event indicates whether cpu computation is complete and whether |
| // there was an error. |
| tfrt::AsyncValueRef<CpuEvent> execute_event; |
| |
| // Call the computation function following the calling convention. |
| std::vector<void*> buffer_pointers; |
| buffer_pointers.reserve(buffer_table.size()); |
| for (const auto& buffer : buffer_table) { |
| buffer_pointers.push_back(buffer->data()); |
| } |
| void* result_buffer = buffer_pointers[result_buffer_index_]; |
| |
| ExecutableRunOptions run_options; |
| run_options.set_run_id(run_id); |
| run_options.set_device_ordinal(device->local_hardware_id()); |
| // Need to keep device_assignment alive until execution completes. |
| run_options.set_device_assignment(device_assignment.get()); |
| run_options.set_intra_op_thread_pool(client_->eigen_intraop_device()); |
| |
| // Schedule only one collective at a time. |
| bool is_a_collective_launch = !!last_collective_launch_event; |
| if (is_a_collective_launch) { |
| input_deps.push_back(std::move(last_collective_launch_event)); |
| } |
| |
| if (input_deps.empty() && cheap_computation_) { |
| // Synchronously call generated function. |
| execute_event = GetOrCreateReadyEvent(host_context); |
| |
| // Set denormal and rounding behavior to match the default TF |
| // ThreadPool behavior. |
| tensorflow::port::ScopedFlushDenormal flush; |
| tensorflow::port::ScopedSetRound round(FE_TONEAREST); |
| |
| // Call generated function. |
| cpu_executable->compute_function()(result_buffer, &run_options, nullptr, |
| buffer_pointers.data(), nullptr); |
| } else { |
| // TODO(zhangqiaorjc): Only async launch expensive computations. Need |
| // heuristics to decide what computation is expensive. |
| // Asynchronously call generated function. |
| execute_event = tfrt::MakeConstructedAsyncValueRef<CpuEvent>(host_context); |
| |
| // We only created enough threads for one collective to complete. |
| // The next collective launch will not be scheduled onto threadpool until |
| // this one completes. |
| if (is_a_collective_launch) { |
| client_->SetLastCollectiveLaunchEvent(execute_event.CopyRef()); |
| } |
| std::vector<tfrt::RCReference<tfrt::AsyncValue>> input_deps_avs = |
| GetAsyncValues(input_deps); |
| EnqueueWorkWhenReady( |
| host_context, input_deps_avs, |
| [cpu_executable, result_buffer, |
| buffer_pointers = std::move(buffer_pointers), |
| buffer_table = std::move(buffer_table), |
| run_options = std::move(run_options), |
| cpu_executable_copy = cpu_executable_, |
| device_assignment = std::move(device_assignment), |
| compute_reservation = std::move(compute_reservation), |
| tracked_buffers = std::move(tracked_buffers), |
| execute_event = execute_event.CopyRef(), |
| input_deps_avs = std::move(input_deps_avs)]() mutable { |
| for (const auto& av : input_deps_avs) { |
| if (auto* error = av->GetErrorIfPresent()) { |
| execute_event.SetError(absl::StrCat( |
| "Error dispatching computation: %s", error->message)); |
| return; |
| } |
| } |
| |
| // Set denormal and rounding behavior to match the default TF |
| // ThreadPool behavior. |
| tensorflow::port::ScopedFlushDenormal flush; |
| tensorflow::port::ScopedSetRound round(FE_TONEAREST); |
| |
| // Call generated function. |
| cpu_executable->compute_function()(result_buffer, &run_options, |
| nullptr, buffer_pointers.data(), |
| nullptr); |
| // CPU computation completes. |
| execute_event.SetStateConcrete(); |
| }); |
| } |
| |
| // Handle input event recording. |
| for (TfrtCpuBuffer::ScopedHold& b : device_buffers) { |
| if (b.type() == TfrtCpuBuffer::ScopedHold::kUsage) { |
| std::array<tfrt::AsyncValueRef<CpuEvent>, 1> usage_events{ |
| execute_event.CopyRef()}; |
| b.ConvertUsageHold(absl::MakeSpan(usage_events)); |
| } else { |
| CHECK(b.type() == TfrtCpuBuffer::ScopedHold::kDonation); |
| b.ConfirmDonation(); |
| } |
| } |
| |
| // Create output TFRT buffers. |
| const Shape& result_shape = cpu_executable_->result_shape(); |
| std::vector<std::unique_ptr<PjRtBuffer>> res; |
| if (options.untuple_result && result_shape.IsTuple()) { |
| res.reserve(result_buffers.size()); |
| for (int i = 0; i < result_buffers.size(); ++i) { |
| absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> sub_buffer; |
| sub_buffer.push_back(std::move(result_buffers[i])); |
| // Program execution writes to output buffers so it's a definition event. |
| absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events; |
| definition_events.push_back(execute_event.CopyRef()); |
| auto leaf_tracked_device_buffer = |
| std::make_shared<TrackedTfrtCpuDeviceBuffer>( |
| /*is_tuple=*/false, std::move(sub_buffer), |
| std::move(definition_events)); |
| auto leaf_buffer = std::make_unique<TfrtCpuBuffer>( |
| result_shape.tuple_shapes(i), std::move(leaf_tracked_device_buffer), |
| client_, device); |
| res.push_back(std::move(leaf_buffer)); |
| } |
| } else { |
| // Program execution writes to output buffers so it's a definition event. |
| absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events; |
| definition_events.push_back(execute_event.CopyRef()); |
| auto tracked_device_buffer = std::make_shared<TrackedTfrtCpuDeviceBuffer>( |
| /*is_tuple=*/result_shape.IsTuple(), std::move(result_buffers), |
| std::move(definition_events)); |
| auto tfrt_output_buffer = std::make_unique<TfrtCpuBuffer>( |
| result_shape, std::move(tracked_device_buffer), client_, device); |
| res.push_back(std::move(tfrt_output_buffer)); |
| } |
| return res; |
| } |
| |
| StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> |
| TfrtCpuExecutable::Execute( |
| absl::Span<const std::vector<PjRtBuffer*>> argument_handles, |
| const ExecuteOptions& options) { |
| tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::Execute"); |
| if (device_assignment_ == nullptr) { |
| return InvalidArgument("Execute expects a non-null device_assignment"); |
| } |
| |
| RunId run_id; |
| tensorflow::profiler::TraceMeProducer activity( |
| "TfrtCpuExecutable::Execute", tensorflow::profiler::ContextType::kPjRt, |
| run_id.ToInt()); |
| |
| const int num_addressable_devices = addressable_devices_.size(); |
| |
| if (argument_handles.size() != num_addressable_devices) { |
| return InvalidArgument( |
| "Attempted to execute with %d argument lists when local device " |
| "count is %d (total replica count: %d, partition count: %d)", |
| argument_handles.size(), num_addressable_devices, num_replicas(), |
| num_partitions()); |
| } |
| |
| VLOG(1) << "Executing computation " << name() |
| << "; num_replicas=" << num_replicas() |
| << " num_partitions=" << num_partitions() |
| << " num_addressable_devices=" << num_addressable_devices; |
| |
| std::vector<StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>> results( |
| num_addressable_devices); |
| if (num_addressable_devices == 1) { |
| // Fast-path if there is only one device — run the computation on the |
| // current thread. |
| const int replica = addressable_device_logical_ids_[0].replica; |
| const int partition = addressable_device_logical_ids_[0].partition; |
| results[0] = ExecuteHelper( |
| argument_handles[0], replica, partition, run_id, options, |
| /*last_collective_launch_event=*/tfrt::AsyncValueRef<CpuEvent>()); |
| } else { |
| // Gang schedule collectives to ensure that collectives with the same RunId |
| // are run at the same time. We conservatively run only one collective at a |
| // time, because we may not have enough threads to run arbitrary number of |
| // collectives concurrently. |
| tfrt::AsyncValueRef<CpuEvent> last_collective_launch_event = |
| client_->GetLastCollectiveLaunchEvent(); |
| |
| absl::Mutex mu; |
| int running = num_addressable_devices; |
| int failed = 0; |
| Status first_failure_status; |
| |
| for (int i = 0; i < num_addressable_devices; ++i) { |
| const int replica = addressable_device_logical_ids_[i].replica; |
| const int partition = addressable_device_logical_ids_[i].partition; |
| tfrt::EnqueueWork(client_->GetHostContext(), [&, replica, partition, i] { |
| results[i] = |
| ExecuteHelper(argument_handles[i], replica, partition, run_id, |
| options, last_collective_launch_event.CopyRef()); |
| |
| absl::MutexLock lock(&mu); |
| --running; |
| if (!results[i].ok()) { |
| if (failed == 0) { |
| first_failure_status = results[i].status(); |
| } |
| ++failed; |
| } |
| }); |
| } |
| |
| auto done_running_or_failed = [&]() { |
| mu.AssertHeld(); |
| return running == 0 || failed > 0; |
| }; |
| absl::MutexLock lock(&mu); |
| mu.Await(absl::Condition(&done_running_or_failed)); |
| } |
| VLOG(1) << "Replicated execution complete."; |
| |
| std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results( |
| num_addressable_devices); |
| for (int i = 0; i < num_addressable_devices; ++i) { |
| const int replica = addressable_device_logical_ids_[i].replica; |
| const int partition = addressable_device_logical_ids_[i].partition; |
| auto& statusor = results[i]; |
| if (!statusor.ok()) { |
| if (num_addressable_devices == 1) { |
| return statusor.status(); |
| } else { |
| return AppendStatus( |
| statusor.status(), |
| absl::StrFormat("while running replica %d and partition %d of a " |
| "replicated computation (other " |
| "replicas may have failed as well).", |
| replica, partition)); |
| } |
| } |
| wrapped_results[i] = std::move(statusor.ValueOrDie()); |
| } |
| return wrapped_results; |
| } |
| |
| StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> |
| TfrtCpuExecutable::ExecuteSharded( |
| absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, |
| const ExecuteOptions& options) { |
| tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::ExecuteSharded"); |
| if (device_assignment_ == nullptr) { |
| return InvalidArgument("ExecuteShard expects a non-null device_assignment"); |
| } |
| for (int i = 0; i < addressable_devices_.size(); ++i) { |
| if (addressable_devices_[i] == device) { |
| VLOG(1) << "ExecuteShard executes computation " << name() |
| << " on assigned replica/partition on device " |
| << device->DebugString(); |
| return ExecuteHelper( |
| argument_handles, addressable_device_logical_ids_[i].replica, |
| addressable_device_logical_ids_[i].partition, RunId(), options, |
| /*last_collective_launch_event=*/tfrt::AsyncValueRef<CpuEvent>()); |
| } |
| } |
| return InvalidArgument( |
| "ExecuteShard attempted to execute on device id %d which is not " |
| "addressable by this client", |
| device->id()); |
| } |
| |
| StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> |
| TfrtCpuExecutable::ExecutePortable( |
| absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device, |
| const ExecuteOptions& options) { |
| tensorflow::profiler::TraceMe traceme("TfrtCpuExecutable::ExecutePortable"); |
| if (device_assignment_ != nullptr) { |
| return InvalidArgument("ExecutePortable gets a non-portable executable"); |
| } |
| if (num_replicas() != 1 || num_partitions() != 1) { |
| return InvalidArgument( |
| "ExecutePortable expects a single-core executable but gets " |
| "one with %d replica %d partition", |
| num_replicas(), num_partitions()); |
| } |
| if (device == nullptr) { |
| return InvalidArgument("ExecutePortable expects a device to be specified"); |
| } |
| VLOG(1) << "ExecutePortable executes single-core portable executable " |
| << name(); |
| return ExecuteHelper( |
| argument_handles, |
| /*replica=*/0, |
| /*partition=*/0, RunId(), options, |
| /*last_collective_launch_event=*/tfrt::AsyncValueRef<CpuEvent>(), |
| tensorflow::down_cast<TfrtCpuDevice*>(device)); |
| } |
| } // namespace xla |