Create a SavedModelSignatureDefImporterLite which does not require a Session
and does not apply any graph transformation.
SavedModelSignatureDefImporter is refactored to use
SavedModelSignatureDefImporterLite but there is no functional change to
SavedModelSignatureDefImporter.
PiperOrigin-RevId: 340899850
Change-Id: Ifd48b4ade028282bb8915e76e045f1ad907318f5
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py
index 3044a9b..ef16b52 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py
@@ -33,8 +33,8 @@
# CHECK-SAME: min_consumer
# CHECK-SAME: producer
-# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
# CHECK: "tf_saved_model.global_tensor"()
+# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
# CHECK: func [[init]]
# CHECK-NEXT: [[R5:%.*]] = "tf.Const"()
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index cf1161d..6b604fd 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -3274,33 +3274,38 @@
// A helper class to import a TensorFlow model expressed in SavedModel V1 into
// an MLIR Module in SavedModel dialect.
-class SavedModelSignatureDefImporter {
+class SavedModelSignatureDefImporterLite {
public:
// Main entry point: converts all functions (specified by SignatureDefs) in
// the given meta graph to an MLIR Module.
static StatusOr<mlir::OwningModuleRef> Convert(
- const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
- mlir::MLIRContext* context, bool upgrade_legacy) {
+ const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
+ absl::Span<std::string> exported_names, mlir::MLIRContext* context,
+ bool upgrade_legacy) {
LoadImporterDialects(*context);
- SavedModelSignatureDefImporter importer(bundle, exported_names, context);
+ SavedModelSignatureDefImporterLite importer(meta_graph_def, debug_info,
+ exported_names, context);
TF_RETURN_IF_ERROR(importer.InitializeGraph(upgrade_legacy));
- return importer.ConvertSignatures();
+ TF_ASSIGN_OR_RETURN(auto module, importer.ConvertSignatures());
+
+ SortSavedModelModule(*module);
+ MarkSavedModelFunctionVisibility(*module);
+
+ return module;
}
- private:
- SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
- absl::Span<std::string> exported_names,
- mlir::MLIRContext* context)
- : bundle_(bundle),
+ SavedModelSignatureDefImporterLite(const MetaGraphDef& meta_graph_def,
+ const GraphDebugInfo& debug_info,
+ absl::Span<std::string> exported_names,
+ mlir::MLIRContext* context)
+ : meta_graph_def_(meta_graph_def),
+ debug_info_(debug_info),
graph_(std::make_unique<Graph>(OpRegistry::Global())),
- debug_info_(),
exported_names_(exported_names),
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))),
- symbol_table_(module_.get()) {
- // debug_info might not be loaded with loader_lite.
- if (bundle_.debug_info != nullptr) debug_info_ = *bundle_.debug_info;
- }
+ symbol_table_(module_.get()) {}
+ private:
// Initializes Graph from saved model GraphDef. If `upgrade_legacy` is set,
// functionalization is ran on the Graph.
Status InitializeGraph(bool upgrade_legacy);
@@ -3326,34 +3331,33 @@
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
const std::vector<std::string> control_outputs);
- // Lifts the variables in `module_`.
- Status LiftVariables();
-
// Moves the functions in `sub_module` to `module_` and skips the duplicate
// functions.
void MoveConvertedFunctionsToModule(mlir::ModuleOp sub_module);
GraphImportConfig::InputArrays ParseInputArrays(
- const std::vector<std::pair<std::string, TensorInfo>>& inputs);
+ llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
const Graph& graph() const { return *graph_; }
const GraphDebugInfo& debug_info() const { return debug_info_; }
- const SavedModelBundle& bundle_;
+ private:
+ const MetaGraphDef& meta_graph_def_;
+ const GraphDebugInfo& debug_info_;
std::unique_ptr<Graph> graph_;
- GraphDebugInfo debug_info_;
absl::Span<std::string> exported_names_;
mlir::OwningModuleRef module_;
mlir::SymbolTable symbol_table_;
};
-Status SavedModelSignatureDefImporter::InitializeGraph(bool upgrade_legacy) {
+Status SavedModelSignatureDefImporterLite::InitializeGraph(
+ bool upgrade_legacy) {
GraphConstructorOptions options;
options.allow_internal_ops = true;
options.add_default_attributes = true;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
- options, bundle_.meta_graph_def.graph_def(), graph_.get()));
+ options, meta_graph_def_.graph_def(), graph_.get()));
// TODO(jpienaar): Remove need to const_cast.
if (upgrade_legacy) {
@@ -3366,11 +3370,11 @@
return Status::OK();
}
-StatusOr<std::vector<SavedModelSignatureDefImporter::AssetInfo>>
-SavedModelSignatureDefImporter::ConvertAssets() {
+StatusOr<std::vector<SavedModelSignatureDefImporterLite::AssetInfo>>
+SavedModelSignatureDefImporterLite::ConvertAssets() {
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(
- internal::GetAssetFileDefs(bundle_.meta_graph_def, &asset_file_defs));
+ internal::GetAssetFileDefs(meta_graph_def_, &asset_file_defs));
std::vector<AssetInfo> results;
results.reserve(asset_file_defs.size());
@@ -3393,7 +3397,7 @@
return results;
}
-void SavedModelSignatureDefImporter::MoveConvertedFunctionsToModule(
+void SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
mlir::ModuleOp sub_module) {
// Iterate through all functions and insert the ones that do not already exist
// in `module_`.
@@ -3403,11 +3407,10 @@
}
}
-Status SavedModelSignatureDefImporter::ConvertInitializer(
+Status SavedModelSignatureDefImporterLite::ConvertInitializer(
const std::vector<AssetInfo>& assets) {
std::string init_node_name;
- TF_RETURN_IF_ERROR(
- internal::GetInitOp("", bundle_.meta_graph_def, &init_node_name));
+ TF_RETURN_IF_ERROR(internal::GetInitOp("", meta_graph_def_, &init_node_name));
if (init_node_name.empty()) return Status::OK();
@@ -3454,51 +3457,7 @@
}
StatusOr<mlir::OwningModuleRef>
-SavedModelSignatureDefImporter::ConvertSignatures() {
- const auto& signatures = bundle_.GetSignatures();
- PopulateTfVersions(module_.get(), graph().versions());
-
- // debug_info might not be loaded with loader_lite.
- GraphDebugInfo debug_info;
- if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
-
- llvm::StringSet<> exported_name_set;
- exported_name_set.insert(exported_names_.begin(), exported_names_.end());
-
- for (const auto& key_and_signature_def : signatures) {
- const std::string& sig_def_key = key_and_signature_def.first;
- const SignatureDef& signature_def = key_and_signature_def.second;
-
- // It is safe to skip "__saved_model_init_op" since it is an internal
- // signature that is not user-accessible.
- if (sig_def_key == "__saved_model_init_op") {
- continue;
- }
- if (!exported_name_set.empty() &&
- exported_name_set.count(sig_def_key) == 0) {
- continue;
- }
-
- TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def));
- }
-
- TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets());
- TF_RETURN_IF_ERROR(ConvertInitializer(assets));
-
- mlir::OpBuilder builder(module_->getBodyRegion());
- module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
-
- module_->setAttr("tf_saved_model.under_construction", builder.getUnitAttr());
- TF_RETURN_IF_ERROR(LiftVariables());
- module_->removeAttr("tf_saved_model.under_construction");
-
- SortSavedModelModule(*module_);
- MarkSavedModelFunctionVisibility(*module_);
-
- return std::move(module_);
-}
-
-StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefImporter::ConvertGraph(
+SavedModelSignatureDefImporterLite::ConvertGraph(
const std::string& name,
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
@@ -3514,7 +3473,7 @@
graph().flib_def(), specs, name);
}
-Status SavedModelSignatureDefImporter::ConvertSignature(
+Status SavedModelSignatureDefImporterLite::ConvertSignature(
const std::string& sig_def_key, const SignatureDef& signature_def) {
// Create local vectors for the input and output and sort them to be
// deterministic. We don't want anyone to really depend on the order, client
@@ -3564,33 +3523,9 @@
return Status::OK();
}
-Status SavedModelSignatureDefImporter::LiftVariables() {
- mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext());
-
- mlir::PassManager pm(module_->getContext());
- SetCrashReproducer(pm);
- pm.addNestedPass<mlir::FuncOp>(
- mlir::tf_executor::CreateTFExecutorGraphPruningPass());
- pm.addNestedPass<mlir::FuncOp>(
- mlir::CreateExecutorDialectToFunctionalConversionPass());
- pm.addPass(
- mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
- pm.addNestedPass<mlir::FuncOp>(
- mlir::TF::
- CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
- pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
- pm.addPass(
- mlir::tf_saved_model::CreateLiftVariablesPass(bundle_.GetSession()));
- pm.addNestedPass<mlir::FuncOp>(
- mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
- if (mlir::failed(pm.run(*module_)))
- return diag_handler.Combine(errors::Internal("Failed to lift variables."));
-
- return Status::OK();
-}
-
-GraphImportConfig::InputArrays SavedModelSignatureDefImporter::ParseInputArrays(
- const std::vector<std::pair<std::string, TensorInfo>>& inputs) {
+GraphImportConfig::InputArrays
+SavedModelSignatureDefImporterLite::ParseInputArrays(
+ llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs) {
GraphImportConfig::InputArrays results;
for (const auto& iter : inputs) {
const auto& tensor_info = iter.second;
@@ -3608,6 +3543,105 @@
return results;
}
+StatusOr<mlir::OwningModuleRef>
+SavedModelSignatureDefImporterLite::ConvertSignatures() {
+ const auto& signatures = meta_graph_def_.signature_def();
+ PopulateTfVersions(module_.get(), graph().versions());
+
+ llvm::DenseSet<llvm::StringRef> exported_name_set;
+ exported_name_set.insert(exported_names_.begin(), exported_names_.end());
+
+ for (const auto& key_and_signature_def : signatures) {
+ const std::string& sig_def_key = key_and_signature_def.first;
+ const SignatureDef& signature_def = key_and_signature_def.second;
+
+ // It is safe to skip "__saved_model_init_op" since it is an internal
+ // signature that is not user-accessible. This signature will be handled in
+ // ConvertInitializer().
+ if (sig_def_key == "__saved_model_init_op") {
+ continue;
+ }
+ if (!exported_name_set.empty() &&
+ exported_name_set.count(sig_def_key) == 0) {
+ continue;
+ }
+
+ TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def));
+ }
+
+ TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets());
+ TF_RETURN_IF_ERROR(ConvertInitializer(assets));
+
+ mlir::OpBuilder builder(module_->getBodyRegion());
+ module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
+
+ SortSavedModelModule(*module_);
+ MarkSavedModelFunctionVisibility(*module_);
+
+ return std::move(module_);
+}
+
+// A helper class to import a TensorFlow model expressed in SavedModel V1 into
+// an MLIR Module in SavedModel dialect. In addition to importing the model, it
+// performs a few graph transformations, including:
+// 1) Convert read-only ref variables to resource variables
+// 2) Lift resource variables to global_tensors by using a TF session.
+class SavedModelSignatureDefImporter {
+ public:
+ // Main entry point: converts all functions (specified by SignatureDefs) in
+ // the given meta graph to an MLIR Module.
+ static StatusOr<mlir::OwningModuleRef> Convert(
+ const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
+ mlir::MLIRContext* context, bool upgrade_legacy) {
+ // debug_info might not be loaded with loader_lite.
+ GraphDebugInfo debug_info;
+ if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info;
+
+ TF_ASSIGN_OR_RETURN(auto module,
+ SavedModelSignatureDefImporterLite::Convert(
+ bundle.meta_graph_def, debug_info, exported_names,
+ context, upgrade_legacy));
+
+ mlir::OpBuilder builder(module->getContext());
+ module->setAttr("tf_saved_model.under_construction", builder.getUnitAttr());
+ TF_RETURN_IF_ERROR(LiftVariables(bundle, *module));
+ module->removeAttr("tf_saved_model.under_construction");
+
+ return module;
+ }
+
+ private:
+ // Lifts the variables in `module`.
+ static Status LiftVariables(const SavedModelBundle& bundle,
+ mlir::ModuleOp module);
+};
+
+Status SavedModelSignatureDefImporter::LiftVariables(
+ const SavedModelBundle& bundle, mlir::ModuleOp module) {
+ mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
+
+ mlir::PassManager pm(module.getContext());
+ SetCrashReproducer(pm);
+ pm.addNestedPass<mlir::FuncOp>(
+ mlir::tf_executor::CreateTFExecutorGraphPruningPass());
+ pm.addNestedPass<mlir::FuncOp>(
+ mlir::CreateExecutorDialectToFunctionalConversionPass());
+ pm.addPass(
+ mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
+ pm.addNestedPass<mlir::FuncOp>(
+ mlir::TF::
+ CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
+ pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
+ pm.addPass(
+ mlir::tf_saved_model::CreateLiftVariablesPass(bundle.GetSession()));
+ pm.addNestedPass<mlir::FuncOp>(
+ mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
+ if (mlir::failed(pm.run(module)))
+ return diag_handler.Combine(errors::Internal("Failed to lift variables."));
+
+ return Status::OK();
+}
+
} // namespace
StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(