| /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h" |
| |
| #include "tensorflow/compiler/jit/xla_device.h" |
| #include "tensorflow/compiler/jit/xla_launch_util.h" |
| #include "tensorflow/compiler/jit/xla_tensor.h" |
| #include "tensorflow/compiler/tf2xla/shape_util.h" |
| #include "tensorflow/compiler/tf2xla/tf2xla_util.h" |
| #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" |
| #include "tensorflow/core/framework/allocator.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/resource_mgr.h" |
| #include "tensorflow/core/framework/resource_var.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/platform/casts.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" |
| #include "tensorflow/core/tpu/kernels/tpu_op_consts.h" |
| #include "tensorflow/core/tpu/kernels/tpu_program_group.h" |
| #include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h" |
| #include "tensorflow/core/tpu/tpu_configuration.h" |
| #include "tensorflow/core/tpu/tpu_defs.h" |
| #include "tensorflow/core/tpu/tpu_execute.h" |
| #include "tensorflow/core/util/stream_executor_util.h" |
| #include "tensorflow/stream_executor/device_memory_allocator.h" |
| #include "tensorflow/stream_executor/tpu/tpu_executor.h" |
| #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h" |
| #include "tensorflow/stream_executor/tpu/tpu_node_context.h" |
| |
| namespace tensorflow { |
| |
| namespace reshard_util = ::tensorflow::tpu::reshard_variables; |
| |
| TPUReshardVariablesOpKernel::TPUReshardVariablesOpKernel( |
| OpKernelConstruction* context) |
| : AsyncOpKernel(context, /* is_deferred = */ true) { |
| OP_REQUIRES_OK(context, context->GetAttr("N", &num_vars_)); |
| } |
| |
| void TPUReshardVariablesOpKernel::ComputeAsync(OpKernelContext* context, |
| DoneCallback done) { |
| // If TPU launches are asynchronous, then perform the launch on this thread |
| // to avoid a thread hop, which has an observable latency cost. |
| OP_REQUIRES_OK_ASYNC(context, DoWork(context), done); |
| done(); |
| } |
| |
| Status TPUReshardVariablesOpKernel::DoWork(OpKernelContext* context) { |
| VLOG(1) << "Cloud TPU: TPUReshardVariablesOpKernel::DoWork"; |
| TF_RET_CHECK(context->input_dtype(num_vars_) == DT_STRING); |
| const Tensor* new_format_key; |
| TF_RETURN_IF_ERROR(context->input("new_format_key", &new_format_key)); |
| TF_RETURN_IF_ERROR(reshard_util::CheckIsValidKey(*new_format_key)); |
| |
| TF_RET_CHECK(context->input_dtype(num_vars_ + 1) == DT_RESOURCE); |
| const ResourceHandle& handle = HandleFromInput(context, num_vars_ + 1); |
| core::RefCountPtr<Var> format_state_var; |
| TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>( |
| context, handle, &format_state_var, [new_format_key](Var** ptr) { |
| *ptr = new Var(new_format_key->dtype()); |
| return Status::OK(); |
| })); |
| mutex_lock ml(*format_state_var->mu()); |
| const bool initialized = format_state_var->is_initialized; |
| if (initialized) { |
| TF_RETURN_IF_ERROR( |
| reshard_util::CheckIsValidKey(*format_state_var->tensor())); |
| } |
| |
| const bool state_is_default = |
| !initialized || reshard_util::IsDefaultKey(*format_state_var->tensor()); |
| const bool new_format_is_default = |
| reshard_util::IsDefaultKey(*new_format_key); |
| |
| if ((state_is_default && new_format_is_default) || |
| (initialized && format_state_var->tensor()->vec<tstring>()(2) == |
| new_format_key->vec<tstring>()(2))) { |
| VLOG(1) << "Sharding unchanged, nothing to do."; |
| return Status::OK(); |
| } |
| |
| if (!state_is_default) { |
| // Convert the current format to default (unsharded). |
| VLOG(1) << "Unsharding with key: " |
| << format_state_var->tensor()->vec<tstring>()(2); |
| TF_RETURN_IF_ERROR( |
| DoTpuExecute(context, *format_state_var->tensor(), |
| tpu::CompilationCacheFetchTarget::UNSHARDING)); |
| } |
| |
| if (!new_format_is_default) { |
| // Convert the new format. |
| VLOG(1) << "Sharding with key: " << new_format_key->vec<tstring>()(2); |
| TF_RETURN_IF_ERROR(DoTpuExecute( |
| context, *new_format_key, tpu::CompilationCacheFetchTarget::SHARDING)); |
| } |
| |
| // Change the state. |
| *format_state_var->tensor() = *new_format_key; |
| format_state_var->is_initialized = true; |
| return Status::OK(); |
| } |
| |
| Status TPUReshardVariablesOpKernel::DoTpuExecute( |
| OpKernelContext* context, const Tensor& format_key, |
| tpu::CompilationCacheFetchTarget fetch_target) { |
| const XlaDevice::Metadata* metadata; |
| TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata)); |
| const int device_ordinal = metadata->device_ordinal(); |
| |
| // We are guaranteed that the underlying object won't be deleted out from |
| // under us |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_interfaces, |
| tpu::TpuNodeContext::Create(device_ordinal)); |
| |
| profiler::TraceMe trace_me( |
| [device_ordinal] { |
| return profiler::TraceMeEncode("TPUReshardVariablesOpKernel", |
| {{"device_ordinal", device_ordinal}}); |
| }, |
| /*level=*/2); |
| profiler::TraceMe trace_me_init("TPUReshardVariablesOpKernel::Init", |
| /*level=*/2); |
| |
| string rendezvous_key_base; |
| std::unique_ptr<tpu::CompilationCacheEntryRef> entry_ref; |
| TF_RETURN_IF_ERROR(reshard_util::GetComputationCacheEntry( |
| format_key, &rendezvous_key_base, &entry_ref, fetch_target)); |
| tpu::TpuCompilationCacheEntry entry = entry_ref->get(); |
| if (entry.tpu_program_group() == nullptr) { |
| VLOG(2) << "Sharding/unsharding program does not exist, so this is default " |
| "sharding."; |
| return Status::OK(); |
| } |
| |
| const tpu::TpuProgramGroupInterface* tpu_program_group = |
| entry.tpu_program_group(); |
| const int core_index = entry.core_index(); |
| const TPUExecutableInfoProto& executable_info_proto = |
| tpu_program_group->executable_info(core_index); |
| const TPUExecutableInfoProto* executable = &executable_info_proto; |
| |
| xla::Backend* const backend = node_interfaces->backend(); |
| xla::TransferManager* const transfer_manager = backend->transfer_manager(); |
| |
| CHECK(context->op_device_context()); |
| se::Stream* stream = context->op_device_context()->stream(); |
| |
| TF_RET_CHECK(executable->input_shapes_size() == 1); |
| xla::Shape host_shape(executable->input_shapes(0)); |
| std::vector<VariableInfo> variables; |
| for (int i = 0; i < num_vars_; ++i) { |
| TF_RET_CHECK(context->input_dtype(i) == DT_RESOURCE); |
| const ResourceHandle& handle = HandleFromInput(context, i); |
| Var* variable; |
| TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable)); |
| variables.push_back(VariableInfo(i, handle.name(), variable)); |
| } |
| |
| // Block for previous TPUExecute ops so that the memory used for them could be |
| // freed. |
| TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); |
| // Lock variables to prevent concurrent access. |
| TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables))); |
| |
| // Build input buffers. |
| TF_ASSIGN_OR_RETURN(auto input_buffers, reshard_util::BuildInputBuffers( |
| context, variables, host_shape, |
| backend, device_ordinal, stream)); |
| xla::ShapedBuffer shaped_buffer(std::move(host_shape), input_buffers.shape(), |
| device_ordinal); |
| shaped_buffer.set_buffers(input_buffers.Map<se::DeviceMemoryBase>( |
| [](const xla::MaybeOwningDeviceMemory& buffer) { |
| return buffer.AsDeviceMemoryBase(); |
| })); |
| |
| // Write input root tuple. |
| TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr, |
| backend->BorrowStream(device_ordinal)); |
| if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(), |
| shaped_buffer)) { |
| TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable( |
| transfer_stream_ptr.get(), shaped_buffer)); |
| stream->ThenWaitFor(transfer_stream_ptr.get()); |
| } else { |
| TF_RETURN_IF_ERROR( |
| transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer)); |
| } |
| VLOG(4) << "Input buffers: " << shaped_buffer.ToString(); |
| |
| TF_RET_CHECK(!executable->has_session_module()) |
| << "session module not supported in sharding/unsharding program."; |
| |
| auto definition_event = std::make_shared<se::Event>(stream->parent()); |
| TF_RET_CHECK(definition_event->Init()) |
| << "TPU definition event initialization failed"; |
| |
| trace_me_init.Stop(); |
| |
| // Execute the program. |
| std::unique_ptr<xla::DeviceAssignment> device_assignment; |
| if (executable->has_device_assignment()) { |
| TF_ASSIGN_OR_RETURN( |
| device_assignment, |
| xla::DeviceAssignment::Deserialize(executable->device_assignment())); |
| } |
| std::vector<xla::ExecutionInput> input; |
| input.emplace_back(xla::ExecutionInput(std::move(input_buffers), |
| shaped_buffer.on_host_shape())); |
| |
| const TPUHostTransferInfoProto& host_transfer_info = |
| tpu_program_group->host_transfer_info(core_index); |
| |
| TF_ASSIGN_OR_RETURN( |
| xla::ExecutionOutput output, |
| TPUExecute(*executable, host_transfer_info, |
| *tpu_program_group->hlo_metadatas()[core_index], |
| std::move(input), rendezvous_key_base, GetXLARandomSeed(), |
| node_interfaces.get(), device_assignment.get(), |
| context->cancellation_manager(), context, stream, |
| transfer_stream_ptr.get(), |
| tpu_program_group->tpu_program(core_index))); |
| |
| stream->ThenRecordEvent(definition_event.get()); |
| |
| // Assign the new buffers to the variables. |
| xla::ScopedShapedBuffer result = output.ConsumeResult(); |
| |
| // Only perform compaction when sharding. |
| // NOTE: Compaction is not supported on some TPUs, see b/168322060 for details |
| if (node_interfaces->CompactionSupported(device_ordinal) && |
| fetch_target == tpu::CompilationCacheFetchTarget::SHARDING) { |
| // Block until program execution is done so that input, output, and program |
| // cache memory can be actually released. |
| TF_RETURN_IF_ERROR(transfer_stream_ptr->BlockHostUntilDone()); |
| TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); |
| { |
| // Explicitly release any RAII objects owning on-device allocations. |
| auto unused = output.ConsumeToBeReleased(); |
| } |
| // Release variables holding inputs. |
| for (int i = 0; i < variables.size(); ++i) { |
| *variables[i].var()->tensor() = |
| Tensor(variables[i].var()->tensor()->dtype()); |
| } |
| // Flush on-device program memory cache. |
| TF_RETURN_IF_ERROR( |
| reshard_util::FlushProgramMemory(backend->platform(), device_ordinal)); |
| TF_RETURN_IF_ERROR(reshard_util::PerformCompaction(stream)); |
| } |
| return reshard_util::UpdateOutputVariables( |
| context, std::move(result), executable->output_tensor_shapes(), backend, |
| stream, device_ordinal, variables, definition_event); |
| } |
| |
| TPUReshardVariablesOpKernel::~TPUReshardVariablesOpKernel() = default; |
| |
| #if !defined(PLATFORM_GOOGLE) |
| REGISTER_KERNEL_BUILDER(Name("TPUReshardVariables") |
| .Device(DEVICE_TPU_NODE) |
| .HostMemory("format_state_var") |
| .HostMemory("new_format_key"), |
| TPUReshardVariablesOpKernel); |
| #endif |
| |
| } // namespace tensorflow |