blob: 4491e956b8cc6d562ce71db10fef05ac1a5ad4d3 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include <memory>
#include <string>
#include <utility>
#include "mlir/Dialect/Async/IR/AsyncTypes.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/ExecutionEngine/AsyncRuntime.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.h"
#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels_registration.h"
#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h"
#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h"
#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
#include "tfrt/jitrt/async_runtime.h" // from @tf_runtime
#include "tfrt/jitrt/async_runtime_api.h" // from @tf_runtime
#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
#include "tfrt/jitrt/jitrt_compiler.h" // from @tf_runtime
#include "tfrt/dtype/dtype.h" // from @tf_runtime
#include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
#include "tfrt/host_context/chain.h" // from @tf_runtime
#include "tfrt/host_context/execution_context.h" // from @tf_runtime
#include "tfrt/host_context/host_buffer.h" // from @tf_runtime
#include "tfrt/host_context/host_context.h" // from @tf_runtime
#include "tfrt/host_context/kernel_registry.h" // from @tf_runtime
#include "tfrt/host_context/kernel_utils.h" // from @tf_runtime
#include "tfrt/host_context/shared_context.h" // from @tf_runtime
#include "tfrt/support/error_util.h" // from @tf_runtime
#include "tfrt/support/forward_decls.h" // from @tf_runtime
#include "tfrt/support/rc_array.h" // from @tf_runtime
#include "tfrt/support/string_util.h" // from @tf_runtime
#include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime
#include "tfrt/tensor/tensor_shape.h" // from @tf_runtime
namespace tensorflow {
namespace {
#if __cplusplus >= 201703L
using ::std::any_cast;
#else
using ::llvm::any_cast;
#endif
using ::llvm::Expected;
using ::llvm::None;
using ::llvm::Optional;
using ::tfrt::Argument;
using ::tfrt::ArrayRef;
using ::tfrt::AsyncValue;
using ::tfrt::AsyncValuePtr;
using ::tfrt::AsyncValueRef;
using ::tfrt::Attribute;
using ::tfrt::Chain;
using ::tfrt::CompilationUnitAttribute;
using ::tfrt::DecodedDiagnostic;
using ::tfrt::DType;
using ::tfrt::EmitErrorAsync;
using ::tfrt::ExecutionContext;
using ::tfrt::HostContext;
using ::tfrt::IndirectAsyncValue;
using ::tfrt::KernelRegistry;
using ::tfrt::MakeAvailableAsyncValueRef;
using ::tfrt::MakeConstructedAsyncValueRef;
using ::tfrt::MakeErrorAsyncValueRef;
using ::tfrt::MakeStringError;
using ::tfrt::RCArray;
using ::tfrt::RCReference;
using ::tfrt::RemainingResults;
using ::tfrt::RepeatedArguments;
using ::tfrt::RequestContext;
using ::tfrt::SharedContext;
using ::tfrt::StrCat;
using ::tfrt::StringAttribute;
using ::tfrt::TaskFunction;
using ::tfrt::jitrt::CompilationOptions;
using ::tfrt::jitrt::CompilationPipelineOptions;
using ::tfrt::jitrt::CreateDefaultJitRtCompilationPipeline;
using ::tfrt::jitrt::EigenThreadPoolAsyncTaskRunner;
using ::tfrt::jitrt::Executable;
using ::tfrt::jitrt::JitExecutable;
using ::tfrt::jitrt::JitExecutableCache;
using ::tfrt::jitrt::MemrefDesc;
using ::tfrt::jitrt::OperandConstraint;
using ::tfrt::jitrt::RegisterDefaultJitRtDialects;
using ::tfrt::jitrt::ReturnErrors;
using ::tfrt::jitrt::ReturnStridedMemref;
using ::tfrt::jitrt::ReturnValueConversion;
using ::tfrt::jitrt::SpecializationListener;
using ::tfrt::jitrt::StaticReturnValueConverter;
using ::tensorflow::profiler::TraceMe;
using ::tensorflow::profiler::TraceMeEncode;
using ::tensorflow::tfd::KernelFallbackCompatRequestState;
using ::tensorflow::tfrt_stub::FallbackTensor;
using ::tensorflow::thread::ThreadPool;
// -------------------------------------------------------------------------- //
// Dedicated thread pool for running compilation tasks.
// -------------------------------------------------------------------------- //
class CompilationThreadPool : public SharedContext {
public:
explicit CompilationThreadPool(HostContext* host) { Reset(); }
static CompilationThreadPool& Get(HostContext* host) {
return host->GetOrCreateSharedContext<CompilationThreadPool>();
}
template <typename Task>
void Schedule(Task&& task) {
// Because compilation tasks can capture move only types, and Tensorflow
// thread pool requires std::function tasks, we have to do manual memory
// management here.
auto ptr = std::make_unique<Task>(std::forward<Task>(task));
thread_pool_->Schedule([ptr = ptr.release()]() {
(*ptr)();
delete ptr;
});
}
// This is an unsafe function intended only for use in tests. It is undefined
// behavior to call it concurrently with `Schedule`.
void Reset() {
thread_pool_ = std::make_unique<ThreadPool>(
Env::Default(), "tf-jitrt-compiler", /*num_threads=*/32);
}
private:
std::unique_ptr<ThreadPool> thread_pool_;
};
// -------------------------------------------------------------------------- //
// JIT compiled kernels use Eigen ThreadPool managed by the kernel fallback as
// an async runtime worker threads.
// -------------------------------------------------------------------------- //
static Expected<Eigen::ThreadPoolInterface*> GetWorkerThreads(
const ExecutionContext& exec_ctx) {
RequestContext* req_ctx = exec_ctx.request_ctx();
auto* fallback = req_ctx->GetDataIfExists<KernelFallbackCompatRequestState>();
if (LLVM_UNLIKELY(!fallback))
return MakeStringError("fallback request state was not found");
// Return user provided intra op thread pool if it is available.
if (LLVM_LIKELY(fallback->intra_op_threadpool()))
return fallback->intra_op_threadpool();
// Otherwise find the default CPU device in the device manager.
Device* host_cpu = fallback->device_manager().HostCPU();
assert(host_cpu && "fallback state must have a valid host cpu device");
const Eigen::ThreadPoolDevice* eigen = host_cpu->eigen_cpu_device();
assert(eigen && "host cpu device must have a valid Eigen thread pool device");
return eigen->getPool();
}
// -------------------------------------------------------------------------- //
// Compile compilation unit attribute to an executable result.
// -------------------------------------------------------------------------- //
// Options for the `tf-jitrt-pipeline`. We do not use MLIR pass options directly
// because they are not copyable or movable, and we need to pass them cheaply
// across the async compilation tasks boundary.
struct TfJitRtPipelineOpts {
bool vectorize;
bool legalize_i1_tensors;
};
// Prints memref descriptor as a tensor type: tensor<NxMxf32>.
static std::string AsTensorType(const MemrefDesc& desc) {
std::string str;
llvm::raw_string_ostream os(str);
os << "tensor<";
for (ssize_t size : desc.sizes) os << size << "x";
os << desc.dtype;
os << ">";
return str;
}
// Print memref descriptor content to trace value specializations.
static std::string AsTensorContent(const MemrefDesc& desc) {
std::string str;
llvm::raw_string_ostream os(str);
auto print_0d = [&](auto type_tag) {
os << desc.dtype << ": " << *static_cast<decltype(type_tag)*>(desc.data);
};
auto print_1d = [&](auto type_tag) {
os << desc.dtype << ": [";
for (size_t i = 0; i < desc.sizes[0]; ++i) {
if (i != 0) os << ",";
os << static_cast<decltype(type_tag)*>(desc.data)[i];
}
os << "]";
};
auto type_dispatch = [&](auto functor) {
switch (desc.dtype) {
case DType::I32:
functor(int32_t{});
break;
case DType::I64:
functor(int64_t{});
break;
default:
os << "<unsupported dtype " << desc.dtype << ">";
}
};
size_t rank = desc.sizes.size();
switch (rank) {
case 0:
type_dispatch(print_0d);
break;
case 1:
type_dispatch(print_1d);
break;
default:
os << "<unsupported rank " << desc.sizes.size() << ">";
}
return str;
}
// Gets the session name from the fallback request state.
static const std::string GetSessionName(RequestContext* req_ctx) {
auto* fallback = req_ctx->GetDataIfExists<KernelFallbackCompatRequestState>();
if (!fallback) return "<unknown>";
return fallback->session_metadata().name();
}
static Expected<AsyncValuePtr<JitExecutable>> CompileImpl(
const CompilationUnitAttribute& kernel, const ExecutionContext& exec_ctx,
const Optional<TfJitRtPipelineOpts>& opts = None) {
// Request context must be initialized with the tf_jitrt state.
auto* state = exec_ctx.request_ctx()->GetDataIfExists<TfJitRtRequestState>();
if (LLVM_UNLIKELY(!state))
return MakeStringError("tf_jitrt state not found in the request context");
// We rely on the unique `id` provided by the CompilationUnitAttribute to look
// up the JitExecutable in the cache. This id is guaranteed to be unique
// within a Bef file. Currently we rely on the fact that the SavedModel
// never unloads a Bef file, and there is a 1-to-1 relationship between the
// ResourceContext and the SavedModel.
//
// TODO(b/206081322): Different compilation options should create unique
// compiled kernel cache keys.
size_t key = kernel.id();
JitExecutableCache* jit_executable_cache = state->jit_executable_cache;
// Maybe return JitExecutable from the cache.
auto cached = jit_executable_cache->Find(key);
if (LLVM_LIKELY(cached)) return cached;
// Get the worker threads from the execution context. Do this before
// allocating an async value to make sure that we can try to instantiate the
// executable.
Expected<Eigen::ThreadPoolInterface*> worker_threads =
GetWorkerThreads(exec_ctx);
if (auto err = worker_threads.takeError()) return std::move(err);
// Allocate a placeholder for the compiled JitExecutable.
JitExecutableCache::Entry entry = jit_executable_cache->Allocate(key);
// We lost the race; some other invocation will do the compilation.
if (!entry.allocated) return entry.ptr;
// Given that compilation happens asynchronously, passing (or capturing) these
// by value prevents use-after-free errors.
struct KernelInfo {
intptr_t id;
std::string entrypoint;
std::string name;
std::string serialized_operation;
} kernel_info;
// We only support functions nested in top level compiled module.
if (kernel.nested_symbols().size() != 1)
return MakeStringError(
"kernel function has to be defined in a top-level module");
// TODO(ecg): use designed initializers + const when C++20 is adopted.
kernel_info.id = kernel.id();
kernel_info.entrypoint = kernel.nested_symbols()[0];
kernel_info.name = kernel.root_symbol();
kernel_info.serialized_operation = kernel.serialized_operation();
// Compilation (specialized executable compilation) events should be rare, so
// we can afford to do detailed tracing for every compilation. If compilation
// events happen too often, it is a much larger problem than the excessive
// tracing.
// Custom runner for compiling specializations that schedules compilation task
// into the dedicated thread pool and adds tracing.
auto runner = [kernel_info](size_t specialization,
ArrayRef<OperandConstraint> constraints,
ArrayRef<MemrefDesc> operands,
TaskFunction compile,
JitExecutable::UserData user_data) {
assert(operands.size() == constraints.size());
// Get the context of the request that triggered specialization compilation.
RequestContext* req_ctx = any_cast<RequestContext*>(user_data);
HostContext* host = req_ctx->host();
// Prepare arguments for the compilation tracing in the caller thread,
// because operands lifetime is shorter than the compilation task.
using SpecializationArg = std::pair<std::string, std::string>;
llvm::SmallVector<SpecializationArg> args;
args.reserve(operands.size());
// Trace types of all operands of the specialization.
for (size_t i = 0; i < operands.size(); ++i)
args.emplace_back(StrCat("%arg", i, " type"), AsTensorType(operands[i]));
// Trace content of all operands that require value specializations.
for (size_t i = 0; i < constraints.size(); ++i) {
if (constraints[i] != OperandConstraint::kValue) continue;
args.emplace_back(StrCat("%arg", i, " value"),
AsTensorContent(operands[i]));
}
// Schedule specialization compilation task into the dedicated thread pool.
CompilationThreadPool& thread_pool = CompilationThreadPool::Get(host);
thread_pool.Schedule(
[kernel_info, specialization, request_id = req_ctx->id(),
session_name = GetSessionName(req_ctx), compile = std::move(compile),
args = std::move(args)]() mutable {
TraceMe trace_me([&] {
return TraceMeEncode("tf_jitrt.CompileSpecialization",
{{"id", request_id},
{"kernel_id", kernel_info.id},
{"executable", kernel_info.name},
{"specialization", specialization}});
});
for (SpecializationArg& arg : args) {
trace_me.AppendMetadata([&] {
return TraceMeEncode({{arg.first, arg.second}});
});
}
trace_me.AppendMetadata([&] {
return TraceMeEncode({{"src", kernel_info.serialized_operation}});
});
auto compile_start_time = absl::Now();
LOG(INFO) << "Started JitExecutable specialization compilation for "
<< kernel_info.name << " (" << session_name << ")";
compile();
auto compile_duration = absl::Now() - compile_start_time;
LOG(INFO) << "JitExecutable specialization compilation for "
<< kernel_info.name << " took "
<< absl::ToInt64Milliseconds(compile_duration) << " ms ("
<< session_name << ")";
if (compile_duration > absl::Seconds(1))
LOG(INFO) << "Expensive JitExecutable specialization compilation ("
<< absl::ToInt64Milliseconds(compile_duration)
<< " ms):\n"
<< kernel_info.serialized_operation;
RecordCompileTime(session_name, kernel_info.name, specialization,
compile_duration);
});
};
HostContext* host = exec_ctx.host();
RequestContext* req_ctx = exec_ctx.request_ctx();
// Compile kernel asynchronously in the compilation thread pool.
CompilationThreadPool& thread_pool = CompilationThreadPool::Get(host);
thread_pool.Schedule([kernel_info, runner, workers = *worker_threads,
ref = entry.ptr.CopyRef(), request_id = req_ctx->id(),
session_name = GetSessionName(req_ctx),
tf_jitrt_opts = opts]() {
TraceMe trace_me([&] {
return TraceMeEncode("tf_jitrt.CompileDefault",
{{"id", request_id},
{"kernel_id", kernel_info.id},
{"executable", kernel_info.name},
{"src", kernel_info.serialized_operation}});
});
// Options for the default JitRt compilation pipeline (lowering to LLVM).
CompilationPipelineOptions copts;
copts.alignment = EIGEN_MAX_ALIGN_BYTES; // Eigen included by tensor.h
copts.num_worker_threads = workers->NumThreads();
copts.cost_driven_async_parallel_for =
GetJitRtFlags().cost_driven_async_parallel_for;
// Options for the JitRt JitExecutable compilation.
CompilationOptions opts;
opts.specialization = GetJitRtFlags().always_specialize
? CompilationOptions::Specialization::kAlways
: CompilationOptions::Specialization::kEnabled;
// Register dialects and interfaces required for the compilation pipeline.
opts.register_dialects = [](mlir::DialectRegistry& registry) {
mlir::RegisterAllTensorFlowDialects(registry);
RegisterDefaultJitRtDialects(registry);
};
// Register a custom pipeline for lowering from Tensorflow dialect to LLVM.
opts.create_compilation_pipeline = [=](mlir::PassManager& pm) {
TfJitRtPipelineOptions opts;
if (tf_jitrt_opts) {
opts.vectorize = tf_jitrt_opts->vectorize;
opts.legalize_i1_tensors = tf_jitrt_opts->legalize_i1_tensors;
} else {
opts.vectorize = GetJitRtFlags().vectorize;
}
// Lower from Tensorflow to Linalg on buffers.
CreateTfJitRtPipeline(pm, opts);
// Use default JitRt compilation pipeline to lower to LLVM.
CreateDefaultJitRtCompilationPipeline(pm, copts);
};
// Register a custom pipeline to propagate specialization information.
opts.create_specialization_pipeline = CreateJitRtSpecializationPipeline;
// When lowering Tensorflow functions to JitRt we convert all input and
// result tensors to memrefs, and add a kernel context input.
opts.calling_convention = CompilationOptions::DefaultCallingConvention(
mlir::bufferization::BufferizeTypeConverter());
// Instantiate new JitExecutable from the MLIR source.
auto compile_start_time = absl::Now();
LOG(INFO) << "Started JitExecutable instantiation compilation for "
<< kernel_info.name << " (" << session_name << ")";
Expected<JitExecutable> jit_executable = JitExecutable::Instantiate(
kernel_info.serialized_operation, kernel_info.entrypoint,
std::move(opts), session_name, runner);
auto compile_duration = absl::Now() - compile_start_time;
LOG(INFO) << "JitExecutable instantiation for " << kernel_info.name
<< " took " << absl::ToInt64Milliseconds(compile_duration)
<< " ms (" << session_name << ")";
if (compile_duration > absl::Seconds(1))
LOG(INFO) << "Expensive JitExecutable instantiation ("
<< absl::ToInt64Milliseconds(compile_duration) << " ms):\n"
<< kernel_info.serialized_operation;
RecordCompileTime(session_name, kernel_info.name, absl::nullopt,
compile_duration);
// Set the entry async value state to error or concrete.
if (auto err = jit_executable.takeError())
ref.SetError(std::move(err));
else
ref.emplace(std::move(*jit_executable));
});
return entry.ptr;
}
// -------------------------------------------------------------------------- //
// TFRT kernel function definition for tf_jitrt.fallback.compile operation.
// -------------------------------------------------------------------------- //
// Compiles kernel into the JitExecutable and updates JitExecutableCache.
static AsyncValueRef<Chain> Compile(StringAttribute device,
CompilationUnitAttribute kernel,
const ExecutionContext& exec_ctx) {
// Trigger kernel compilation, that will update the JitExecutableCache.
Expected<AsyncValuePtr<JitExecutable>> executable =
CompileImpl(kernel, exec_ctx);
// Return error if can't schedule the compilation task.
if (auto err = executable.takeError())
return MakeErrorAsyncValueRef(StrCat(err));
// Mark chain available once we compile the default executable.
auto chain = MakeConstructedAsyncValueRef<Chain>();
executable->AndThen([chain]() { chain.SetStateConcrete(); });
return chain;
}
// -------------------------------------------------------------------------- //
// TFRT kernel function definition for tf_jitrt.test.wait_for_compilation.
// -------------------------------------------------------------------------- //
static AsyncValueRef<Chain> WaitForCompilation(
Argument<Chain> chain, CompilationUnitAttribute kernel,
const ExecutionContext& exec_ctx) {
// Request context must be initialized with the tf_jitrt state.
auto* state = exec_ctx.request_ctx()->GetDataIfExists<TfJitRtRequestState>();
if (!state)
return EmitErrorAsync(exec_ctx,
"tf_jitrt state not found in the request context");
// Wait for the completion of all compilation tasks.
JitExecutableCache* jit_executable_cache = state->jit_executable_cache;
if (auto cached = jit_executable_cache->Find(kernel.id()))
return cached->AllExecutablesCompiled();
return MakeAvailableAsyncValueRef<Chain>();
}
// -------------------------------------------------------------------------- //
// TFRT kernel function for tf_jitrt.test.reset_compilation_thread_pool.
// -------------------------------------------------------------------------- //
static AsyncValueRef<Chain> ResetCompilationThreadPool(
Argument<Chain> chain, const ExecutionContext& exec_ctx) {
// Make sure that we reset the compilation thread pool only from a thread pool
// (concurrent work queue) managed by the HostContext.
return EnqueueWork(exec_ctx, [host = exec_ctx.host()]() -> Chain {
CompilationThreadPool::Get(host).Reset();
return {};
});
}
// -------------------------------------------------------------------------- //
// Execute compiled JitRt kernels with Fallback Runtime interop.
// -------------------------------------------------------------------------- //
using ReturnTensorflowTensor =
ReturnValueConversion<TensorflowConversionContext,
ReturnStridedMemref<ConvertTensor>>;
using TensorflowReturnValueConverter =
StaticReturnValueConverter<TensorflowConversionContext,
ReturnTensorflowTensor>;
// Converts Tensor to the Memref Descriptor and verifies that the Tensor
// value is compatible with the memref type.
static void ConvertTensorToMemrefDesc(const tensorflow::Tensor& tensor,
MemrefDesc* memref) {
memref->dtype = tfd::GetTfrtDtype(tensor.dtype());
memref->data = const_cast<void*>(tensor.data());
memref->offset = 0;
int rank = tensor.dims();
memref->sizes.resize_for_overwrite(rank);
memref->strides.resize_for_overwrite(rank);
// Fill memref sizes and compute strides from the tensor dimensions.
int64_t multiplier = 1;
for (int i = rank - 1; i >= 0; --i) {
int64_t dim_size = tensor.dim_size(i);
memref->sizes[i] = dim_size;
memref->strides[i] = multiplier;
multiplier *= dim_size;
}
}
static void ConvertTensorOperandsToMemrefDesc(
RepeatedArguments<FallbackTensor> operands,
llvm::SmallVectorImpl<MemrefDesc>* memrefs) {
assert(memrefs->empty() && "memrefs must be empty");
memrefs->resize(operands.size());
for (unsigned i = 0; i < operands.size(); ++i)
ConvertTensorToMemrefDesc(operands[i].tensor(), &(*memrefs)[i]);
}
struct DebugListener : public SpecializationListener {
void notifyModuleSpecialized(
ArrayRef<mlir::Type> operands,
ArrayRef<mlir::DictionaryAttr> attrs) const override {
std::string message;
llvm::raw_string_ostream os(message);
os << "Specialized operands:\n";
for (auto& tuple : llvm::enumerate(llvm::zip(operands, attrs))) {
mlir::Type type = std::get<0>(tuple.value());
mlir::Attribute attr = std::get<1>(tuple.value());
os << "%arg" << tuple.index() << ": " << type << " " << attr << "\n";
}
printf("%s", message.c_str());
fflush(stdout);
}
void notifyValueSpecialized(unsigned index, mlir::Type type,
mlir::Attribute value) const override {
std::string message;
llvm::raw_string_ostream(message) << "%arg" << index << " "
<< "value specialized: " << value << "\n";
printf("%s", message.c_str());
fflush(stdout);
}
};
// Emits diagnostics for the kernel invocation and returns error for all
// remaining results.
template <typename Error>
static void ReturnErrors(RemainingResults results, Error error,
const ExecutionContext& exec_ctx) {
EmitError(exec_ctx, StrCat(error));
ReturnErrors(results, std::move(error));
}
static void ExecuteImpl(Executable& executable,
const llvm::SmallVectorImpl<MemrefDesc>& memrefs,
RepeatedArguments<FallbackTensor> operands,
RemainingResults results,
const ExecutionContext& exec_ctx) {
// Bind execution trace to the request context.
TraceMe trace_me([&] {
int64_t id = exec_ctx.request_ctx()->id();
absl::string_view name(executable.name().data(), executable.name().size());
return TraceMeEncode(
"tf_jitrt.Execute",
{{"id", id},
{"executable", name},
{"specialization", !executable.specialization().hasValue()
? "default"
: std::to_string(*executable.specialization())},
{"time_to_compile_ms", executable.time_to_compile().count()}});
});
// TODO(ezhulenev): Conversion context and async task runner might not outlive
// the execution of all async tasks, and should be kept alive until all tasks
// are completed, which will require heap allocation(s).
assert(!executable.IsAsync() && "async executables are not yet supported");
// Keep track of memory address to tensor mapping for result conversion.
TensorflowConversionContext ctx(operands.size(), results.size());
for (auto& t : operands)
ctx.runtime_tensors.insert({t.tensor().data(), &t.tensor()});
TensorflowReturnValueConverter converter(results, ctx);
// Get the worker threads from the execution context.
Expected<Eigen::ThreadPoolInterface*> worker_threads =
GetWorkerThreads(exec_ctx);
if (LLVM_UNLIKELY(!worker_threads))
return ReturnErrors(results, worker_threads.takeError(), exec_ctx);
// Use Eigen thread pool to execute all async tasks.
EigenThreadPoolAsyncTaskRunner async_task_runner(*worker_threads);
Executable::ExecuteOpts opts;
opts.async_task_runner = &async_task_runner;
opts.kernel_context = &ctx;
// Execution error automatically forwarded to all results, we only need to
// notify the HostContext to emit the diagnostics for the kernel invocation.
auto err = executable.Execute(memrefs, converter, opts);
if (LLVM_UNLIKELY(err)) {
EmitError(exec_ctx, StrCat(err));
return;
}
}
// Gets a specialized Executable async value from the JitExecutable, and then
// dispatches it inline or using and-then continuation depending on the async
// value state.
static void ExecuteImpl(JitExecutable& jit_executable,
RepeatedArguments<FallbackTensor> operands,
RemainingResults results,
const ExecutionContext& exec_ctx, bool debug) {
// Convert Tensor operands to memref descriptors.
llvm::SmallVector<MemrefDesc> memrefs;
ConvertTensorOperandsToMemrefDesc(operands, &memrefs);
// Get an executable that might be specialized to the operands.
DebugListener debug_listener;
// Pass request context to the compilation task runner.
JitExecutable::UserData user_data = exec_ctx.request_ctx();
Expected<AsyncValuePtr<Executable>> executable = jit_executable.GetExecutable(
memrefs, user_data, debug ? &debug_listener : nullptr);
if (LLVM_UNLIKELY(!executable))
return ReturnErrors(results, executable.takeError(), exec_ctx);
// If executable is available execute it inline ...
if (LLVM_LIKELY(executable->IsConcrete()))
return ExecuteImpl(executable->get(), memrefs, operands, results, exec_ctx);
// ... or maybe return errors.
if (LLVM_UNLIKELY(executable->IsError()))
return ReturnErrors(results, executable->GetError(), exec_ctx);
// Otherwise execute it when the executable will become available. This
// requires careful lifetime extension of all async values passed as operands
// to the kernel (and also results that will become available asynchronously).
// Allocate indirect async values for all results, we'll forward them to the
// actual async values computed by the executable later.
for (unsigned i = 0; i < results.size(); ++i)
results.AllocateIndirectResultAt(i);
// Call executable when it's ready with the original operands.
executable->AndThen([exec_ctx, executable = *executable,
memrefs = std::move(memrefs),
r = RCArray<AsyncValue>(results.values()),
o = RCArray<AsyncValue>(operands.values())] {
// Allocate storage for the executable results.
llvm::SmallVector<RCReference<AsyncValue>> results_storage;
results_storage.resize(r.size());
// Reconstruct arguments and results from captured async values.
RepeatedArguments<FallbackTensor> operands(o.values());
RemainingResults results(results_storage);
if (executable.IsError()) {
ReturnErrors(results, executable.GetError(), exec_ctx);
} else {
ExecuteImpl(*executable, memrefs, operands, results, exec_ctx);
}
// Forward previously allocated indirect results to the actual results.
for (unsigned i = 0; i < r.size(); ++i)
llvm::cast<IndirectAsyncValue>(*r[i]).ForwardTo(
std::move(results_storage[i]));
});
}
// Gets a JitExecutable async value from the cache, and then dispatches it
// inline or using and-then continuation depending on the async value state.
static void ExecuteImpl(RepeatedArguments<FallbackTensor> operands,
RemainingResults results, const StringAttribute& device,
const CompilationUnitAttribute& kernel,
const ExecutionContext& exec_ctx, bool debug = false,
const Optional<TfJitRtPipelineOpts>& opts = None) {
// Compile kernel module into the JitExecutable.
Expected<AsyncValuePtr<JitExecutable>> jit_executable =
CompileImpl(kernel, exec_ctx, opts);
if (LLVM_UNLIKELY(!jit_executable))
return ReturnErrors(results, jit_executable.takeError(), exec_ctx);
// If kernel is available execute it inline ...
if (LLVM_LIKELY(jit_executable->IsConcrete()))
return ExecuteImpl(**jit_executable, operands, results, exec_ctx, debug);
// ... or maybe return errors.
if (LLVM_UNLIKELY(jit_executable->IsError()))
return ReturnErrors(results, jit_executable->GetError(), exec_ctx);
// Otherwise execute it when the executable will become available. This
// requires careful lifetime extension of all async values passed as operands
// to the kernel (and also results that will become available asynchronously).
// Allocate indirect async values for all results, we'll forward them to the
// actual async values computed by the executable later.
for (unsigned i = 0; i < results.size(); ++i)
results.AllocateIndirectResultAt(i);
// Call executable when it's ready with the original operands.
jit_executable->AndThen([exec_ctx, jit_executable = *jit_executable,
r = RCArray<AsyncValue>(results.values()),
o = RCArray<AsyncValue>(operands.values()), debug] {
// Allocate storage for compiled executable results.
llvm::SmallVector<RCReference<AsyncValue>> results_storage;
results_storage.resize(r.size());
// Reconstruct arguments and results from captured async values.
RepeatedArguments<FallbackTensor> operands(o.values());
RemainingResults results(results_storage);
if (jit_executable.IsError()) {
ReturnErrors(results, jit_executable.GetError(), exec_ctx);
} else {
ExecuteImpl(*jit_executable, operands, results, exec_ctx, debug);
}
// Forward previously entry indirect results to the actual results.
for (unsigned i = 0; i < r.size(); ++i)
llvm::cast<IndirectAsyncValue>(*r[i]).ForwardTo(
std::move(results_storage[i]));
});
}
// -------------------------------------------------------------------------- //
// TFRT kernel function definitions for tf_jitrt.fallback.execute operations.
// -------------------------------------------------------------------------- //
// Compiles kernel into the JitExecutable and executes it with the fallback
// tensors operands.
static void Execute(RepeatedArguments<FallbackTensor> operands,
RemainingResults results, StringAttribute device,
CompilationUnitAttribute kernel,
const ExecutionContext& exec_ctx) {
ExecuteImpl(operands, results, device, kernel, exec_ctx);
}
// Compiles kernel into the JitExecutable and executes it with the fallback
// tensors operands in the debug mode: prints compilation diagnostics to the
// standard output. Should be used only in tests for verifying compiler
// internals.
void ExecuteDebug(RepeatedArguments<FallbackTensor> operands,
RemainingResults results,
Attribute<bool> debug_specializations, StringAttribute device,
CompilationUnitAttribute kernel, Attribute<bool> vectorize,
Attribute<bool> legalize_i1_tensors,
const ExecutionContext& exec_ctx) {
TfJitRtPipelineOpts opts;
opts.vectorize = *vectorize;
opts.legalize_i1_tensors = *legalize_i1_tensors;
ExecuteImpl(operands, results, device, kernel, exec_ctx,
*debug_specializations, opts);
}
} // namespace
void RegisterTfJitRuntimeKernels(KernelRegistry* registry) {
registry->AddKernel("tf_jitrt.fallback.compile", TFRT_KERNEL(Compile));
registry->AddKernel("tf_jitrt.fallback.execute", TFRT_KERNEL(Execute));
registry->AddKernel("tf_jitrt.fallback.debug.execute",
TFRT_KERNEL(ExecuteDebug));
registry->AddKernel("tf_jitrt.test.wait_for_compilation",
TFRT_KERNEL(WaitForCompilation));
registry->AddKernel("tf_jitrt.test.reset_compilation_thread_pool",
TFRT_KERNEL(ResetCompilationThreadPool));
}
} // namespace tensorflow