Create a SavedModelSignatureDefImporterLite which does not require a Session
and does not apply any graph transformation.

SavedModelSignatureDefImporter is refactored to use
SavedModelSignatureDefImporterLite but there is no functional change to
SavedModelSignatureDefImporter.

PiperOrigin-RevId: 340899850
Change-Id: Ifd48b4ade028282bb8915e76e045f1ad907318f5
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py
index 3044a9b..ef16b52 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py
@@ -33,8 +33,8 @@
 # CHECK-SAME: min_consumer
 # CHECK-SAME: producer
 
-# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
 # CHECK: "tf_saved_model.global_tensor"()
+# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
 
 # CHECK:      func [[init]]
 # CHECK-NEXT: [[R5:%.*]] = "tf.Const"()
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index cf1161d..6b604fd 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -3274,33 +3274,38 @@
 
 // A helper class to import a TensorFlow model expressed in SavedModel V1 into
 // an MLIR Module in SavedModel dialect.
-class SavedModelSignatureDefImporter {
+class SavedModelSignatureDefImporterLite {
  public:
   // Main entry point: converts all functions (specified by SignatureDefs) in
   // the given meta graph to an MLIR Module.
   static StatusOr<mlir::OwningModuleRef> Convert(
-      const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
-      mlir::MLIRContext* context, bool upgrade_legacy) {
+      const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
+      absl::Span<std::string> exported_names, mlir::MLIRContext* context,
+      bool upgrade_legacy) {
     LoadImporterDialects(*context);
-    SavedModelSignatureDefImporter importer(bundle, exported_names, context);
+    SavedModelSignatureDefImporterLite importer(meta_graph_def, debug_info,
+                                                exported_names, context);
     TF_RETURN_IF_ERROR(importer.InitializeGraph(upgrade_legacy));
-    return importer.ConvertSignatures();
+    TF_ASSIGN_OR_RETURN(auto module, importer.ConvertSignatures());
+
+    SortSavedModelModule(*module);
+    MarkSavedModelFunctionVisibility(*module);
+
+    return module;
   }
 
- private:
-  SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
-                                 absl::Span<std::string> exported_names,
-                                 mlir::MLIRContext* context)
-      : bundle_(bundle),
+  SavedModelSignatureDefImporterLite(const MetaGraphDef& meta_graph_def,
+                                     const GraphDebugInfo& debug_info,
+                                     absl::Span<std::string> exported_names,
+                                     mlir::MLIRContext* context)
+      : meta_graph_def_(meta_graph_def),
+        debug_info_(debug_info),
         graph_(std::make_unique<Graph>(OpRegistry::Global())),
-        debug_info_(),
         exported_names_(exported_names),
         module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))),
-        symbol_table_(module_.get()) {
-    // debug_info might not be loaded with loader_lite.
-    if (bundle_.debug_info != nullptr) debug_info_ = *bundle_.debug_info;
-  }
+        symbol_table_(module_.get()) {}
 
+ private:
   // Initializes Graph from saved model GraphDef. If `upgrade_legacy` is set,
   // functionalization is ran on the Graph.
   Status InitializeGraph(bool upgrade_legacy);
@@ -3326,34 +3331,33 @@
       const std::vector<std::pair<std::string, TensorInfo>>& outputs,
       const std::vector<std::string> control_outputs);
 
-  // Lifts the variables in `module_`.
-  Status LiftVariables();
-
   // Moves the functions in `sub_module` to `module_` and skips the duplicate
   // functions.
   void MoveConvertedFunctionsToModule(mlir::ModuleOp sub_module);
 
   GraphImportConfig::InputArrays ParseInputArrays(
-      const std::vector<std::pair<std::string, TensorInfo>>& inputs);
+      llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs);
 
   const Graph& graph() const { return *graph_; }
   const GraphDebugInfo& debug_info() const { return debug_info_; }
 
-  const SavedModelBundle& bundle_;
+ private:
+  const MetaGraphDef& meta_graph_def_;
+  const GraphDebugInfo& debug_info_;
   std::unique_ptr<Graph> graph_;
-  GraphDebugInfo debug_info_;
   absl::Span<std::string> exported_names_;
   mlir::OwningModuleRef module_;
   mlir::SymbolTable symbol_table_;
 };
 
-Status SavedModelSignatureDefImporter::InitializeGraph(bool upgrade_legacy) {
+Status SavedModelSignatureDefImporterLite::InitializeGraph(
+    bool upgrade_legacy) {
   GraphConstructorOptions options;
   options.allow_internal_ops = true;
   options.add_default_attributes = true;
 
   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
-      options, bundle_.meta_graph_def.graph_def(), graph_.get()));
+      options, meta_graph_def_.graph_def(), graph_.get()));
 
   // TODO(jpienaar): Remove need to const_cast.
   if (upgrade_legacy) {
@@ -3366,11 +3370,11 @@
   return Status::OK();
 }
 
