Add an option to specify whether to enable TFRT GPU related logic.

PiperOrigin-RevId: 442579977
diff --git a/tensorflow/core/tfrt/graph_executor/graph_execution_options.h b/tensorflow/core/tfrt/graph_executor/graph_execution_options.h
index 85e0127..03c2011 100644
--- a/tensorflow/core/tfrt/graph_executor/graph_execution_options.h
+++ b/tensorflow/core/tfrt/graph_executor/graph_execution_options.h
@@ -39,6 +39,9 @@
   // optimizations like function inlining will be applied.
   bool enable_grappler_function_optimizer = false;
 
+  // Whether to enable TFRT GPU.
+  bool enable_tfrt_gpu = false;
+
   // Runtime configuration. Refer to tensorflow::tfrt_stub::Runtime class for
   // more details. It must not be nullptr;
   const tensorflow::tfrt_stub::Runtime* runtime = nullptr;
diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc
index ec4dc21..f77d635 100644
--- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc
+++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc
@@ -295,6 +295,7 @@
   TfrtGraphExecutionState::Options graph_execution_state_options;
   graph_execution_state_options.run_placer_grappler_on_functions =
       options.run_placer_grappler_on_functions;
+  graph_execution_state_options.enable_tfrt_gpu = options.enable_tfrt_gpu;
 
   TF_ASSIGN_OR_RETURN(
       auto graph_execution_state,
diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc
index ff59e77..c550cb3 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model.cc
@@ -303,7 +303,8 @@
 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSavedModel(
     mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def,
     const FallbackState& fallback_state, std::string saved_model_dir,
-    bool import_user_signatures, bool run_placer_grappler_on_functions) {
+    bool import_user_signatures, bool run_placer_grappler_on_functions,
+    bool enable_tfrt_gpu) {
   std::vector<std::string> signature_names;
   if (import_user_signatures) {
     signature_names = FindNamesForValidSignatures(meta_graph_def);
@@ -321,7 +322,7 @@
   TF_ASSIGN_OR_RETURN(auto import_input,
                       TfrtSavedModelMLIRImportInput::Create(
                           fallback_state, &meta_graph_def, /*debug_info=*/{},
-                          run_placer_grappler_on_functions));
+                          run_placer_grappler_on_functions, enable_tfrt_gpu));
 
   TF_ASSIGN_OR_RETURN(
       auto module,
@@ -484,7 +485,8 @@
             &context, meta_graph_def, *fallback_state,
             std::string(saved_model_dir),
             /*import_user_signatures=*/!options.enable_lazy_loading,
-            options.graph_execution_options.run_placer_grappler_on_functions));
+            options.graph_execution_options.run_placer_grappler_on_functions,
+            options.graph_execution_options.enable_tfrt_gpu));
 
     auto import_duration = absl::Now() - import_start_time;
     saved_model_import_time_seconds->GetCell(std::string(saved_model_dir))
diff --git a/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc b/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc
index c8b9dc8..5338717 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc
@@ -24,12 +24,13 @@
 StatusOr<TfrtSavedModelMLIRImportInput> TfrtSavedModelMLIRImportInput::Create(
     const FallbackState& fallback_state, const MetaGraphDef* meta_graph_def,
     const GraphDebugInfo& debug_info,
-    bool run_placer_grappler_on_nested_functions) {
+    bool run_placer_grappler_on_nested_functions, bool enable_tfrt_gpu) {
   DCHECK(meta_graph_def);
 
   TfrtGraphExecutionState::Options options;
   options.run_placer_grappler_on_functions =
       run_placer_grappler_on_nested_functions;
+  options.enable_tfrt_gpu = enable_tfrt_gpu;
   TF_ASSIGN_OR_RETURN(
       auto graph_execution_state,
       TfrtGraphExecutionState::Create(options, meta_graph_def->graph_def(),
diff --git a/tensorflow/core/tfrt/saved_model/saved_model_import_input.h b/tensorflow/core/tfrt/saved_model/saved_model_import_input.h
index 3a8131e..45af815 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model_import_input.h
+++ b/tensorflow/core/tfrt/saved_model/saved_model_import_input.h
@@ -30,7 +30,8 @@
   static StatusOr<TfrtSavedModelMLIRImportInput> Create(
       const FallbackState& fallback_state, const MetaGraphDef* meta_graph_def,
       const GraphDebugInfo& debug_info,
-      bool run_placer_grappler_on_nested_functions = false);
+      bool run_placer_grappler_on_nested_functions = false,
+      bool enable_tfrt_gpu = false);
 
   TfrtSavedModelMLIRImportInput(
       const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info,