Add an option to specify whether to enable TFRT GPU related logic.
PiperOrigin-RevId: 442579977
diff --git a/tensorflow/core/tfrt/graph_executor/graph_execution_options.h b/tensorflow/core/tfrt/graph_executor/graph_execution_options.h
index 85e0127..03c2011 100644
--- a/tensorflow/core/tfrt/graph_executor/graph_execution_options.h
+++ b/tensorflow/core/tfrt/graph_executor/graph_execution_options.h
@@ -39,6 +39,9 @@
// optimizations like function inlining will be applied.
bool enable_grappler_function_optimizer = false;
+ // Whether to enable TFRT GPU.
+ bool enable_tfrt_gpu = false;
+
// Runtime configuration. Refer to tensorflow::tfrt_stub::Runtime class for
// more details. It must not be nullptr;
const tensorflow::tfrt_stub::Runtime* runtime = nullptr;
diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc
index ec4dc21..f77d635 100644
--- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc
+++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc
@@ -295,6 +295,7 @@
TfrtGraphExecutionState::Options graph_execution_state_options;
graph_execution_state_options.run_placer_grappler_on_functions =
options.run_placer_grappler_on_functions;
+ graph_execution_state_options.enable_tfrt_gpu = options.enable_tfrt_gpu;
TF_ASSIGN_OR_RETURN(
auto graph_execution_state,
diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc
index ff59e77..c550cb3 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model.cc
@@ -303,7 +303,8 @@
StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSavedModel(
mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def,
const FallbackState& fallback_state, std::string saved_model_dir,
- bool import_user_signatures, bool run_placer_grappler_on_functions) {
+ bool import_user_signatures, bool run_placer_grappler_on_functions,
+ bool enable_tfrt_gpu) {
std::vector<std::string> signature_names;
if (import_user_signatures) {
signature_names = FindNamesForValidSignatures(meta_graph_def);
@@ -321,7 +322,7 @@
TF_ASSIGN_OR_RETURN(auto import_input,
TfrtSavedModelMLIRImportInput::Create(
fallback_state, &meta_graph_def, /*debug_info=*/{},
- run_placer_grappler_on_functions));
+ run_placer_grappler_on_functions, enable_tfrt_gpu));
TF_ASSIGN_OR_RETURN(
auto module,
@@ -484,7 +485,8 @@
&context, meta_graph_def, *fallback_state,
std::string(saved_model_dir),
/*import_user_signatures=*/!options.enable_lazy_loading,
- options.graph_execution_options.run_placer_grappler_on_functions));
+ options.graph_execution_options.run_placer_grappler_on_functions,
+ options.graph_execution_options.enable_tfrt_gpu));
auto import_duration = absl::Now() - import_start_time;
saved_model_import_time_seconds->GetCell(std::string(saved_model_dir))
diff --git a/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc b/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc
index c8b9dc8..5338717 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model_import_input.cc
@@ -24,12 +24,13 @@
StatusOr<TfrtSavedModelMLIRImportInput> TfrtSavedModelMLIRImportInput::Create(
const FallbackState& fallback_state, const MetaGraphDef* meta_graph_def,
const GraphDebugInfo& debug_info,
- bool run_placer_grappler_on_nested_functions) {
+ bool run_placer_grappler_on_nested_functions, bool enable_tfrt_gpu) {
DCHECK(meta_graph_def);
TfrtGraphExecutionState::Options options;
options.run_placer_grappler_on_functions =
run_placer_grappler_on_nested_functions;
+ options.enable_tfrt_gpu = enable_tfrt_gpu;
TF_ASSIGN_OR_RETURN(
auto graph_execution_state,
TfrtGraphExecutionState::Create(options, meta_graph_def->graph_def(),
diff --git a/tensorflow/core/tfrt/saved_model/saved_model_import_input.h b/tensorflow/core/tfrt/saved_model/saved_model_import_input.h
index 3a8131e..45af815 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model_import_input.h
+++ b/tensorflow/core/tfrt/saved_model/saved_model_import_input.h
@@ -30,7 +30,8 @@
static StatusOr<TfrtSavedModelMLIRImportInput> Create(
const FallbackState& fallback_state, const MetaGraphDef* meta_graph_def,
const GraphDebugInfo& debug_info,
- bool run_placer_grappler_on_nested_functions = false);
+ bool run_placer_grappler_on_nested_functions = false,
+ bool enable_tfrt_gpu = false);
TfrtSavedModelMLIRImportInput(
const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info,