-StatusOr<std::vector<SavedModelSignatureDefImporter::AssetInfo>>
-SavedModelSignatureDefImporter::ConvertAssets() {
+StatusOr<std::vector<SavedModelSignatureDefImporterLite::AssetInfo>>
+SavedModelSignatureDefImporterLite::ConvertAssets() {
   std::vector<AssetFileDef> asset_file_defs;
   TF_RETURN_IF_ERROR(
-      internal::GetAssetFileDefs(bundle_.meta_graph_def, &asset_file_defs));
+      internal::GetAssetFileDefs(meta_graph_def_, &asset_file_defs));
 
   std::vector<AssetInfo> results;
   results.reserve(asset_file_defs.size());
@@ -3393,7 +3397,7 @@
   return results;
 }
 
-void SavedModelSignatureDefImporter::MoveConvertedFunctionsToModule(
+void SavedModelSignatureDefImporterLite::MoveConvertedFunctionsToModule(
     mlir::ModuleOp sub_module) {
   // Iterate through all functions and insert the ones that do not already exist
   // in `module_`.
@@ -3403,11 +3407,10 @@
   }
 }
 
-Status SavedModelSignatureDefImporter::ConvertInitializer(
+Status SavedModelSignatureDefImporterLite::ConvertInitializer(
     const std::vector<AssetInfo>& assets) {
   std::string init_node_name;
-  TF_RETURN_IF_ERROR(
-      internal::GetInitOp("", bundle_.meta_graph_def, &init_node_name));
+  TF_RETURN_IF_ERROR(internal::GetInitOp("", meta_graph_def_, &init_node_name));
 
   if (init_node_name.empty()) return Status::OK();
 
@@ -3454,51 +3457,7 @@
 }
 
 StatusOr<mlir::OwningModuleRef>
-SavedModelSignatureDefImporter::ConvertSignatures() {
-  const auto& signatures = bundle_.GetSignatures();
-  PopulateTfVersions(module_.get(), graph().versions());
-
-  // debug_info might not be loaded with loader_lite.
-  GraphDebugInfo debug_info;
-  if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
-
-  llvm::StringSet<> exported_name_set;
-  exported_name_set.insert(exported_names_.begin(), exported_names_.end());
-
-  for (const auto& key_and_signature_def : signatures) {
-    const std::string& sig_def_key = key_and_signature_def.first;
-    const SignatureDef& signature_def = key_and_signature_def.second;
-
-    // It is safe to skip "__saved_model_init_op" since it is an internal
-    // signature that is not user-accessible.
-    if (sig_def_key == "__saved_model_init_op") {
-      continue;
-    }
-    if (!exported_name_set.empty() &&
-        exported_name_set.count(sig_def_key) == 0) {
-      continue;
-    }
-
-    TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def));
-  }
-
-  TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets());
-  TF_RETURN_IF_ERROR(ConvertInitializer(assets));
-
-  mlir::OpBuilder builder(module_->getBodyRegion());
-  module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
-
-  module_->setAttr("tf_saved_model.under_construction", builder.getUnitAttr());
-  TF_RETURN_IF_ERROR(LiftVariables());
-  module_->removeAttr("tf_saved_model.under_construction");
-
-  SortSavedModelModule(*module_);
-  MarkSavedModelFunctionVisibility(*module_);
-
-  return std::move(module_);
-}
-
-StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefImporter::ConvertGraph(
+SavedModelSignatureDefImporterLite::ConvertGraph(
     const std::string& name,
     const std::vector<std::pair<std::string, TensorInfo>>& inputs,
     const std::vector<std::pair<std::string, TensorInfo>>& outputs,
@@ -3514,7 +3473,7 @@
                                    graph().flib_def(), specs, name);
 }
 
