blob: f77d6355a5e30986e2fc8efeb3b6c7e7c1f28443 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
#include <algorithm>
#include <array>
#include <cstdint>
#include <functional>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h"
#include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/threadpool_interface.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/lib/traceme_encode.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
#include "tensorflow/core/tfrt/fallback/fallback_state.h"
#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
#include "tensorflow/core/tfrt/utils/error_util.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
#include "tensorflow/core/tfrt/utils/utils.h"
#include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
#include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
#include "tfrt/host_context/async_value.h" // from @tf_runtime
#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
#include "tfrt/host_context/chain.h" // from @tf_runtime
#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
#include "tfrt/host_context/execution_context.h" // from @tf_runtime
#include "tfrt/host_context/function.h" // from @tf_runtime
#include "tfrt/host_context/host_context.h" // from @tf_runtime
#include "tfrt/host_context/request_deadline_tracker.h" // from @tf_runtime
#include "tfrt/host_context/resource_context.h" // from @tf_runtime
#include "tfrt/support/forward_decls.h" // from @tf_runtime
#include "tfrt/support/ref_count.h" // from @tf_runtime
#include "tfrt/support/string_util.h" // from @tf_runtime
namespace tensorflow {
namespace tfrt_stub {
namespace {
constexpr char kDeadlineExceededMessage[] = "Deadline exceeded.";
constexpr char kTensorNameJoiningDelimiter[] = "-";
constexpr char kArgumentTypeJoiningDelimiter[] = "^";
} // namespace
StatusOr<std::unique_ptr<RequestInfo>> SetUpRequestContext(
const GraphExecutionRunOptions& run_options,
const SessionMetadata& model_metadata, tfrt::HostContext* host,
tensorflow::tfrt_stub::WorkQueueInterface* work_queue,
tfrt::ResourceContext* resource_context,
const tensorflow::tfrt_stub::FallbackState& fallback_state) {
DCHECK(host);
DCHECK(work_queue);
// Create request context and prepare deadline tracker.
// TODO(tfrt-devs): Consider using an ID unique within each model to reduce
// contention.
tfrt::RequestContextBuilder request_context_builder(host, resource_context,
tfrt::GetUniqueInt());
// TODO(b/198671794): `intra_op_threadpool` should be passed through Run()
// directly.
tensorflow::thread::ThreadPoolInterface* intra_op_threadpool = nullptr;
// TODO(b/198671794): The per-request queue should be passed through Run()
// directly.
TF_ASSIGN_OR_RETURN(auto request_queue,
work_queue->InitializeRequest(&request_context_builder,
&intra_op_threadpool));
auto request_info = std::make_unique<RequestInfo>();
// If a per-request queue is not provided, use the original queue in the
// tensorflow::Executor::Args::Runner.
auto* inter_op_queue = request_queue ? request_queue.get() : work_queue;
request_info->runner = [inter_op_queue](std::function<void()> f) {
inter_op_queue->AddTask(std::move(f));
};
request_info->request_queue = std::move(request_queue);
TF_RETURN_IF_ERROR(tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
&request_context_builder, &fallback_state.device_manager(),
&fallback_state.process_function_library_runtime(), intra_op_threadpool,
model_metadata, &request_info->runner));
TF_RETURN_IF_ERROR(
tensorflow::SetUpTfJitRtRequestContext(&request_context_builder));
tfrt::RequestOptions request_options;
request_options.priority = run_options.priority;
request_context_builder.set_request_options(request_options);
auto expected_req_ctx = std::move(request_context_builder).build();
if (!expected_req_ctx) {
return tensorflow::errors::Internal(
tfrt::StrCat(expected_req_ctx.takeError()));
}
request_info->tfrt_request_context = std::move(expected_req_ctx.get());
return request_info;
}
tensorflow::Status GraphExecutionRunOnFunction(
const GraphExecutionOptions& options,
const GraphExecutionRunOptions& run_options,
absl::string_view signature_name, const tfrt::Function& func,
absl::Span<const tensorflow::Tensor> inputs,
absl::Span<const tensorflow::Tensor> captures,
std::vector<tensorflow::Tensor>* outputs,
tfrt::ResourceContext* resource_context, const Runtime& runtime,
const FallbackState& fallback_state,
tfrt::RequestDeadlineTracker& req_deadline_tracker) {
auto* host = runtime.core_runtime()->GetHostContext();
TF_ASSIGN_OR_RETURN(
auto request_info,
SetUpRequestContext(run_options, options.model_metadata, host,
run_options.work_queue ? run_options.work_queue
: runtime.work_queue(),
resource_context, fallback_state));
tensorflow::profiler::TraceMeProducer traceme(
// To TraceMeConsumers in RunHandlerThreadPool::WorkerLoop.
[request_id = request_info->tfrt_request_context->id(), signature_name,
&options] {
return tensorflow::profiler::TraceMeEncode(
"TfrtModelRun",
{{"_r", 1},
{"id", request_id},
{"signature", signature_name},
{"model_id", absl::StrCat(options.model_metadata.name(), ":",
options.model_metadata.version())}});
},
tensorflow::profiler::ContextType::kTfrtExecutor,
request_info->tfrt_request_context->id());
// Only configure timer when the deadline is set.
if (run_options.deadline.has_value()) {
auto deadline = run_options.deadline.value();
if (absl::ToChronoTime(absl::Now()) > deadline) {
return tensorflow::errors::DeadlineExceeded(kDeadlineExceededMessage);
}
req_deadline_tracker.CancelRequestOnDeadline(
deadline, request_info->tfrt_request_context);
}
tfrt::ExecutionContext exec_ctx{request_info->tfrt_request_context};
if (run_options.work_queue) {
// TODO(b/198671794): Avoid creating `request_queue` when the `work_queue`
// in `run_options` is specified.
exec_ctx.set_work_queue(run_options.work_queue);
} else if (request_info->request_queue) {
exec_ctx.set_work_queue(request_info->request_queue.get());
} else {
exec_ctx.set_work_queue(runtime.work_queue());
}
llvm::SmallVector<tfrt::AsyncValue*, 4> arguments;
auto cleanup = tensorflow::gtl::MakeCleanup([&]() {
for (auto* argument : arguments) argument->DropRef();
});
// The first argument is a chain for side-effects. Since SavedModel::Run()
// only returns when side-effects are visible, we can use a ready chain here.
arguments.push_back(tfrt::GetReadyChain().release());
for (const auto& input : inputs) {
arguments.push_back(
tfrt::MakeAvailableAsyncValueRef<FallbackTensor>(input).release());
}
DCHECK(captures.empty()) << "signature should have no captures, which is "
"guaranteed by the compiler";
if (arguments.size() != func.argument_types().size())
return tensorflow::errors::Internal("incorrect number of inputs.");
llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 4> chain_and_results;
chain_and_results.resize(func.result_types().size());
// Hand over the execution to thread pool.
std::array<tfrt::RCReference<tfrt::AsyncValue>, 1> executed = {
EnqueueWork(exec_ctx, [&]() -> tfrt::Chain {
func.Execute(exec_ctx, arguments, chain_and_results);
return {};
})};
// Wait for the function execution before checking chain and results.
exec_ctx.work_queue().Await(executed);
// Wait for all results including the side-effect chain. This ensures that all
// side-effects are visible when SavedModel::Run() returns.
exec_ctx.work_queue().Await(chain_and_results);
DCHECK(!chain_and_results.empty());
tfrt::RCReference<tfrt::AsyncValue>& chain = chain_and_results[0];
auto results = llvm::drop_begin(chain_and_results, 1);
tensorflow::StatusGroup status_group;
if (chain->IsError()) {
status_group.Update(CreateTfErrorStatus(chain->GetError()));
}
for (tfrt::RCReference<tfrt::AsyncValue>& result : results) {
DCHECK(result->IsAvailable());
if (result->IsError()) {
status_group.Update(CreateTfErrorStatus(result->GetError()));
outputs->push_back(tensorflow::Tensor());
continue;
}
// The result must be a host tensor. This is guaranteed as the compiler
// will insert necessary device transfer operations in the graph.
DCHECK(result->IsType<FallbackTensor>());
const auto& host_tensor = result->get<FallbackTensor>().tensor();
// Make a copy of tensor here as the different result AsyncValues might
// point to the same underlying tensor.
outputs->push_back(host_tensor);
}
// TODO(b/171926578): Explicitly clear the context data. Remove it after the
// b/171926578 is fixed.
exec_ctx.request_ctx()->ClearData();
// Check if error is due to cancellation.
// TODO(tfrt-devs): report cancellation reason from runtime.
if (request_info->tfrt_request_context->IsCancelled()) {
// Currently a request can only be cancelled by an expired timer.
return tensorflow::errors::DeadlineExceeded(kDeadlineExceededMessage);
}
return status_group.as_summary_status();
}
std::unique_ptr<tfrt::ResourceContext> CreateResourceContext(
const tensorflow::tfrt_stub::Runtime& runtime,
tfrt::tpu::TpuModelResource* tpu_model_resource,
tensorflow::TfrtTpuInfraTarget tpu_target) {
auto resource_context = std::make_unique<tfrt::ResourceContext>();
runtime.CreateRuntimeResources(resource_context.get());
// TODO(b/178227859): We should make TPU resource init code pluggable, as
// opposed to linking it in. We can do this by adding a callback with
// `Runtime::AddCreateRuntimeResourceFn`.
if (tpu_target == tensorflow::TfrtTpuInfraTarget::kTpurt) {
AddTpuResources(resource_context.get(), tpu_model_resource);
}
return resource_context;
}
StatusOr<std::unique_ptr<GraphExecutor>> GraphExecutor::Create(
Options options, const FallbackState& fallback_state,
tfrt::tpu::TpuModelResource* tpu_model_resource,
tensorflow::GraphDef graph_def) {
if (options.runtime == nullptr) {
return errors::InvalidArgument("options.runtime must be non-null ");
}
TfrtGraphExecutionState::Options graph_execution_state_options;
graph_execution_state_options.run_placer_grappler_on_functions =
options.run_placer_grappler_on_functions;
graph_execution_state_options.enable_tfrt_gpu = options.enable_tfrt_gpu;
TF_ASSIGN_OR_RETURN(
auto graph_execution_state,
TfrtGraphExecutionState::Create(graph_execution_state_options,
std::move(graph_def), fallback_state));
return std::make_unique<GraphExecutor>(std::move(options), fallback_state,
tpu_model_resource,
std::move(graph_execution_state));
}
namespace {
// Sort the strings in `names` and store the results in `sorted_names`. In
// addition, the original index in `names` for the item `sorted_names[i]` is
// stored in `original_indices[i]`.
void CreateSortedNamesAndOriginalIndices(absl::Span<const std::string> names,
std::vector<std::string>& sorted_names,
std::vector<int>& original_indices) {
DCHECK(sorted_names.empty());
DCHECK(original_indices.empty());
// Generate indices.
original_indices.resize(names.size());
std::iota(original_indices.begin(), original_indices.end(), 0);
// Sort indices by comparing the corresponding names.
std::sort(original_indices.begin(), original_indices.end(),
[&](int x, int y) { return names[x] < names[y]; });
// Use sorted indices to generate sorted names.
sorted_names.reserve(names.size());
for (int original_index : original_indices) {
DCHECK_LT(original_index, names.size());
sorted_names.push_back(names[original_index]);
}
}
} // namespace
tensorflow::Status GraphExecutor::Run(
const RunOptions& run_options,
absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
absl::Span<const std::string> output_tensor_names,
absl::Span<const std::string> target_tensor_names,
std::vector<tensorflow::Tensor>* outputs) {
// TODO(b/192498110): Validate input type.
// Sort the input/output names to have a stable order, so that the
// `joined_name`, which is used as the cache key, will be the same as long as
// the same set of inputs/outputs are specified.
std::vector<std::string> input_names;
input_names.reserve(inputs.size());
for (const auto& p : inputs) input_names.push_back(p.first);
std::vector<std::string> sorted_input_names;
std::vector<int> input_original_indices;
CreateSortedNamesAndOriginalIndices(input_names, sorted_input_names,
input_original_indices);
// We also need to create sorted input dtypes as they are needed for the
// compilation.
std::vector<tensorflow::DataType> sorted_input_dtypes;
sorted_input_dtypes.reserve(inputs.size());
for (int original_index : input_original_indices) {
sorted_input_dtypes.push_back(inputs.at(original_index).second.dtype());
}
std::vector<std::string> sorted_output_names;
std::vector<int> output_original_indices;
CreateSortedNamesAndOriginalIndices(output_tensor_names, sorted_output_names,
output_original_indices);
// For target node names, we only need to sort them. The original indices are
// not needed.
std::vector<std::string> sorted_target_node_names(target_tensor_names.begin(),
target_tensor_names.end());
std::sort(sorted_target_node_names.begin(), sorted_target_node_names.end());
// Load the client graph.
TF_ASSIGN_OR_RETURN(const LoadedClientGraph& loaded_client_graph,
GetOrCreateLoadedClientGraph(
sorted_input_names, sorted_input_dtypes,
sorted_output_names, sorted_target_node_names));
const auto* func = loaded_client_graph.bef_file->GetFunction(
tensorflow::kImportModelDefaultGraphFuncName);
DCHECK(func);
// Create the actual arguments to the compiled function, which are sorted
// according to the input tensor names.
std::vector<tensorflow::Tensor> flat_inputs;
flat_inputs.reserve(inputs.size());
for (int original_index : input_original_indices) {
flat_inputs.push_back(inputs.at(original_index).second);
}
std::vector<tensorflow::Tensor> flat_outputs;
TF_RETURN_IF_ERROR(GraphExecutionRunOnFunction(
options_, run_options, loaded_client_graph.name, *func, flat_inputs,
/*captures=*/{}, &flat_outputs,
loaded_client_graph.resource_context.get(), runtime(), fallback_state_,
req_deadline_tracker_));
// Create the outputs from the actual function results, which are sorted
// according to the output tensor names.
auto flat_output_iter = flat_outputs.begin();
outputs->resize(flat_outputs.size());
for (int original_index : output_original_indices) {
(*outputs)[original_index] = std::move(*flat_output_iter);
++flat_output_iter;
}
return tensorflow::Status::OK();
}
tensorflow::Status GraphExecutor::Extend(const GraphDef& graph) {
return graph_execution_state_->Extend(graph);
}
StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>>
GraphExecutor::LoadClientGraph(const GraphExecutor::ClientGraph& client_graph) {
auto loaded_client_graph = std::make_unique<LoadedClientGraph>();
loaded_client_graph->name = client_graph.name;
loaded_client_graph->resource_context = CreateResourceContext(
runtime(), tpu_model_resource_, options_.compile_options.tpu_target);
// Step 1: Import the client graph from proto to an MLIR module.
mlir::MLIRContext context;
TF_ASSIGN_OR_RETURN(auto module,
ImportClientGraphToMlirModule(client_graph, &context));
// Step 2: Compile the MLIR module from TF dialect to TFRT dialect (in BEF).
TF_ASSIGN_OR_RETURN(loaded_client_graph->bef,
CompileMlirModuleToBef(module.get()));
// Step 3: Initialize runtime states using special BEF functions.
TF_ASSIGN_OR_RETURN(
loaded_client_graph->bef_file,
tfrt::CreateBefFileFromBefBuffer(runtime(), loaded_client_graph->bef));
TF_RETURN_IF_ERROR(InitBef(loaded_client_graph->bef_file.get(),
loaded_client_graph->resource_context.get()));
return loaded_client_graph;
}
tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
GraphExecutor::ImportClientGraphToMlirModule(
const GraphExecutor::ClientGraph& client_graph,
mlir::MLIRContext* context) const {
tensorflow::GraphImportConfig graph_import_config;
graph_import_config.prune_unused_nodes = true;
graph_import_config.enable_shape_inference = false;
graph_import_config.inputs = client_graph.input_nodes;
graph_import_config.outputs = client_graph.output_nodes;
graph_import_config.control_outputs = client_graph.target_nodes;
// Optimize the graph.
TF_ASSIGN_OR_RETURN(
auto optimized_graph,
graph_execution_state_->CreateOptimizedGraph(graph_import_config));
// Convert the optimized graph to an MLIR module.
return tensorflow::ConvertGraphToMlir(
*optimized_graph.graph, /*debug_info=*/{},
optimized_graph.graph->flib_def(), graph_import_config, context);
}
StatusOr<tfrt::BefBuffer> GraphExecutor::CompileMlirModuleToBef(
mlir::ModuleOp module) const {
tfrt::BefBuffer bef;
TF_RETURN_IF_ERROR(
tensorflow::ConvertTfMlirToBef(options_.compile_options, module, &bef));
return bef;
}
tensorflow::Status GraphExecutor::InitBef(
tfrt::BEFFile* bef_file, tfrt::ResourceContext* resource_context) {
auto* host = runtime().core_runtime()->GetHostContext();
TF_ASSIGN_OR_RETURN(
auto request_info,
SetUpRequestContext(/*run_options=*/{}, /*model_metadata=*/{}, host,
runtime().work_queue(), resource_context,
fallback_state_));
tfrt::ExecutionContext exec_ctx(request_info->tfrt_request_context);
// Run "_tfrt_fallback_init" first to initialize fallback-specific states. It
// is the special function created by compiler, which calls a sequence of
// tfrt_fallback_async.createop to create all fallback ops used in this BEF.
TF_RETURN_IF_ERROR(
RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_fallback_init"));
// After we initialized all the resources in the original graph, we can run
// the "_tfrt_resource_init" function to set these resources in runtime
// states, so that later it can be efficiently retrieved without any locking.
TF_RETURN_IF_ERROR(
RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_resource_init"));
return tensorflow::Status::OK();
}
StatusOr<std::reference_wrapper<const GraphExecutor::LoadedClientGraph>>
GraphExecutor::GetOrCreateLoadedClientGraph(
absl::Span<const std::string> input_tensor_names,
absl::Span<const tensorflow::DataType> input_tensor_dtypes,
absl::Span<const std::string> output_tensor_names,
absl::Span<const std::string> target_tensor_names) {
// The format of the joined name is illustrated as in the following example:
// input1-input2^output1-output2^target1-target2
const auto joined_name = absl::StrCat(
absl::StrJoin(input_tensor_names, kTensorNameJoiningDelimiter),
kArgumentTypeJoiningDelimiter,
absl::StrJoin(output_tensor_names, kTensorNameJoiningDelimiter),
kArgumentTypeJoiningDelimiter,
absl::StrJoin(target_tensor_names, kTensorNameJoiningDelimiter));
tensorflow::mutex_lock l(loaded_client_graphs_mu_);
// Cache hit; return immediately.
const auto iter = loaded_client_graphs_.find(joined_name);
if (iter != loaded_client_graphs_.end()) return {*iter->second};
// Cache miss; populate a `ClientGraph` and load it.
tensorflow::GraphImportConfig::InputArrays input_nodes;
DCHECK_EQ(input_tensor_names.size(), input_tensor_dtypes.size());
for (int i = 0; i < input_tensor_names.size(); ++i) {
const auto& input_name = input_tensor_names[i];
auto input_dtype = input_tensor_dtypes[i];
tensorflow::ArrayInfo array_info;
array_info.imported_dtype = input_dtype;
array_info.shape.set_unknown_rank(true);
input_nodes[input_name] = array_info;
}
ClientGraph client_graph{
joined_name,
std::move(input_nodes),
{output_tensor_names.begin(), output_tensor_names.end()},
{target_tensor_names.begin(), target_tensor_names.end()}};
TF_ASSIGN_OR_RETURN(auto loaded_client_graph, LoadClientGraph(client_graph));
// Store the new loaded client graph in cache and return.
const auto* loaded_client_graph_ptr = loaded_client_graph.get();
loaded_client_graphs_[joined_name] = std::move(loaded_client_graph);
return {*loaded_client_graph_ptr};
}
} // namespace tfrt_stub
} // namespace tensorflow