| /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #define EIGEN_USE_THREADS |
| |
| #include <string> |
| #include <utility> |
| |
| #include "mlir/Dialect/Async/IR/AsyncTypes.h" |
| #include "mlir/ExecutionEngine/AsyncRuntime.h" |
| #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" |
| #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt.h" |
| #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_passes.h" |
| #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_request_context.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/platform/dynamic_annotations.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" |
| #include "tensorflow/core/tfrt/utils/fallback_tensor.h" |
| #include "tfrt/cpu/jit/async_runtime.h" // from @tf_runtime |
| #include "tfrt/cpu/jit/async_runtime_api.h" // from @tf_runtime |
| #include "tfrt/cpu/jit/cpurt.h" // from @tf_runtime |
| #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime |
| #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime |
| #include "tfrt/host_context/chain.h" // from @tf_runtime |
| #include "tfrt/host_context/execution_context.h" // from @tf_runtime |
| #include "tfrt/host_context/host_buffer.h" // from @tf_runtime |
| #include "tfrt/host_context/host_context.h" // from @tf_runtime |
| #include "tfrt/host_context/kernel_registry.h" // from @tf_runtime |
| #include "tfrt/host_context/kernel_utils.h" // from @tf_runtime |
| #include "tfrt/support/error_util.h" // from @tf_runtime |
| #include "tfrt/support/forward_decls.h" // from @tf_runtime |
| #include "tfrt/support/rc_array.h" // from @tf_runtime |
| #include "tfrt/support/string_util.h" // from @tf_runtime |
| #include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime |
| #include "tfrt/tensor/tensor_shape.h" // from @tf_runtime |
| |
| namespace tensorflow { |
| namespace tfrt { |
| namespace jit { |
| namespace { |
| |
| using ::llvm::Expected; |
| |
| using ::tfrt::ArrayRef; |
| using ::tfrt::AsyncValue; |
| using ::tfrt::AsyncValuePtr; |
| using ::tfrt::AsyncValueRef; |
| using ::tfrt::Chain; |
| using ::tfrt::CompilationUnitAttribute; |
| using ::tfrt::EnqueueWork; |
| using ::tfrt::ExecutionContext; |
| using ::tfrt::IndirectAsyncValue; |
| using ::tfrt::KernelRegistry; |
| using ::tfrt::MakeConstructedAsyncValueRef; |
| using ::tfrt::MakeErrorAsyncValueRef; |
| using ::tfrt::MakeStringError; |
| using ::tfrt::RCArray; |
| using ::tfrt::RCReference; |
| using ::tfrt::RemainingResults; |
| using ::tfrt::RepeatedArguments; |
| using ::tfrt::RequestContext; |
| using ::tfrt::StrCat; |
| using ::tfrt::StringAttribute; |
| |
| using ::tfrt::cpu::jit::CompilationOptions; |
| using ::tfrt::cpu::jit::EmitErrors; |
| using ::tfrt::cpu::jit::Executable; |
| using ::tfrt::cpu::jit::JitExecutable; |
| using ::tfrt::cpu::jit::JitExecutableCache; |
| using ::tfrt::cpu::jit::MemrefDesc; |
| using ::tfrt::cpu::jit::ReturnAsyncStridedMemref; |
| using ::tfrt::cpu::jit::ReturnStridedMemref; |
| using ::tfrt::cpu::jit::ReturnValueConverter; |
| |
| using ::tensorflow::tfd::KernelFallbackCompatRequestState; |
| using ::tensorflow::tfrt_stub::FallbackTensor; |
| |
| // -------------------------------------------------------------------------- // |
| // JIT compiled kernels use Eigen CPU device as async runtime worker threads. |
| // -------------------------------------------------------------------------- // |
| |
| static Expected<Eigen::ThreadPoolInterface*> GetWorkerThreads( |
| const ExecutionContext& exec_ctx) { |
| RequestContext* req_ctx = exec_ctx.request_ctx(); |
| |
| auto* fallback = req_ctx->GetDataIfExists<KernelFallbackCompatRequestState>(); |
| if (!fallback) return MakeStringError("fallback request state was not found"); |
| |
| Device* host_cpu = fallback->device_manager().HostCPU(); |
| assert(host_cpu && "fallback state must have a valid host cpu device"); |
| |
| const Eigen::ThreadPoolDevice* eigen = host_cpu->eigen_cpu_device(); |
| assert(eigen && "host cpu device must have a valid Eigen thread pool device"); |
| |
| return eigen->getPool(); |
| } |
| |
| // -------------------------------------------------------------------------- // |
| // Compile compilation unit attribute to an executable result. |
| // -------------------------------------------------------------------------- // |
| |
| static Expected<AsyncValuePtr<JitExecutable>> CompileImpl( |
| CompilationUnitAttribute kernel, const ExecutionContext& exec_ctx) { |
| // We only support functions nested in top level compiled module. |
| if (kernel.nested_symbols().size() != 1) |
| return MakeStringError( |
| "kernel function has to be defined in a top-level module"); |
| |
| // Request context must be initialized with the tf_cpurt state. |
| TfCpuRtRequestState* state = |
| exec_ctx.request_ctx()->GetDataIfExists<TfCpuRtRequestState>(); |
| if (!state) |
| return MakeStringError("cpurt state not found in the request context"); |
| |
| JitExecutableCache* jit_executable_cache = state->jit_executable_cache; |
| |
| // TODO(ezhulenev): CompilationUnitAttribute in addition to an `id` should |
| // provide a hash (or something like sha-256 fingerprint) of its content for |
| // cache lookup. Currently we rely on the fact that the SavedModel never |
| // unloads a Bef file, and there is a 1-to-1 relationship between the |
| // ResourceContext and the SavedModel, so the `id` is guaranteed to be a |
| // unique key for the cache lookup. |
| intptr_t key = kernel.id(); |
| |
| // Maybe return JitExecutable from the cache. |
| if (auto cached = jit_executable_cache->Find(key)) return cached; |
| |
| // Get the worker threads from the execution context. Do this before |
| // allocating an async value to make sure that we can try to instantiate the |
| // executable. |
| Expected<Eigen::ThreadPoolInterface*> worker_threads = |
| GetWorkerThreads(exec_ctx); |
| if (auto err = worker_threads.takeError()) return std::move(err); |
| |
| // Allocate a placeholder for the compiled JitExecutable. |
| JitExecutableCache::Entry entry = jit_executable_cache->Allocate(key); |
| |
| // We lost the race; some other invocation will do the compilation. |
| if (!entry.allocated) return entry.ptr; |
| |
| // Compile kernel asynchronously in the host context thread pool. |
| EnqueueWork(exec_ctx, [kernel, workers = *worker_threads, ptr = entry.ptr]() { |
| CompilationOptions opts; |
| // All entry memrefs must have alignment compatible with Tensorflow. |
| opts.alignment = EIGEN_MAX_ALIGN_BYTES; // Eigen included by tensor.h |
| opts.num_worker_threads = workers->NumThreads(); |
| opts.register_dialects = mlir::RegisterAllTensorFlowDialects; |
| opts.register_pass_pipeline = CreateTfCpuRtPipeline; |
| |
| auto entrypoint = kernel.nested_symbols()[0]; |
| auto module = kernel.serialized_operation(); |
| |
| // Instantiate new JitExecutable from the MLIR source. |
| Expected<JitExecutable> jit_executable = |
| JitExecutable::Instantiate(module, entrypoint, opts); |
| |
| // Set the entry async value state to error or concrete. |
| if (auto err = jit_executable.takeError()) |
| ptr.SetError(std::move(err)); |
| else |
| ptr.emplace(std::move(*jit_executable)); |
| }); |
| |
| return entry.ptr; |
| } |
| |
| // Compiles kernel into the JitExecutable and updates JitExecutableCache. |
| static AsyncValueRef<Chain> Compile(StringAttribute device, |
| CompilationUnitAttribute kernel, |
| const ExecutionContext& exec_ctx) { |
| // Trigger kernel compilation, that will update the JitExecutableCache. |
| Expected<AsyncValuePtr<JitExecutable>> executable = |
| CompileImpl(kernel, exec_ctx); |
| |
| // Return immediately if can't compile the kernel. |
| if (auto err = executable.takeError()) |
| return MakeErrorAsyncValueRef(StrCat(err)); |
| |
| // Signal compilation completion using an async chain. |
| auto compiled = MakeConstructedAsyncValueRef<Chain>(); |
| |
| executable->AndThen([executable = *executable, res = compiled.CopyRef()]() { |
| if (executable.IsError()) |
| res.SetError(executable.GetError()); |
| else |
| res.SetStateConcrete(); |
| }); |
| |
| return compiled; |
| } |
| |
| // -------------------------------------------------------------------------- // |
| // Execute compiled CPURT kernels with Fallback Runtime interop. |
| // -------------------------------------------------------------------------- // |
| |
| using TensorflowReturnValueConverter = |
| ReturnValueConverter<TensorflowConversionContext>; |
| |
| // Converts Tensor to the Memref Descriptor and verifies that the Tensor |
| // value is compatible with the memref type. |
| static void ConvertTensorToMemrefDesc(const tensorflow::Tensor& tensor, |
| MemrefDesc* memref) { |
| memref->dtype = tfd::GetTfrtDtype(tensor.dtype()); |
| memref->data = const_cast<void*>(tensor.data()); |
| memref->offset = 0; |
| |
| int rank = tensor.dims(); |
| memref->sizes.resize_for_overwrite(rank); |
| memref->strides.resize_for_overwrite(rank); |
| |
| // Fill memref sizes and compute strides from the tensor dimensions. |
| ssize_t multiplier = 1; |
| for (int i = rank - 1; i >= 0; --i) { |
| ssize_t dim_size = tensor.dim_size(i); |
| memref->sizes[i] = dim_size; |
| memref->strides[i] = multiplier; |
| multiplier *= dim_size; |
| } |
| } |
| |
| static void ConvertTensorOperandsToMemrefDesc( |
| RepeatedArguments<FallbackTensor> operands, |
| llvm::SmallVectorImpl<MemrefDesc>* memrefs) { |
| assert(memrefs->empty() && "memrefs must be empty"); |
| memrefs->resize(operands.size()); |
| |
| for (unsigned i = 0; i < operands.size(); ++i) |
| ConvertTensorToMemrefDesc(operands[i].tensor(), &(*memrefs)[i]); |
| } |
| |
| struct DebugListener : public JitExecutable::Listener { |
| void notifyModuleSpecialized(ArrayRef<mlir::Type> inputs) const override { |
| std::string message; |
| llvm::raw_string_ostream(message) |
| << "Specialized module: " << inputs << "\n"; |
| printf("%s", message.c_str()); |
| fflush(stdout); |
| } |
| |
| void notifyValueSpecialized(unsigned index, mlir::Type type, |
| mlir::Attribute attr) const override { |
| std::string message; |
| llvm::raw_string_ostream(message) << "Arg[" << index << "] " |
| << "value specialized: " << attr << "\n"; |
| printf("%s", message.c_str()); |
| fflush(stdout); |
| } |
| }; |
| |
| static void ExecuteImpl(Executable& executable, |
| const llvm::SmallVectorImpl<MemrefDesc>& memrefs, |
| RepeatedArguments<FallbackTensor> operands, |
| RemainingResults results, |
| const ExecutionContext& exec_ctx) { |
| // Bind execution trace to the request context. |
| profiler::TraceMe trace_me([&] { |
| return profiler::TraceMeEncode("tf_cpurt.Execute", |
| {{"id", exec_ctx.request_ctx()->id()}, |
| {"executable", executable.name()}}); |
| }); |
| |
| // Keep track of memory address to tensor mapping for result conversion. |
| auto ctx = std::make_unique<TensorflowConversionContext>(operands.size()); |
| for (auto& t : operands) |
| ctx->tensor_operands.insert({t.tensor().data(), &t.tensor()}); |
| |
| // Tensorflow -> CPURT only supportes returning Memrefs as Tensors. |
| TensorflowReturnValueConverter converter(results, std::move(ctx)); |
| converter.AddConversion(ReturnAsyncStridedMemref<ConvertTensor>); |
| converter.AddConversion(ReturnStridedMemref<ConvertTensor>); |
| |
| // Get the worker threads from the execution context. |
| Expected<Eigen::ThreadPoolInterface*> worker_threads = |
| GetWorkerThreads(exec_ctx); |
| if (auto err = worker_threads.takeError()) |
| return EmitErrors(results, std::move(err), exec_ctx); |
| |
| // Override async runtime worker threads with fallback Eigen thread pool. |
| Executable::ExecuteOpts opts; |
| opts.async_runtime_worker_threads = *worker_threads; |
| |
| // Error propagation happens in the result converter. |
| if (auto err = executable.Execute(memrefs, converter, exec_ctx, opts)) return; |
| |
| // If executable is async keep operands and conversion context alive until |
| // results become available. |
| if (executable.IsAsync()) |
| RunWhenReady(results.values(), |
| [operands = RCArray<AsyncValue>(operands.values()), |
| ctx = converter.TakeConversionContext()] {}); |
| } |
| |
| // Gets a specialized Executable async value from the JitExecutable, and then |
| // dispatches it inline or using and-then continuation depending on the async |
| // value state. |
| static void ExecuteImpl(JitExecutable& jit_executable, |
| RepeatedArguments<FallbackTensor> operands, |
| RemainingResults results, |
| const ExecutionContext& exec_ctx, bool debug) { |
| // Convert Tensor operands to memref descriptors. |
| llvm::SmallVector<MemrefDesc> memrefs; |
| ConvertTensorOperandsToMemrefDesc(operands, &memrefs); |
| |
| // Get an executable that might be specialized to the operands. |
| DebugListener debug_listener; |
| |
| AsyncValuePtr<Executable> executable = jit_executable.GetExecutable( |
| memrefs, exec_ctx, debug ? &debug_listener : nullptr); |
| |
| // If executable is available execute it inline. |
| if (executable.IsAvailable()) { |
| if (executable.IsError()) { |
| EmitErrors(results, executable.GetError(), exec_ctx); |
| } else { |
| ExecuteImpl(executable.get(), memrefs, operands, results, exec_ctx); |
| } |
| return; |
| } |
| |
| // Otherwise execute it when the executable will become available. This |
| // requires careful lifetime extension of all async values passed as operands |
| // to the kernel (and also results that will become available asynchronously). |
| |
| // Allocate indirect async values for all results, we'll forward them to the |
| // actual async values computed by the executable later. |
| for (unsigned i = 0; i < results.size(); ++i) |
| results.AllocateIndirectResultAt(i); |
| |
| // Call executable when it's ready with the original operands. |
| executable.AndThen([exec_ctx, executable, memrefs = std::move(memrefs), |
| r = RCArray<AsyncValue>(results.values()), |
| o = RCArray<AsyncValue>(operands.values())] { |
| // Allocate storage for the executable results. |
| llvm::SmallVector<RCReference<AsyncValue>> results_storage; |
| results_storage.resize(r.size()); |
| |
| // Reconstruct arguments and results from captured async values. |
| RepeatedArguments<FallbackTensor> operands(o.values()); |
| RemainingResults results(exec_ctx.host(), results_storage); |
| |
| if (executable.IsError()) { |
| EmitErrors(results, executable.GetError(), exec_ctx); |
| } else { |
| ExecuteImpl(*executable, memrefs, operands, results, exec_ctx); |
| } |
| |
| // Forward previously allocated indirect results to the actual results. |
| for (unsigned i = 0; i < r.size(); ++i) |
| llvm::cast<IndirectAsyncValue>(*r[i]).ForwardTo( |
| std::move(results_storage[i])); |
| }); |
| } |
| |
| // Gets a JitExecutable async value from the cache, and then dispatches it |
| // inline or using and-then continuation depending on the async value state. |
| static void ExecuteImpl(RepeatedArguments<FallbackTensor> operands, |
| RemainingResults results, StringAttribute device, |
| CompilationUnitAttribute kernel, |
| const ExecutionContext& exec_ctx, bool debug) { |
| // Compile kernel module into the JitExecutable. |
| Expected<AsyncValuePtr<JitExecutable>> jit_executable = |
| CompileImpl(kernel, exec_ctx); |
| |
| if (auto err = jit_executable.takeError()) |
| return EmitErrors(results, std::move(err), exec_ctx); |
| |
| // If kernel is available execute it inline. |
| if (jit_executable->IsAvailable()) { |
| if (jit_executable->IsError()) { |
| EmitErrors(results, jit_executable->GetError(), exec_ctx); |
| } else { |
| ExecuteImpl(**jit_executable, operands, results, exec_ctx, debug); |
| } |
| return; |
| } |
| |
| // Otherwise execute it when the executable will become available. This |
| // requires careful lifetime extension of all async values passed as operands |
| // to the kernel (and also results that will become available asynchronously). |
| |
| // Allocate indirect async values for all results, we'll forward them to the |
| // actual async values computed by the executable later. |
| for (unsigned i = 0; i < results.size(); ++i) |
| results.AllocateIndirectResultAt(i); |
| |
| // Call executable when it's ready with the original operands. |
| jit_executable->AndThen([exec_ctx, jit_executable = *jit_executable, |
| r = RCArray<AsyncValue>(results.values()), |
| o = RCArray<AsyncValue>(operands.values()), debug] { |
| // Allocate storage for compiled executable results. |
| llvm::SmallVector<RCReference<AsyncValue>> results_storage; |
| results_storage.resize(r.size()); |
| |
| // Reconstruct arguments and results from captured async values. |
| RepeatedArguments<FallbackTensor> operands(o.values()); |
| RemainingResults results(exec_ctx.host(), results_storage); |
| |
| if (jit_executable.IsError()) { |
| EmitErrors(results, jit_executable.GetError(), exec_ctx); |
| } else { |
| ExecuteImpl(*jit_executable, operands, results, exec_ctx, debug); |
| } |
| |
| // Forward previously entry indirect results to the actual results. |
| for (unsigned i = 0; i < r.size(); ++i) |
| llvm::cast<IndirectAsyncValue>(*r[i]).ForwardTo( |
| std::move(results_storage[i])); |
| }); |
| } |
| |
| // Compiles kernel into the JitExecutable and executes it with the fallback |
| // tensors operands. |
| static void Execute(RepeatedArguments<FallbackTensor> operands, |
| RemainingResults results, StringAttribute device, |
| CompilationUnitAttribute kernel, |
| const ExecutionContext& exec_ctx) { |
| ExecuteImpl(operands, results, device, kernel, exec_ctx, /*debug=*/false); |
| } |
| |
| // Compiles kernel into the JitExecutable and executes it with the fallback |
| // tensors operands in the debug mode: prints compilation diagnostics to the |
| // standard output. Should be used only in tests for verifying compiler |
| // internals. |
| static void ExecuteDebug(RepeatedArguments<FallbackTensor> operands, |
| RemainingResults results, StringAttribute device, |
| CompilationUnitAttribute kernel, |
| const ExecutionContext& exec_ctx) { |
| ExecuteImpl(operands, results, device, kernel, exec_ctx, /*debug=*/true); |
| } |
| |
| } // namespace |
| |
| void RegisterTfCpuRuntimeKernels(KernelRegistry* registry) { |
| registry->AddKernel("tf_cpurt.fallback.compile", TFRT_KERNEL(Compile)); |
| registry->AddKernel("tf_cpurt.fallback.execute", TFRT_KERNEL(Execute)); |
| registry->AddKernel("tf_cpurt.fallback.debug.execute", |
| TFRT_KERNEL(ExecuteDebug)); |
| } |
| |
| } // namespace jit |
| } // namespace tfrt |
| } // namespace tensorflow |