blob: c550cb30598f805a1820e3924a93d5f8c650922e [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tfrt/saved_model/saved_model.h"
#include <algorithm>
#include <cstddef>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h"
#include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function_def_utils.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/platform/enable_tf2_utils.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/profiler/lib/traceme_encode.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
#include "tensorflow/core/runtime_fallback/util/tensor_util.h"
#include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
#include "tensorflow/core/tfrt/saved_model/saved_model_import_input.h"
#include "tensorflow/core/tfrt/tpu/tpu_resources.h"
// TODO(b/200579737): using FunctionRegistry is simpler than the OSS trick.
#include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
#include "tensorflow/core/tfrt/utils/error_util.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
#include "tensorflow/core/tfrt/utils/tensor_util.h"
#include "tensorflow/core/tfrt/utils/utils.h"
#include "tfrt/bef_executor/bef_file.h" // from @tf_runtime
#include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
#include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime
#include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
#include "tfrt/host_context/async_value.h" // from @tf_runtime
#include "tfrt/host_context/chain.h" // from @tf_runtime
#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
#include "tfrt/host_context/execution_context.h" // from @tf_runtime
#include "tfrt/host_context/function.h" // from @tf_runtime
#include "tfrt/host_context/host_context.h" // from @tf_runtime
#include "tfrt/host_context/request_deadline_tracker.h" // from @tf_runtime
#include "tfrt/metrics/common_metrics.h" // from @tf_runtime
#include "tfrt/support/error_util.h" // from @tf_runtime
#include "tfrt/support/logging.h" // from @tf_runtime
#include "tfrt/support/ref_count.h" // from @tf_runtime
namespace tensorflow {
namespace tfrt_stub {
namespace {
constexpr absl::string_view kSignatureJoiningDelimiter = "+";
constexpr absl::string_view kTensorNameJoiningDelimiter = "-";
constexpr absl::string_view kArgumentTypeJoiningDelimiter = "^";
using SignatureMap = absl::flat_hash_map<std::string, internal::Signature>;
using ::tensorflow::SessionMetadata;
using ::tensorflow::StatusOr;
struct InitializersAndSignatures {
llvm::SmallVector<std::string, 4> initializers;
SignatureMap signature_map;
};
auto* saved_model_read_meta_graph_time_seconds =
tensorflow::monitoring::Gauge<int64_t, 1>::New(
"/tensorflow/tfrt/saved_model/read_meta_graph_time",
"Record the time of reading meta_graph from disk.", "model_name");
auto* saved_model_functionalization_time_seconds =
tensorflow::monitoring::Gauge<int64_t, 1>::New(
"/tensorflow/tfrt/saved_model/functionalization_time",
"Record the functionalization time for the savedmodel.", "model_name");
auto* saved_model_grappler_time_seconds =
tensorflow::monitoring::Gauge<int64_t, 1>::New(
"/tensorflow/tfrt/saved_model/grappler_time",
"Record the grappler time for the savedmodel.", "model_name");
auto* saved_model_import_time_seconds =
tensorflow::monitoring::Gauge<int64_t, 1>::New(
"/tensorflow/tfrt/saved_model/import_time",
"Record the MLIR import time for the savedmodel.", "model_name");
auto* saved_model_compile_time_seconds =
tensorflow::monitoring::Gauge<int64_t, 1>::New(
"/tensorflow/tfrt/saved_model/compile_time",
"Record the compilation time for the savedmodel.", "model_name");
auto* saved_model_init_time_seconds =
tensorflow::monitoring::Gauge<int64_t, 1>::New(
"/tensorflow/tfrt/saved_model/init_time",
"Record the initialization time for the savedmodel.", "model_name");
tensorflow::Tensor CreateScalarStringTensor(absl::string_view str) {
return tensorflow::Tensor(tensorflow::tstring(str));
}
// Create the tensor for the bound input, which can be a variable or an asset.
//
// TODO(chky): For V2 models, the bound input can also be a resource.
StatusOr<tensorflow::Tensor> CreateTensorFromBoundInput(
mlir::Operation* bound_input, absl::string_view saved_model_dir,
absl::flat_hash_map<std::string, tensorflow::Tensor>* variables) {
// Assets are files in the saved model directory. We pass their filenames to
// functions so that they can be used.
if (auto asset = llvm::dyn_cast<mlir::tf_saved_model::AssetOp>(bound_input)) {
// The filename in the asset is a relative path. So we prefix it with the
// directory path.
return CreateScalarStringTensor(
tensorflow::io::JoinPath(saved_model_dir, asset.filename().str()));
}
return tensorflow::errors::Internal(
"Failed to create captured tensors: unknown bound input type.");
}
StatusOr<SignatureMap> GetFunctionSignaturesFromTFSavedModelMLIR(
absl::string_view saved_model_dir, mlir::ModuleOp module) {
absl::flat_hash_map<std::string, tensorflow::Tensor> variables;
SignatureMap signatures;
tensorflow::StatusGroup status_group;
TF_RETURN_IF_ERROR(tensorflow::MapFunctionSignaturesFromTFSavedModelMLIR(
module, [&status_group, &variables, &signatures, saved_model_dir](
const tensorflow::TFRTSavedModelSignatureInfo& sig_info) {
auto& signature = signatures[std::string(sig_info.func_name)];
auto copy = [](llvm::ArrayRef<llvm::StringRef> src,
std::vector<std::string>* dst) {
transform(src, std::back_inserter(*dst),
[](llvm::StringRef x) { return x.str(); });
};
copy(sig_info.input_names, &signature.input_names);
copy(sig_info.output_names, &signature.output_names);
copy(sig_info.input_devices, &signature.input_devices);
DCHECK(signature.input_specs.empty());
signature.input_specs.reserve(sig_info.input_specs.size());
for (auto& spec : sig_info.input_specs) {
signature.input_specs.push_back(TensorSpec(spec.first, spec.second));
}
DCHECK(signature.output_specs.empty());
signature.output_specs.reserve(sig_info.output_specs.size());
for (auto& spec : sig_info.output_specs) {
signature.output_specs.push_back(TensorSpec(spec.first, spec.second));
}
for (auto* bound_input : sig_info.bound_inputs) {
auto statusor_capture = CreateTensorFromBoundInput(
bound_input, saved_model_dir, &variables);
if (!statusor_capture.ok()) {
status_group.Update(statusor_capture.status());
// Insert a random tensor in case of errors.
signature.captures.push_back(tensorflow::Tensor());
} else {
signature.captures.push_back(
std::move(statusor_capture).ValueOrDie());
}
}
}));
if (!status_group.ok()) return status_group.as_concatenated_status();
return signatures;
}
tensorflow::Status RunInitializers(
const InitializersAndSignatures& initializers_and_signatures,
const SessionMetadata& model_metadata, tfrt::BEFFile* bef_file,
const Runtime& runtime, tfrt::ResourceContext* resource_context,
const FallbackState& fallback_state) {
auto* host = runtime.core_runtime()->GetHostContext();
TF_ASSIGN_OR_RETURN(auto request_info,
SetUpRequestContext(/*run_options=*/{}, model_metadata,
host, runtime.work_queue(),
resource_context, fallback_state));
tfrt::ExecutionContext exec_ctx(request_info->tfrt_request_context);
// Run "_tfrt_fallback_init" first to initialize fallback-specific states. It
// is the special function created by compiler, which calls a sequence of
// tfrt_fallback_async.createop to create all fallback ops used in this BEF.
TF_RETURN_IF_ERROR(
RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_fallback_init"));
for (const auto& init : initializers_and_signatures.initializers) {
// TODO(b/184771263): Consider using `GraphExecutionRunOnFunction()`
// instead.
auto* func = bef_file->GetFunction(init);
assert(func);
const auto& signature = initializers_and_signatures.signature_map.at(init);
auto ready_chain = tfrt::GetReadyChain();
// The actual arguments are the concat of side-effect chain and assets.
llvm::SmallVector<tfrt::AsyncValue*, 1> arguments;
auto cleanup = tensorflow::gtl::MakeCleanup([&]() {
for (auto* argument : arguments) argument->DropRef();
});
arguments.push_back(ready_chain.release());
for (const auto& capture : signature.captures) {
arguments.push_back(
tfrt::MakeAvailableAsyncValueRef<FallbackTensor>(capture).release());
}
assert(arguments.size() == func->argument_types().size());
llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 1> results;
results.resize(func->result_types().size());
assert(results.size() == 1);
func->Execute(exec_ctx, arguments, results);
// Wait for the function execution to finish, as well as the side-effects.
host->Await(results);
if (auto* error = results[0]->GetErrorIfPresent()) {
return CreateTfErrorStatus(*error);
}
}
// After we initialized all the resources in the original graph, we can run
// the "_tfrt_resource_init" function to set these resources in runtime
// states, so that later it can be efficiently retrieved without any locking.
TF_RETURN_IF_ERROR(
RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_resource_init"));
return tensorflow::Status::OK();
}
std::vector<std::string> FindNamesForValidSignatures(
const tensorflow::MetaGraphDef& meta_graph_def) {
std::vector<std::string> valid_signature_names;
auto is_dense_tensor_info = [](const auto& named_tensor_info) {
return !named_tensor_info.second.name().empty();
};
auto is_ref_type_tensor_info = [](const auto& named_tensor_info) {
return tensorflow::IsRefType(named_tensor_info.second.dtype());
};
for (const auto& iter : meta_graph_def.signature_def()) {
const auto& sig_key = iter.first;
const auto& signature = iter.second;
if (!std::all_of(signature.inputs().begin(), signature.inputs().end(),
is_dense_tensor_info) ||
!std::all_of(signature.outputs().begin(), signature.outputs().end(),
is_dense_tensor_info)) {
LOG(WARNING) << "Unsupported signature with non-dense tensors as "
"input/output. Name: "
<< sig_key << "; Signature: " << signature.DebugString();
continue;
}
if (std::any_of(signature.inputs().begin(), signature.inputs().end(),
is_ref_type_tensor_info) ||
std::any_of(signature.outputs().begin(), signature.outputs().end(),
is_ref_type_tensor_info)) {
LOG(WARNING) << "Unsupported signature with ref type tensors as "
"input/output. Name: "
<< sig_key << "; Signature: " << signature.DebugString();
continue;
}
valid_signature_names.push_back(sig_key);
}
return valid_signature_names;
}
StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSavedModel(
mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def,
const FallbackState& fallback_state, std::string saved_model_dir,
bool import_user_signatures, bool run_placer_grappler_on_functions,
bool enable_tfrt_gpu) {
std::vector<std::string> signature_names;
if (import_user_signatures) {
signature_names = FindNamesForValidSignatures(meta_graph_def);
if (signature_names.empty())
LOG(WARNING) << "No valid signature found for model: " << saved_model_dir;
}
// TfrtSavedModelMLIRImportInput basically implements the graph processing
// logic (eg. Placer and Grappler) used in DirectSession, which apply graph
// transformations on each subgraphs (ie. signatures). It is reusing the
// code path in DirectSession to avoid problems caused by different behavior
// in a different code path. And it is injected to the MLIR importer so that
// the importer can import the transformed graph instead of the original
// graph.
TF_ASSIGN_OR_RETURN(auto import_input,
TfrtSavedModelMLIRImportInput::Create(
fallback_state, &meta_graph_def, /*debug_info=*/{},
run_placer_grappler_on_functions, enable_tfrt_gpu));
TF_ASSIGN_OR_RETURN(
auto module,
tensorflow::ConvertSavedModelV1ToMlirLite(
import_input,
/*exported_names=*/absl::MakeSpan(signature_names), context));
LOG(INFO) << "TFRT ImportSavedModel: Functionalization took "
<< absl::ToInt64Milliseconds(
import_input.GetFunctionalizationDuration())
<< " ms.";
LOG(INFO) << "TFRT ImportSavedModel: Grappler took "
<< absl::ToInt64Milliseconds(import_input.GetGrapplerDuration())
<< " ms.";
saved_model_functionalization_time_seconds->GetCell(saved_model_dir)
->Set(absl::ToInt64Seconds(import_input.GetFunctionalizationDuration()));
saved_model_grappler_time_seconds->GetCell(saved_model_dir)
->Set(absl::ToInt64Seconds(import_input.GetGrapplerDuration()));
return module;
}
StatusOr<InitializersAndSignatures> GetInitializersAndSignatures(
mlir::ModuleOp module, absl::string_view saved_model_dir) {
InitializersAndSignatures result;
TF_ASSIGN_OR_RETURN(
result.signature_map,
GetFunctionSignaturesFromTFSavedModelMLIR(saved_model_dir, module));
for (auto session_initializer_name :
mlir::tf_saved_model::GetSessionInitializerExportedName(module)) {
result.initializers.push_back(session_initializer_name.str());
}
return result;
}
tensorflow::Status InitSavedModel(
const InitializersAndSignatures& initializers_and_signatures,
tfrt::BEFFile* bef_file, const SavedModel::Options& options,
tfrt::ResourceContext* resource_context,
const FallbackState& fallback_state) {
TF_RETURN_IF_ERROR(
RunInitializers(initializers_and_signatures,
options.graph_execution_options.model_metadata, bef_file,
*options.graph_execution_options.runtime,
resource_context, fallback_state));
return tensorflow::Status::OK();
}
} // namespace
SavedModel::~SavedModel() {}
tfrt::HostContext* SavedModel::GetHostContext() const {
return runtime_->core_runtime()->GetHostContext();
}
std::unique_ptr<SavedModel> SavedModelImpl::LoadSavedModel(
Options options, absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags, tensorflow::Status* status) {
LOG(INFO) << "TFRT reading v1 savedmodel: " << saved_model_dir;
auto read_start_time = absl::Now();
tensorflow::MetaGraphDef meta_graph_def;
auto read_status = tensorflow::ReadMetaGraphDefFromSavedModel(
std::string(saved_model_dir), tags, &meta_graph_def);
if (!read_status.ok()) {
*status = read_status;
return nullptr;
}
auto read_meta_graph_duration = absl::Now() - read_start_time;
saved_model_read_meta_graph_time_seconds
->GetCell(std::string(saved_model_dir))
->Set(absl::ToInt64Seconds(read_meta_graph_duration));
LOG(INFO) << "TFRT finished reading meta graph. Took "
<< absl::ToInt64Milliseconds(read_meta_graph_duration) << " ms.";
return LoadSavedModel(std::move(options), saved_model_dir,
std::move(meta_graph_def), status);
}
namespace {
// Gets the signatures from `signature_defs` and inserts them into `signatures`.
void GetSignaturesFromSignatureDef(
SignatureMap& signatures,
const google::protobuf::Map<std::string, tensorflow::SignatureDef>& signature_defs,
const SavedModel::Options& options) {
for (const auto& p : signature_defs) {
const std::string& signature_name = p.first;
const tensorflow::SignatureDef& signature_def = p.second;
DCHECK(signatures.find(signature_name) == signatures.end());
auto& signature = signatures[signature_name];
signature.input_names.reserve(signature_def.inputs().size());
signature.input_specs.reserve(signature_def.inputs().size());
for (const auto& p : signature_def.inputs()) {
const std::string& input_tensor_name = p.first;
const tensorflow::TensorInfo& tensor_info = p.second;
signature.input_names.push_back(input_tensor_name);
signature.input_specs.push_back(
TensorSpec(tensor_info.dtype(), tensor_info.tensor_shape()));
}
signature.input_devices = std::vector<std::string>(
signature_def.inputs().size(),
options.graph_execution_options.compile_options.default_device);
signature.output_names.reserve(signature_def.outputs().size());
signature.output_specs.reserve(signature_def.outputs().size());
for (const auto& p : signature_def.outputs()) {
const std::string& output_tensor_name = p.first;
const tensorflow::TensorInfo& tensor_info = p.second;
signature.output_names.push_back(output_tensor_name);
signature.output_specs.push_back(
TensorSpec(tensor_info.dtype(), tensor_info.tensor_shape()));
}
}
}
} // namespace
std::unique_ptr<SavedModel> SavedModelImpl::LoadSavedModel(
Options options, absl::string_view saved_model_dir,
tensorflow::MetaGraphDef meta_graph_def, tensorflow::Status* status) {
LOG(INFO) << "TFRT loading v1 savedmodel: " << saved_model_dir;
tfrt::metrics::AddTFRTVersionMetric();
UpdateTpuTargetByBridgeCompatibility(options.graph_execution_options,
meta_graph_def.graph_def());
auto statusor_saved_model =
[&]() -> tensorflow::StatusOr<std::unique_ptr<SavedModel>> {
mlir::MLIRContext context;
// Step 1: Import saved model from a proto to an MLIR module.
auto import_start_time = absl::Now();
auto session_options =
CreateDefaultSessionOptions(options.graph_execution_options);
// Set optimize_for_static_graph to true since we won't extend the graph
// later. If optimize_for_static_graph is set to false, FallbackState will
// keep an extra unused copy of the graph, which unnecessarily consumes
// memory.
session_options.config.mutable_experimental()
->set_optimize_for_static_graph(true);
// Creating the fallback_state using the original function def library
// without applying placer or grappler, it is OK for now because it's only
// used for captured functions in certain tf.data ops
const auto& fdef_lib = meta_graph_def.graph_def().library();
TF_ASSIGN_OR_RETURN(auto fallback_state,
FallbackState::Create(session_options, fdef_lib));
TF_ASSIGN_OR_RETURN(
auto mlir_module,
ImportSavedModel(
&context, meta_graph_def, *fallback_state,
std::string(saved_model_dir),
/*import_user_signatures=*/!options.enable_lazy_loading,
options.graph_execution_options.run_placer_grappler_on_functions,
options.graph_execution_options.enable_tfrt_gpu));
auto import_duration = absl::Now() - import_start_time;
saved_model_import_time_seconds->GetCell(std::string(saved_model_dir))
->Set(absl::ToInt64Seconds(import_duration));
LOG(INFO) << "TFRT finished importing savedmodel. Took "
<< absl::ToInt64Milliseconds(import_duration) << " ms.";
// Step 2: Compile the MLIR module from TF dialect to TFRT dialect (in BEF).
auto compile_start_time = absl::Now();
TF_ASSIGN_OR_RETURN(
auto initializers_and_signatures,
GetInitializersAndSignatures(mlir_module.get(), saved_model_dir));
// If lazy loading is enabled, the user signatures are not exported via MLIR
// module, so we need to get them from the proto.
// TODO(b/187228559): Unify the code paths for populating the signature map.
if (options.enable_lazy_loading) {
GetSignaturesFromSignatureDef(initializers_and_signatures.signature_map,
meta_graph_def.signature_def(), options);
}
tfrt::BefBuffer bef;
TF_RETURN_IF_ERROR(tensorflow::ConvertTfMlirToBef(
options.graph_execution_options.compile_options, mlir_module.get(),
&bef));
auto compile_duration = absl::Now() - compile_start_time;
saved_model_compile_time_seconds->GetCell(std::string(saved_model_dir))
->Set(absl::ToInt64Seconds(compile_duration));
LOG(INFO) << "TFRT finished compiling savedmodel. Took "
<< absl::ToInt64Milliseconds(compile_duration) << " ms.";
// Step 3: Initialize runtime states using special BEF functions.
auto init_start_time = absl::Now();
TF_ASSIGN_OR_RETURN(auto bef_file,
tfrt::CreateBefFileFromBefBuffer(
*options.graph_execution_options.runtime, bef));
auto tpu_model_resource = std::make_unique<tfrt::tpu::TpuModelResource>();
auto resource_context = CreateResourceContext(
*options.graph_execution_options.runtime, tpu_model_resource.get(),
options.graph_execution_options.compile_options.tpu_target);
TF_RETURN_IF_ERROR(InitSavedModel(initializers_and_signatures,
bef_file.get(), options,
resource_context.get(), *fallback_state));
auto init_duration = absl::Now() - init_start_time;
saved_model_init_time_seconds->GetCell(std::string(saved_model_dir))
->Set(absl::ToInt64Seconds(init_duration));
LOG(INFO) << "TFRT finished initializing savedmodel. Took "
<< absl::ToInt64Milliseconds(init_duration) << " ms.";
TF_ASSIGN_OR_RETURN(
auto graph_executor,
GraphExecutor::Create(options.graph_execution_options, *fallback_state,
tpu_model_resource.get(),
std::move(*meta_graph_def.mutable_graph_def())));
// Finally, create the saved model.
return {std::make_unique<SavedModelImpl>(
std::move(options), std::move(meta_graph_def), std::move(bef),
std::move(bef_file),
std::move(initializers_and_signatures.signature_map),
std::move(fallback_state), std::move(tpu_model_resource),
std::move(resource_context), std::move(graph_executor))};
}();
if (!statusor_saved_model.ok()) {
*status = statusor_saved_model.status();
return nullptr;
}
*status = tensorflow::Status::OK();
return std::move(statusor_saved_model).ValueOrDie();
}
SavedModelImpl::SavedModelImpl(
Options options, tensorflow::MetaGraphDef meta_graph_def,
tfrt::BefBuffer bef, tfrt::RCReference<tfrt::BEFFile> bef_file,
SignatureMap signatures, std::unique_ptr<FallbackState> fallback_state,
std::unique_ptr<tfrt::tpu::TpuModelResource> tpu_model_resource,
std::unique_ptr<tfrt::ResourceContext> resource_context,
std::unique_ptr<GraphExecutor> graph_executor)
: SavedModel(options.graph_execution_options.runtime),
options_(std::move(options)),
meta_graph_def_(std::move(meta_graph_def)),
bef_(std::move(bef)),
bef_file_(std::move(bef_file)),
req_deadline_tracker_(
options.graph_execution_options.runtime->core_runtime()
->GetHostContext()),
signatures_(std::move(signatures)),
fallback_state_(std::move(fallback_state)),
tpu_model_resource_(std::move(tpu_model_resource)),
resource_context_(std::move(resource_context)),
graph_executor_(std::move(graph_executor)) {}
SavedModelImpl::~SavedModelImpl() = default;
std::vector<std::string> SavedModelImpl::GetFunctionNames() const {
std::vector<std::string> result;
for (const auto& entry : signatures_) {
result.push_back(entry.first);
}
return result;
}
const tensorflow::MetaGraphDef& SavedModelImpl::GetMetaGraphDef() const {
return meta_graph_def_;
}
absl::optional<FunctionMetadata> SavedModelImpl::GetFunctionMetadata(
absl::string_view func_name) const {
auto iter = signatures_.find(func_name);
if (iter == signatures_.end()) return absl::nullopt;
return FunctionMetadata(&iter->second);
}
namespace {
tensorflow::Status IsInputSpecsCorrect(
absl::string_view name, const internal::Signature& signature,
absl::Span<const tensorflow::Tensor> inputs) {
TF_RET_CHECK(signature.input_specs.size() == inputs.size())
<< "signature " << name
<< " input size is wrong, expected: " << signature.input_specs.size()
<< ", actual: " << inputs.size();
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& expected_input_spec = signature.input_specs[i];
TF_RET_CHECK(expected_input_spec.dtype == inputs[i].dtype())
<< "signature " << name
<< " input dtype is wrong, expected: " << expected_input_spec.dtype
<< ", actual: " << inputs[i].dtype();
TF_RET_CHECK(expected_input_spec.shape.IsCompatibleWith(inputs[i].shape()))
<< "signature " << name
<< " input shape is wrong, expected : " << expected_input_spec.shape
<< ", actual: " << inputs[i].shape();
}
return tensorflow::Status::OK();
}
} // namespace
tensorflow::Status SavedModelImpl::Run(
const RunOptions& run_options, absl::string_view name,
absl::Span<const tensorflow::Tensor> inputs,
std::vector<tensorflow::Tensor>* outputs) {
TF_RET_CHECK(outputs) << "outputs must be provided";
outputs->clear();
auto sig_iter = signatures_.find(name);
TF_RET_CHECK(sig_iter != signatures_.end())
<< "failed to find signature " << name << " in the graph";
if (run_options.validate_input_specs) {
TF_RETURN_IF_ERROR(IsInputSpecsCorrect(name, sig_iter->second, inputs));
}
std::vector<tensorflow::Tensor> captures;
for (const auto& capture : sig_iter->second.captures) {
captures.push_back(capture);
}
const tfrt::Function* func;
tfrt::ResourceContext* resource_context;
if (options_.enable_lazy_loading) {
// If lazy loading is enabled, no signature is loaded into `bef_file_`, so
// we need to find the BEF from the cache or create one.
TF_ASSIGN_OR_RETURN(const LoadingResult& loading_result,
GetOrCreateLoadingResult({std::string(name)}));
func = loading_result.bef_file->GetFunction(
tensorflow::kImportModelDefaultGraphFuncName);
resource_context = loading_result.resource_context.get();
} else {
func = bef_file_->GetFunction({name.data(), name.size()});
resource_context = resource_context_.get();
}
DCHECK(func);
return GraphExecutionRunOnFunction(options_.graph_execution_options,
run_options, name, *func, inputs, captures,
outputs, resource_context, runtime(),
*fallback_state_, req_deadline_tracker_);
}
struct SavedModelImpl::JoinedSignature {
// A unique name for the joined signature.
std::string name;
// The feed nodes for the corresponding inputs, but they might not be in the
// original order and if there are more than one original inputs mapped to the
// same feed node, only one is picked here.
tensorflow::GraphImportConfig::InputArrays input_nodes;
// The fetch nodes for the outputs, which should be in the original order.
std::vector<std::string> output_nodes;
// The target nodes that should be run but not returned as outputs.
std::vector<std::string> target_nodes;
};
tensorflow::Status SavedModelImpl::RunMultipleSignatures(
const RunOptions& run_options, absl::Span<const std::string> names,
absl::Span<const std::vector<tensorflow::Tensor>> multi_inputs,
std::vector<std::vector<tensorflow::Tensor>>* multi_outputs) {
TF_RET_CHECK(names.size() == multi_inputs.size())
<< "the sizes of names and inputs should be the same";
TF_RET_CHECK(multi_outputs) << "outputs must be provided";
multi_outputs->clear();
// Due to possible overlapping of feed nodes among user-specified inputs,
// `JoinSignatures()` will deduplicate against fetch tensor names and produce
// the desired inputs in a new order. The same dedup logic is used here to
// generate the flattened input values in the same order.
//
// Note that we don't need to do any deduplicating nor reordering for the
// fetch nodes.
//
// TODO(tfrt-devs): Consider refactoring JoinSignatures so that we don't have
// the implicit requirement that the same dedup logic must be used here and in
// JoinSignatures().
std::vector<std::pair<std::string /*tensor_name*/, tensorflow::Tensor>>
flat_inputs;
std::vector<std::string> flat_output_names;
absl::flat_hash_set<std::string> visited_feed_tensor_names;
const auto& signature_defs = meta_graph_def_.signature_def();
for (int i = 0; i < names.size(); ++i) {
const auto& signature_name = names[i];
const auto& input_tensors = multi_inputs[i];
auto sig_iter = signature_defs.find(signature_name);
// Early out if any signature can't be found.
TF_RET_CHECK(sig_iter != signature_defs.end())
<< "failed to find signature in the graph";
const auto& signature_def = sig_iter->second;
// `signatures_` keeps the user-specified input names that is in the same
// order as `input_tensors`.
const auto& signature = signatures_.at(signature_name);
const auto& input_names = signature.input_names;
TF_RETURN_IF_ERROR(
IsInputSpecsCorrect(signature_name, signature, input_tensors));
DCHECK(signature.captures.empty());
TF_RET_CHECK(input_tensors.size() == signature_def.inputs().size())
<< "Incorrect input size for signature: " << signature_name
<< ": expected " << signature_def.inputs().size() << ", but got "
<< input_tensors.size();
DCHECK_EQ(input_names.size(), signature_def.inputs().size());
// Then we find out the corresponding tensor names (ie.
// node_name:output_idx) for the inputs using the SignatureDef proto.
//
// TODO(tfrt-devs): Consider including tensor names in `signatures_` as
// well, so that only `signatures_` is used here.
for (int j = 0; j < input_tensors.size(); ++j) {
const auto& tensor_info = signature_def.inputs().at(input_names[j]);
// TODO(b/184675681): Support other encoding cases.
//
// TODO(b/184679394): Add unit test for this check.
TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
<< "Only dense tensor is supported, but got encoding case "
<< tensor_info.encoding_case();
const auto& tensor_name = tensor_info.name();
// Skip if we have visited the feed tensor. Otherwise, marked it as
// visited and put it in the `flat_inputs`. Note that the following code
// uses the same logic as in JoinSignatures() to deduplicate inputs with
// the feed tensor names, and generates the flat inputs in the same order.
if (visited_feed_tensor_names.contains(tensor_name)) continue;
visited_feed_tensor_names.insert(tensor_name);
flat_inputs.push_back(std::make_pair(tensor_name, input_tensors[j]));
}
for (const auto& output_key : signature.output_names) {
const auto& tensor_info = signature_def.outputs().at(output_key);
VLOG(1) << "Importing Signature Output: output_key = " << output_key
<< ", tensor_info = " << tensor_info.DebugString();
TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
<< "Only dense tensor is supported, but got encoding case "
<< tensor_info.encoding_case();
flat_output_names.push_back(tensor_info.name());
}
}
std::vector<tensorflow::Tensor> flat_outputs;
TF_RETURN_IF_ERROR(
graph_executor_->Run(run_options, flat_inputs, flat_output_names,
/*target_tensor_names=*/{}, &flat_outputs));
// The outputs of the compiled function are in the user-specified order,
// though they are flattened. So we just need to regroup the outputs for each
// signature using the number of outputs of it.
multi_outputs->resize(names.size());
auto cur = flat_outputs.begin();
for (size_t i = 0; i < names.size(); ++i) {
const auto& signature_name = names[i];
const size_t len = signature_defs.at(signature_name).outputs().size();
std::move(cur, cur + len, std::back_inserter(multi_outputs->at(i)));
cur += len;
DCHECK_LE(std::distance(flat_outputs.begin(), cur), flat_outputs.size());
}
return tensorflow::Status::OK();
}
tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
SavedModelImpl::ImportSubgraph(
mlir::MLIRContext* context,
const tensorflow::GraphImportConfig::InputArrays& input_nodes,
const std::vector<std::string>& output_nodes,
const std::vector<std::string>& target_nodes) {
tensorflow::GraphImportConfig graph_import_config;
graph_import_config.prune_unused_nodes = true;
graph_import_config.enable_shape_inference = false;
graph_import_config.inputs = input_nodes;
graph_import_config.outputs = output_nodes;
graph_import_config.control_outputs = target_nodes;
// Optimize the graph.
TF_ASSIGN_OR_RETURN(
auto optimization_result,
graph_executor_->graph_execution_state().CreateOptimizedGraph(
graph_import_config));
// Convert the optimized graph to an MLIR module.
return tensorflow::ConvertGraphToMlir(
*optimization_result.graph, /*debug_info=*/{},
optimization_result.graph->flib_def(), graph_import_config, context);
}
tensorflow::Status SavedModelImpl::RunByTensorNames(
const RunOptions& run_options,
absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
absl::Span<const std::string> output_tensor_names,
absl::Span<const std::string> target_node_names,
std::vector<tensorflow::Tensor>* outputs) {
// TODO(b/192498110): Validate input type.
return graph_executor_->Run(run_options, inputs, output_tensor_names,
target_node_names, outputs);
}
namespace {
using JoinedSignature = SavedModelImpl::JoinedSignature;
// Returns a joined signature with the signatures in `names`. For inputs, as
// their corresponding nodes may overlap, we deduplicate them by the nodes so
// the order of inputs for the joined signature would be different from the
// original order. For outputs, overlapping is fine so we only flatten it in the
// original order.
StatusOr<JoinedSignature> JoinSignatures(
absl::Span<const std::string> names, const SignatureMap& signature_map,
const tensorflow::protobuf::Map<std::string, tensorflow::SignatureDef>&
signature_def_map) {
// Join all the names, all the inputs, and all the outputs.
JoinedSignature joined_signature;
joined_signature.name = absl::StrJoin(names, kSignatureJoiningDelimiter);
// Keep the feed tensor names visited.
absl::flat_hash_set<std::string> visited_feed_tensor_names;
for (const auto& name : names) {
const auto& signature_def = signature_def_map.at(name);
// For inputs, we deduplicate possible overlapping feed nodes and create the
// new input array.
for (const auto& iter : signature_def.inputs()) {
const auto& tensor_info = iter.second;
// Skip if this feed node is already visited.
if (visited_feed_tensor_names.contains(tensor_info.name())) continue;
// Otherwise, we parse its tensor info and collect it for later
// compilation.
visited_feed_tensor_names.insert(tensor_info.name());
// TODO(b/184675681): Support other encoding cases.
//
// TODO(b/184679394): Add unit test for this check.
TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
<< "Only dense tensor is supported, but got encoding case "
<< tensor_info.encoding_case();
VLOG(1) << "Importing Signature Input: input_key = " << iter.first
<< ", tensor_info = " << tensor_info.DebugString();
tensorflow::ArrayInfo array_info;
array_info.imported_dtype = tensor_info.dtype();
if (tensor_info.has_tensor_shape()) {
array_info.shape = tensor_info.tensor_shape();
} else {
// If there is no tensor shape in the tensor info, conservatively set
// unknown_rank to true.
array_info.shape.set_unknown_rank(true);
}
joined_signature.input_nodes.insert(
std::pair<std::string, tensorflow::ArrayInfo>(tensor_info.name(),
std::move(array_info)));
}
// For outputs, we simply flatten them in the original order, as it is fine
// to have duplicated fetch nodes.
const internal::Signature& signature = signature_map.at(name);
for (const auto& output_key : signature.output_names) {
const auto& tensor_info = signature_def.outputs().at(output_key);
VLOG(1) << "Importing Signature Output: output_key = " << output_key
<< ", tensor_info = " << tensor_info.DebugString();
TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
<< "Only dense tensor is supported, but got encoding case "
<< tensor_info.encoding_case();
joined_signature.output_nodes.push_back(tensor_info.name());
}
}
return joined_signature;
}
} // namespace
// TODO(b/216379787): Reuse `GraphExecutor::LoadClientGraph()`.
StatusOr<std::reference_wrapper<const SavedModelImpl::LoadingResult>>
SavedModelImpl::LoadJoinedSignature(const JoinedSignature& joined_signature) {
// Step 1: Import the combined subgraph from proto to an MLIR module.
mlir::MLIRContext context;
TF_ASSIGN_OR_RETURN(auto module,
ImportSubgraph(&context, joined_signature.input_nodes,
joined_signature.output_nodes,
joined_signature.target_nodes));
// Step 2: Compile the MLIR module from TF dialect to TFRT dialect (in BEF).
auto loading_result = std::make_unique<LoadingResult>();
loading_result->name = joined_signature.name;
loading_result->resource_context = CreateResourceContext(
runtime(), tpu_model_resource_.get(),
options_.graph_execution_options.compile_options.tpu_target);
TF_RETURN_IF_ERROR(tensorflow::ConvertTfMlirToBef(
options_.graph_execution_options.compile_options, module.get(),
&loading_result->bef));
// Step 3: Initialize runtime states using special BEF functions.
TF_ASSIGN_OR_RETURN(
loading_result->bef_file,
tfrt::CreateBefFileFromBefBuffer(
*options_.graph_execution_options.runtime, loading_result->bef));
TF_RETURN_IF_ERROR(RunInitializers(
/*initializers_and_signatures=*/{},
options_.graph_execution_options.model_metadata,
loading_result->bef_file.get(), *options_.graph_execution_options.runtime,
loading_result->resource_context.get(), *fallback_state_));
// Store loading_result in cache.
const auto* loading_result_ptr = loading_result.get();
loading_result_cache_[joined_signature.name] = std::move(loading_result);
return {*loading_result_ptr};
}
StatusOr<std::reference_wrapper<const SavedModelImpl::LoadingResult>>
SavedModelImpl::GetOrCreateLoadingResult(absl::Span<const std::string> names) {
const auto joined_name = absl::StrJoin(names, kSignatureJoiningDelimiter);
tensorflow::mutex_lock l(loading_result_cache_mu_);
const auto iter = loading_result_cache_.find(joined_name);
if (iter != loading_result_cache_.end()) return {*iter->second};
TF_ASSIGN_OR_RETURN(
const auto joined_signature,
JoinSignatures(names, signatures_, meta_graph_def_.signature_def()));
return LoadJoinedSignature(joined_signature);
}
} // namespace tfrt_stub
} // namespace tensorflow