-Status SavedModelSignatureDefImporter::ConvertSignature(
+Status SavedModelSignatureDefImporterLite::ConvertSignature(
     const std::string& sig_def_key, const SignatureDef& signature_def) {
   // Create local vectors for the input and output and sort them to be
   // deterministic. We don't want anyone to really depend on the order, client
@@ -3564,33 +3523,9 @@
   return Status::OK();
 }
 
-Status SavedModelSignatureDefImporter::LiftVariables() {
-  mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext());
-
-  mlir::PassManager pm(module_->getContext());
-  SetCrashReproducer(pm);
-  pm.addNestedPass<mlir::FuncOp>(
-      mlir::tf_executor::CreateTFExecutorGraphPruningPass());
-  pm.addNestedPass<mlir::FuncOp>(
-      mlir::CreateExecutorDialectToFunctionalConversionPass());
-  pm.addPass(
-      mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
-  pm.addNestedPass<mlir::FuncOp>(
-      mlir::TF::
-          CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
-  pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
-  pm.addPass(
-      mlir::tf_saved_model::CreateLiftVariablesPass(bundle_.GetSession()));
-  pm.addNestedPass<mlir::FuncOp>(
-      mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
-  if (mlir::failed(pm.run(*module_)))
-    return diag_handler.Combine(errors::Internal("Failed to lift variables."));
-
-  return Status::OK();
-}
-
-GraphImportConfig::InputArrays SavedModelSignatureDefImporter::ParseInputArrays(
-    const std::vector<std::pair<std::string, TensorInfo>>& inputs) {
+GraphImportConfig::InputArrays
+SavedModelSignatureDefImporterLite::ParseInputArrays(
+    llvm::ArrayRef<std::pair<std::string, TensorInfo>> inputs) {
   GraphImportConfig::InputArrays results;
   for (const auto& iter : inputs) {
     const auto& tensor_info = iter.second;
@@ -3608,6 +3543,105 @@
   return results;
 }
 
+StatusOr<mlir::OwningModuleRef>
+SavedModelSignatureDefImporterLite::ConvertSignatures() {
+  const auto& signatures = meta_graph_def_.signature_def();
+  PopulateTfVersions(module_.get(), graph().versions());
+
+  llvm::DenseSet<llvm::StringRef> exported_name_set;
+  exported_name_set.insert(exported_names_.begin(), exported_names_.end());
+
+  for (const auto& key_and_signature_def : signatures) {
+    const std::string& sig_def_key = key_and_signature_def.first;
+    const SignatureDef& signature_def = key_and_signature_def.second;
+
+    // It is safe to skip "__saved_model_init_op" since it is an internal
+    // signature that is not user-accessible. This signature will be handled in
+    // ConvertInitializer().
+    if (sig_def_key == "__saved_model_init_op") {
+      continue;
+    }
+    if (!exported_name_set.empty() &&
+        exported_name_set.count(sig_def_key) == 0) {
+      continue;
+    }
+
+    TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def));
+  }
+
+  TF_ASSIGN_OR_RETURN(auto assets, ConvertAssets());
+  TF_RETURN_IF_ERROR(ConvertInitializer(assets));
+
+  mlir::OpBuilder builder(module_->getBodyRegion());
+  module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
+
+  SortSavedModelModule(*module_);
+  MarkSavedModelFunctionVisibility(*module_);
+
+  return std::move(module_);
+}
+
+// A helper class to import a TensorFlow model expressed in SavedModel V1 into
+// an MLIR Module in SavedModel dialect. In addition to importing the model, it
+// performs a few graph transformations, including:
+//  1) Convert read-only ref variables to resource variables
+//  2) Lift resource variables to global_tensors by using a TF session.
+class SavedModelSignatureDefImporter {
+ public:
+  // Main entry point: converts all functions (specified by SignatureDefs) in
+  // the given meta graph to an MLIR Module.
+  static StatusOr<mlir::OwningModuleRef> Convert(
+      const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
+      mlir::MLIRContext* context, bool upgrade_legacy) {
+    // debug_info might not be loaded with loader_lite.
+    GraphDebugInfo debug_info;
+    if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info;
+
+    TF_ASSIGN_OR_RETURN(auto module,
+                        SavedModelSignatureDefImporterLite::Convert(
+                            bundle.meta_graph_def, debug_info, exported_names,
+                            context, upgrade_legacy));
+
+    mlir::OpBuilder builder(module->getContext());
+    module->setAttr("tf_saved_model.under_construction", builder.getUnitAttr());
+    TF_RETURN_IF_ERROR(LiftVariables(bundle, *module));
+    module->removeAttr("tf_saved_model.under_construction");
+
+    return module;
+  }
+
+ private:
+  // Lifts the variables in `module`.
+  static Status LiftVariables(const SavedModelBundle& bundle,
+                              mlir::ModuleOp module);
+};
+
+Status SavedModelSignatureDefImporter::LiftVariables(
+    const SavedModelBundle& bundle, mlir::ModuleOp module) {
+  mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
+
+  mlir::PassManager pm(module.getContext());
+  SetCrashReproducer(pm);
+  pm.addNestedPass<mlir::FuncOp>(
+      mlir::tf_executor::CreateTFExecutorGraphPruningPass());
+  pm.addNestedPass<mlir::FuncOp>(
+      mlir::CreateExecutorDialectToFunctionalConversionPass());
+  pm.addPass(
+      mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass());
+  pm.addNestedPass<mlir::FuncOp>(
+      mlir::TF::
+          CreateConvertReadonlyReferenceVariablesToResourceVariablesPass());
+  pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
+  pm.addPass(
+      mlir::tf_saved_model::CreateLiftVariablesPass(bundle.GetSession()));
+  pm.addNestedPass<mlir::FuncOp>(
+      mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
+  if (mlir::failed(pm.run(module)))
+    return diag_handler.Combine(errors::Internal("Failed to lift variables."));
+
+  return Status::OK();
+}
+
 }  // namespace
 
 StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(