Move running logic on function to graph_executor folder

PiperOrigin-RevId: 418679819
Change-Id: Ic301c3199b3b3e2a582495f0aa39e5a384127e7f
diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD
index 0c47e92..cb91ddf 100644
--- a/tensorflow/compiler/mlir/tfrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/BUILD
@@ -450,6 +450,7 @@
         # copybara:uncomment "//learning/brain/experimental/tfrt/visualization:__pkg__",
         "//tensorflow/compiler/mlir/tfrt/tests/saved_model:__pkg__",
         "//tensorflow/core/tfrt/eager:__pkg__",
+        "//tensorflow/core/tfrt/graph_executor:__pkg__",
         "//tensorflow/core/tfrt/saved_model:__pkg__",
         "@tf_runtime//:__subpackages__",
     ],
diff --git a/tensorflow/core/runtime_fallback/kernel/BUILD b/tensorflow/core/runtime_fallback/kernel/BUILD
index 6b07ea6..ec93f4b 100644
--- a/tensorflow/core/runtime_fallback/kernel/BUILD
+++ b/tensorflow/core/runtime_fallback/kernel/BUILD
@@ -346,6 +346,7 @@
         "//tensorflow/compiler/mlir/tfrt/benchmarks:__pkg__",
         "//tensorflow/core/runtime_fallback:internal",
         "//tensorflow/core/tfrt/eager:__pkg__",
+        "//tensorflow/core/tfrt/graph_executor:__pkg__",
         "//tensorflow/core/tfrt/saved_model:__pkg__",
     ],
     deps = [
diff --git a/tensorflow/core/runtime_fallback/opdefs/BUILD b/tensorflow/core/runtime_fallback/opdefs/BUILD
index 29b056f..7f7cfeb 100644
--- a/tensorflow/core/runtime_fallback/opdefs/BUILD
+++ b/tensorflow/core/runtime_fallback/opdefs/BUILD
@@ -89,6 +89,7 @@
     hdrs = ["tfrt_fallback_util.h"],
     visibility = [
         "//tensorflow/core/runtime_fallback:internal",
+        "//tensorflow/core/tfrt/graph_executor:__pkg__",
         "//tensorflow/core/tfrt/saved_model:__pkg__",
     ],
     deps = [
diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD
index ea2ad7a..3bff67f 100644
--- a/tensorflow/core/tfrt/graph_executor/BUILD
+++ b/tensorflow/core/tfrt/graph_executor/BUILD
@@ -20,3 +20,49 @@
         "@com_google_absl//absl/types:optional",
     ],
 )
+
+cc_library(
+    name = "graph_executor",
+    srcs = ["graph_executor.cc"],
+    hdrs = ["graph_executor.h"],
+    tags = ["no_oss"],
+    deps = [
+        ":graph_execution_options",
+        "//tensorflow/cc/saved_model:reader",
+        "//tensorflow/compiler/mlir/tensorflow:import_model",
+        "//tensorflow/compiler/mlir/tensorflow:translate_lib",
+        "//tensorflow/compiler/mlir/tensorflow:upgrade_graph",
+        "//tensorflow/compiler/mlir/tfrt:tf_cpurt_request_context",
+        "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options",
+        "//tensorflow/core:core_cpu_base",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/common_runtime:core_cpu_internal",
+        "//tensorflow/core/framework:tensor",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/core/platform:path",
+        "//tensorflow/core/profiler/lib:connected_traceme",
+        "//tensorflow/core/profiler/lib:traceme_encode",
+        "//tensorflow/core/protobuf:for_core_protos_cc",
+        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_execute_compat",
+        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_op_handler",
+        "//tensorflow/core/runtime_fallback/opdefs:tfrt_fallback_util",
+        "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink",
+        "//tensorflow/core/runtime_fallback/util:tensor_util",
+        "//tensorflow/core/tfrt/fallback:fallback_state",
+        "//tensorflow/core/tfrt/runtime",
+        "//tensorflow/core/tfrt/runtime:work_queue_interface",
+        "//tensorflow/core/tfrt/utils",
+        "//tensorflow/core/tfrt/utils:error_util",
+        "//tensorflow/core/tfrt/utils:fallback_tensor",
+        "//tensorflow/core/tfrt/utils:tfrt_graph_execution_state",
+        "//third_party/eigen3",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/time",
+        "@com_google_absl//absl/types:span",
+        "@tf_runtime//:befexecutor",
+        "@tf_runtime//:core_runtime",
+        "@tf_runtime//:hostcontext",
+        "@tf_runtime//:support",
+        "@tf_runtime//:tensor",
+    ],
+)
diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc
new file mode 100644
index 0000000..35c819e
--- /dev/null
+++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc
@@ -0,0 +1,265 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
+
+#include <array>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_request_context.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/statusor.h"
+#include "tensorflow/core/platform/threadpool_interface.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/lib/connected_traceme.h"
+#include "tensorflow/core/profiler/lib/traceme_encode.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
+#include "tensorflow/core/tfrt/fallback/fallback_state.h"
+#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
+#include "tensorflow/core/tfrt/runtime/runtime.h"
+#include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
+#include "tensorflow/core/tfrt/utils/error_util.h"
+#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
+#include "tensorflow/core/tfrt/utils/utils.h"
+#include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
+#include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
+#include "tfrt/host_context/async_value.h"  // from @tf_runtime
+#include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
+#include "tfrt/host_context/chain.h"  // from @tf_runtime
+#include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
+#include "tfrt/host_context/execution_context.h"  // from @tf_runtime
+#include "tfrt/host_context/function.h"  // from @tf_runtime
+#include "tfrt/host_context/host_context.h"  // from @tf_runtime
+#include "tfrt/host_context/request_deadline_tracker.h"  // from @tf_runtime
+#include "tfrt/host_context/resource_context.h"  // from @tf_runtime
+#include "tfrt/support/forward_decls.h"  // from @tf_runtime
+#include "tfrt/support/ref_count.h"  // from @tf_runtime
+#include "tfrt/support/string_util.h"  // from @tf_runtime
+
+namespace tensorflow {
+namespace tfrt_stub {
+namespace {
+
+constexpr char kDeadlineExceededMessage[] = "Deadline exceeded.";
+
+}  // namespace
+
+StatusOr<std::unique_ptr<RequestInfo>> SetUpRequestContext(
+    const GraphExecutionRunOptions& run_options,
+    const SessionMetadata& model_metadata, tfrt::HostContext* host,
+    tensorflow::tfrt_stub::WorkQueueInterface* work_queue,
+    tfrt::ResourceContext* resource_context,
+    const tensorflow::tfrt_stub::FallbackState& fallback_state) {
+  DCHECK(host);
+  DCHECK(work_queue);
+  // Create request context and prepare deadline tracker.
+  // TODO(tfrt-devs): Consider using an ID unique within each model to reduce
+  // contention.
+  tfrt::RequestContextBuilder request_context_builder(host, resource_context,
+                                                      tfrt::GetUniqueInt());
+
+  // TODO(b/198671794): `intra_op_threadpool` should be passed through Run()
+  // directly.
+  tensorflow::thread::ThreadPoolInterface* intra_op_threadpool = nullptr;
+
+  // TODO(b/198671794): The per-request queue should be passed through Run()
+  // directly.
+  TF_ASSIGN_OR_RETURN(auto request_queue,
+                      work_queue->InitializeRequest(&request_context_builder,
+                                                    &intra_op_threadpool));
+
+  auto request_info = std::make_unique<RequestInfo>();
+
+  // If a per-request queue is not provided, use the original queue in the
+  // tensorflow::Executor::Args::Runner.
+  auto* inter_op_queue = request_queue ? request_queue.get() : work_queue;
+  request_info->runner = [inter_op_queue](std::function<void()> f) {
+    inter_op_queue->AddTask(std::move(f));
+  };
+
+  request_info->request_queue = std::move(request_queue);
+
+  TF_RETURN_IF_ERROR(tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
+      &request_context_builder, &fallback_state.device_manager(),
+      &fallback_state.process_function_library_runtime(), intra_op_threadpool,
+      model_metadata, &request_info->runner));
+
+  TF_RETURN_IF_ERROR(
+      tensorflow::SetUpTfCpuRtRequestContext(&request_context_builder));
+  tfrt::RequestOptions request_options;
+  request_options.priority = run_options.priority;
+  request_context_builder.set_request_options(request_options);
+
+  auto expected_req_ctx = std::move(request_context_builder).build();
+  if (!expected_req_ctx) {
+    return tensorflow::errors::Internal(
+        tfrt::StrCat(expected_req_ctx.takeError()));
+  }
+
+  request_info->tfrt_request_context = std::move(expected_req_ctx.get());
+
+  return request_info;
+}
+
+tensorflow::Status GraphExecutionRunOnFunction(
+    const GraphExecutionOptions& options,
+    const GraphExecutionRunOptions& run_options,
+    absl::string_view signature_name, const tfrt::Function& func,
+    absl::Span<const tensorflow::Tensor> inputs,
+    absl::Span<const tensorflow::Tensor> captures,
+    std::vector<tensorflow::Tensor>* outputs,
+    tfrt::ResourceContext* resource_context, const Runtime& runtime,
+    const FallbackState& fallback_state,
+    tfrt::RequestDeadlineTracker& req_deadline_tracker) {
+  auto* host = runtime.core_runtime()->GetHostContext();
+
+  TF_ASSIGN_OR_RETURN(
+      auto request_info,
+      SetUpRequestContext(run_options, options.model_metadata, host,
+                          run_options.work_queue ? run_options.work_queue
+                                                 : runtime.work_queue(),
+                          resource_context, fallback_state));
+
+  tensorflow::profiler::TraceMeProducer traceme(
+      // To TraceMeConsumers in RunHandlerThreadPool::WorkerLoop.
+      [request_id = request_info->tfrt_request_context->id(), signature_name,
+       options] {
+        return tensorflow::profiler::TraceMeEncode(
+            "TfrtModelRun",
+            {{"_r", 1},
+             {"id", request_id},
+             {"signature", signature_name},
+             {"model_id", absl::StrCat(options.model_metadata.name(),
+                                       options.model_metadata.version())}});
+      },
+      tensorflow::profiler::ContextType::kTfrtExecutor,
+      request_info->tfrt_request_context->id());
+
+  // Only configure timer when the deadline is set.
+  if (run_options.deadline.has_value()) {
+    auto deadline = run_options.deadline.value();
+    if (absl::ToChronoTime(absl::Now()) > deadline) {
+      return tensorflow::errors::DeadlineExceeded(kDeadlineExceededMessage);
+    }
+    req_deadline_tracker.CancelRequestOnDeadline(
+        deadline, request_info->tfrt_request_context);
+  }
+
+  tfrt::ExecutionContext exec_ctx{request_info->tfrt_request_context};
+  if (run_options.work_queue) {
+    // TODO(b/198671794): Avoid creating `request_queue` when the `work_queue`
+    // in `run_options` is specified.
+    exec_ctx.set_work_queue(run_options.work_queue);
+  } else if (request_info->request_queue) {
+    exec_ctx.set_work_queue(request_info->request_queue.get());
+  } else {
+    exec_ctx.set_work_queue(runtime.work_queue());
+  }
+
+  llvm::SmallVector<tfrt::AsyncValue*, 4> arguments;
+  auto cleanup = tensorflow::gtl::MakeCleanup([&]() {
+    for (auto* argument : arguments) argument->DropRef();
+  });
+
+  // The first argument is a chain for side-effects. Since SavedModel::Run()
+  // only returns when side-effects are visible, we can use a ready chain here.
+  arguments.push_back(tfrt::GetReadyChain().release());
+
+  for (const auto& input : inputs) {
+    arguments.push_back(
+        tfrt::MakeAvailableAsyncValueRef<FallbackTensor>(input).release());
+  }
+
+  DCHECK(captures.empty()) << "signature should have no captures, which is "
+                              "guaranteed by the compiler";
+
+  if (arguments.size() != func.argument_types().size())
+    return tensorflow::errors::Internal("incorrect number of inputs.");
+
+  llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 4> chain_and_results;
+  chain_and_results.resize(func.result_types().size());
+
+  // Hand over the execution to thread pool.
+  std::array<tfrt::RCReference<tfrt::AsyncValue>, 1> executed = {
+      EnqueueWork(exec_ctx, [&]() -> tfrt::Chain {
+        func.Execute(exec_ctx, arguments, chain_and_results);
+        return {};
+      })};
+
+  // Wait for the function execution before checking chain and results.
+  exec_ctx.work_queue().Await(executed);
+
+  // Wait for all results including the side-effect chain. This ensures that all
+  // side-effects are visible when SavedModel::Run() returns.
+  exec_ctx.work_queue().Await(chain_and_results);
+
+  DCHECK(!chain_and_results.empty());
+
+  tfrt::RCReference<tfrt::AsyncValue>& chain = chain_and_results[0];
+  auto results = llvm::drop_begin(chain_and_results, 1);
+
+  tensorflow::StatusGroup status_group;
+
+  if (chain->IsError()) {
+    status_group.Update(CreateTfErrorStatus(chain->GetError()));
+  }
+
+  for (tfrt::RCReference<tfrt::AsyncValue>& result : results) {
+    DCHECK(result->IsAvailable());
+
+    if (result->IsError()) {
+      status_group.Update(CreateTfErrorStatus(result->GetError()));
+      outputs->push_back(tensorflow::Tensor());
+      continue;
+    }
+
+    // The result must be a host tensor. This is guaranteed as the compiler
+    // will insert necessary device transfer operations in the graph.
+    DCHECK(result->IsType<FallbackTensor>());
+    const auto& host_tensor = result->get<FallbackTensor>().tensor();
+    // Make a copy of tensor here as the different result AsyncValues might
+    // point to the same underlying tensor.
+    outputs->push_back(host_tensor);
+  }
+
+  // TODO(b/171926578): Explicitly clear the context data. Remove it after the
+  // b/171926578 is fixed.
+  exec_ctx.request_ctx()->ClearData();
+
+  // Check if error is due to cancellation.
+  // TODO(tfrt-devs): report cancellation reason from runtime.
+  if (request_info->tfrt_request_context->IsCancelled()) {
+    // Currently a request can only be cancelled by an expired timer.
+    return tensorflow::errors::DeadlineExceeded(kDeadlineExceededMessage);
+  }
+
+  return status_group.as_summary_status();
+}
+
+}  // namespace tfrt_stub
+}  // namespace tensorflow
diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.h b/tensorflow/core/tfrt/graph_executor/graph_executor.h
new file mode 100644
index 0000000..d46d82d
--- /dev/null
+++ b/tensorflow/core/tfrt/graph_executor/graph_executor.h
@@ -0,0 +1,63 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
+#define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
+
+#include <functional>
+#include <memory>
+
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/tfrt/fallback/fallback_state.h"
+#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
+#include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
+#include "tfrt/host_context/execution_context.h"  // from @tf_runtime
+#include "tfrt/host_context/function.h"  // from @tf_runtime
+#include "tfrt/host_context/request_deadline_tracker.h"  // from @tf_runtime
+#include "tfrt/support/ref_count.h"  // from @tf_runtime
+
+namespace tensorflow {
+namespace tfrt_stub {
+
+// Contains request related info.
+struct RequestInfo {
+  tfrt::RCReference<tfrt::RequestContext> tfrt_request_context;
+  std::unique_ptr<WorkQueueInterface> request_queue;
+  std::function<void(std::function<void()>)> runner;
+};
+
+// Creates a `RequestInfo` given relative data.
+StatusOr<std::unique_ptr<RequestInfo>> SetUpRequestContext(
+    const GraphExecutionRunOptions& run_options,
+    const SessionMetadata& model_metadata, tfrt::HostContext* host,
+    tensorflow::tfrt_stub::WorkQueueInterface* work_queue,
+    tfrt::ResourceContext* resource_context,
+    const FallbackState& fallback_state);
+
+// Runs on a function given input/output and other info.
+tensorflow::Status GraphExecutionRunOnFunction(
+    const GraphExecutionOptions& options,
+    const GraphExecutionRunOptions& run_options,
+    absl::string_view signature_name, const tfrt::Function& func,
+    absl::Span<const tensorflow::Tensor> inputs,
+    absl::Span<const tensorflow::Tensor> captures,
+    std::vector<tensorflow::Tensor>* outputs,
+    tfrt::ResourceContext* resource_context, const Runtime& runtime,
+    const FallbackState& fallback_state,
+    tfrt::RequestDeadlineTracker& req_deadline_tracker);
+
+}  // namespace tfrt_stub
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD
index 925cc14..8276a02 100644
--- a/tensorflow/core/tfrt/saved_model/BUILD
+++ b/tensorflow/core/tfrt/saved_model/BUILD
@@ -56,6 +56,7 @@
         "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink",
         "//tensorflow/core/runtime_fallback/util:tensor_util",
         "//tensorflow/core/tfrt/fallback:fallback_state",
+        "//tensorflow/core/tfrt/graph_executor",
         "//tensorflow/core/tfrt/graph_executor:graph_execution_options",
         "//tensorflow/core/tfrt/runtime",
         "//tensorflow/core/tfrt/runtime:work_queue_interface",
diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc
index d424d8b..e45d37a 100644
--- a/tensorflow/core/tfrt/saved_model/saved_model.cc
+++ b/tensorflow/core/tfrt/saved_model/saved_model.cc
@@ -51,6 +51,7 @@
 #include "tensorflow/core/tfrt/saved_model/saved_model_import_input.h"
 #include "tensorflow/core/tfrt/tpu/tpu_resources.h"
 // TODO(b/200579737): using FunctionRegistry is simpler than the OSS trick.
+#include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
 #include "tensorflow/core/tfrt/utils/bridge_graph_analysis.h"
 #include "tensorflow/core/tfrt/utils/error_util.h"
 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
@@ -119,74 +120,10 @@
         "/tensorflow/tfrt/saved_model/init_time",
         "Record the initialization time for the savedmodel.", "model_name");
 
-constexpr char kDeadlineExceededMessage[] = "Deadline exceeded.";
-
 tensorflow::Tensor CreateScalarStringTensor(absl::string_view str) {
   return tensorflow::Tensor(tensorflow::tstring(str));
 }
 
-struct RequestInfo {
-  tfrt::RCReference<tfrt::RequestContext> tfrt_request_context;
-  std::unique_ptr<WorkQueueInterface> request_queue;
-  std::function<void(std::function<void()>)> runner;
-};
-
-StatusOr<std::unique_ptr<RequestInfo>> SetUpRequestContext(
-    const SavedModel::RunOptions& run_options,
-    const SessionMetadata& model_metadata, tfrt::HostContext* host,
-    WorkQueueInterface* work_queue, tfrt::ResourceContext* resource_context,
-    const FallbackState& fallback_state) {
-  DCHECK(host);
-  DCHECK(work_queue);
-  // Create request context and prepare deadline tracker.
-  // TODO(tfrt-devs): Consider using an ID unique within each model to reduce
-  // contention.
-  tfrt::RequestContextBuilder request_context_builder(host, resource_context,
-                                                      tfrt::GetUniqueInt());
-
-  // TODO(b/198671794): `intra_op_threadpool` should be passed through Run()
-  // directly.
-  tensorflow::thread::ThreadPoolInterface* intra_op_threadpool = nullptr;
-
-  // TODO(b/198671794): The per-request queue should be passed through Run()
-  // directly.
-  TF_ASSIGN_OR_RETURN(auto request_queue,
-                      work_queue->InitializeRequest(&request_context_builder,
-                                                    &intra_op_threadpool));
-
-  auto request_info = std::make_unique<RequestInfo>();
-
-  // If a per-request queue is not provided, use the original queue in the
-  // tensorflow::Executor::Args::Runner.
-  auto* inter_op_queue = request_queue ? request_queue.get() : work_queue;
-  request_info->runner = [inter_op_queue](std::function<void()> f) {
-    inter_op_queue->AddTask(std::move(f));
-  };
-
-  request_info->request_queue = std::move(request_queue);
-
-  TF_RETURN_IF_ERROR(tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
-      &request_context_builder, &fallback_state.device_manager(),
-      &fallback_state.process_function_library_runtime(), intra_op_threadpool,
-      model_metadata, &request_info->runner));
-
-  TF_RETURN_IF_ERROR(
-      tensorflow::SetUpTfCpuRtRequestContext(&request_context_builder));
-  tfrt::RequestOptions request_options;
-  request_options.priority = run_options.priority;
-  request_context_builder.set_request_options(request_options);
-
-  auto expected_req_ctx = std::move(request_context_builder).build();
-  if (!expected_req_ctx) {
-    return tensorflow::errors::Internal(
-        tfrt::StrCat(expected_req_ctx.takeError()));
-  }
-
-  request_info->tfrt_request_context = std::move(expected_req_ctx.get());
-
-  return request_info;
-}
-
 // Create the tensor for the bound input, which can be a variable or an asset.
 //
 // TODO(chky): For V2 models, the bound input can also be a resource.
@@ -277,7 +214,8 @@
       RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_fallback_init"));
 
   for (const auto& init : initializers_and_signatures.initializers) {
-    // TODO(b/184771263): Consider using `RunInternal()` instead.
+    // TODO(b/184771263): Consider using `GraphExecutionRunOnFunction()`
+    // instead.
 
     auto* func = bef_file->GetFunction(init);
     assert(func);
@@ -762,8 +700,10 @@
   }
   DCHECK(func);
 
-  return RunInternal(run_options, name, *func, inputs, captures, outputs,
-                     resource_context);
+  return GraphExecutionRunOnFunction(options_.graph_execution_options,
+                                     run_options, name, *func, inputs, captures,
+                                     outputs, resource_context, runtime(),
+                                     *fallback_state_, req_deadline_tracker_);
 }
 
 namespace {
@@ -900,9 +840,11 @@
 
   std::vector<tensorflow::Tensor> flat_outputs;
 
-  TF_RETURN_IF_ERROR(RunInternal(run_options, loading_result.name, *func,
-                                 flat_inputs, /*captures=*/{}, &flat_outputs,
-                                 loading_result.resource_context.get()));
+  TF_RETURN_IF_ERROR(GraphExecutionRunOnFunction(
+      options_.graph_execution_options, run_options, loading_result.name, *func,
+      flat_inputs, /*captures=*/{}, &flat_outputs,
+      loading_result.resource_context.get(), runtime(), *fallback_state_,
+      req_deadline_tracker_));
 
   // The outputs of the compiled function are in the user-specified order,
   // though they are flattened. So we just need to regroup the outputs for each
@@ -1012,9 +954,11 @@
   }
 
   std::vector<tensorflow::Tensor> flat_outputs;
-  TF_RETURN_IF_ERROR(RunInternal(
-      run_options, loading_result.name, *func, flat_inputs,
-      /*captures=*/{}, &flat_outputs, loading_result.resource_context.get()));
+  TF_RETURN_IF_ERROR(GraphExecutionRunOnFunction(
+      options_.graph_execution_options, run_options, loading_result.name, *func,
+      flat_inputs,
+      /*captures=*/{}, &flat_outputs, loading_result.resource_context.get(),
+      runtime(), *fallback_state_, req_deadline_tracker_));
 
   // Create the outputs from the actual function results, which are sorted
   // according to the output tensor names.
@@ -1204,139 +1148,5 @@
   return LoadJoinedSignature(joined_signature);
 }
 
-tensorflow::Status SavedModelImpl::RunInternal(
-    const RunOptions& run_options, absl::string_view signature_name,
-    const tfrt::Function& func, absl::Span<const tensorflow::Tensor> inputs,
-    absl::Span<const tensorflow::Tensor> captures,
-    std::vector<tensorflow::Tensor>* outputs,
-    tfrt::ResourceContext* resource_context) {
-  auto* host = runtime().core_runtime()->GetHostContext();
-
-  TF_ASSIGN_OR_RETURN(
-      auto request_info,
-      SetUpRequestContext(run_options,
-                          options_.graph_execution_options.model_metadata, host,
-                          run_options.work_queue ? run_options.work_queue
-                                                 : runtime().work_queue(),
-                          resource_context, *fallback_state_));
-
-  tensorflow::profiler::TraceMeProducer traceme(
-      // To TraceMeConsumers in RunHandlerThreadPool::WorkerLoop.
-      [request_id = request_info->tfrt_request_context->id(), signature_name,
-       this] {
-        return tensorflow::profiler::TraceMeEncode(
-            "TfrtModelRun",
-            {{"_r", 1},
-             {"id", request_id},
-             {"signature", signature_name},
-             {"model_id",
-              absl::StrCat(
-                  options_.graph_execution_options.model_metadata.name(),
-                  options_.graph_execution_options.model_metadata.version())}});
-      },
-      tensorflow::profiler::ContextType::kTfrtExecutor,
-      request_info->tfrt_request_context->id());
-
-  // Only configure timer when the deadline is set.
-  if (run_options.deadline.has_value()) {
-    auto deadline = run_options.deadline.value();
-    if (absl::ToChronoTime(absl::Now()) > deadline) {
-      return tensorflow::errors::DeadlineExceeded(kDeadlineExceededMessage);
-    }
-    req_deadline_tracker_.CancelRequestOnDeadline(
-        deadline, request_info->tfrt_request_context);
-  }
-
-  tfrt::ExecutionContext exec_ctx{request_info->tfrt_request_context};
-  if (run_options.work_queue) {
-    // TODO(b/198671794): Avoid creating `request_queue` when the `work_queue`
-    // in `run_options` is specified.
-    exec_ctx.set_work_queue(run_options.work_queue);
-  } else if (request_info->request_queue) {
-    exec_ctx.set_work_queue(request_info->request_queue.get());
-  } else {
-    exec_ctx.set_work_queue(runtime().work_queue());
-  }
-
-  llvm::SmallVector<tfrt::AsyncValue*, 4> arguments;
-  auto cleanup = tensorflow::gtl::MakeCleanup([&]() {
-    for (auto* argument : arguments) argument->DropRef();
-  });
-
-  // The first argument is a chain for side-effects. Since SavedModel::Run()
-  // only returns when side-effects are visible, we can use a ready chain here.
-  arguments.push_back(tfrt::GetReadyChain().release());
-
-  for (const auto& input : inputs) {
-    arguments.push_back(
-        tfrt::MakeAvailableAsyncValueRef<FallbackTensor>(input).release());
-  }
-
-  DCHECK(captures.empty()) << "signature should have no captures, which is "
-                              "guaranteed by the compiler";
-
-  if (arguments.size() != func.argument_types().size())
-    return tensorflow::errors::Internal("incorrect number of inputs.");
-
-  llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 4> chain_and_results;
-  chain_and_results.resize(func.result_types().size());
-
-  // Hand over the execution to thread pool.
-  std::array<tfrt::RCReference<tfrt::AsyncValue>, 1> executed = {
-      EnqueueWork(exec_ctx, [&]() -> tfrt::Chain {
-        func.Execute(exec_ctx, arguments, chain_and_results);
-        return {};
-      })};
-
-  // Wait for the function execution before checking chain and results.
-  exec_ctx.work_queue().Await(executed);
-
-  // Wait for all results including the side-effect chain. This ensures that all
-  // side-effects are visible when SavedModel::Run() returns.
-  exec_ctx.work_queue().Await(chain_and_results);
-
-  DCHECK(!chain_and_results.empty());
-
-  tfrt::RCReference<tfrt::AsyncValue>& chain = chain_and_results[0];
-  auto results = llvm::drop_begin(chain_and_results, 1);
-
-  tensorflow::StatusGroup status_group;
-
-  if (chain->IsError()) {
-    status_group.Update(CreateTfErrorStatus(chain->GetError()));
-  }
-
-  for (tfrt::RCReference<tfrt::AsyncValue>& result : results) {
-    DCHECK(result->IsAvailable());
-
-    if (result->IsError()) {
-      status_group.Update(CreateTfErrorStatus(result->GetError()));
-      outputs->push_back(tensorflow::Tensor());
-      continue;
-    }
-
-    // The result must be a host tensor. This is guaranteed as the compiler
-    // will insert necessary device transfer operations in the graph.
-    DCHECK(result->IsType<FallbackTensor>());
-    const auto& host_tensor = result->get<FallbackTensor>().tensor();
-    // Make a copy of tensor here as the different result AsyncValues might
-    // point to the same underlying tensor.
-    outputs->push_back(host_tensor);
-  }
-
-  // TODO(b/171926578): Explicitly clear the context data. Remove it after the
-  // b/171926578 is fixed.
-  exec_ctx.request_ctx()->ClearData();
-
-  // Check if error is due to cancellation.
-  // TODO(tfrt-devs): report cancellation reason from runtime.
-  if (request_info->tfrt_request_context->IsCancelled()) {
-    // Currently a request can only be cancelled by an expired timer.
-    return tensorflow::errors::DeadlineExceeded(kDeadlineExceededMessage);
-  }
-
-  return status_group.as_summary_status();
-}
-
 }  // namespace tfrt_stub
 }  // namespace tensorflow
diff --git a/tensorflow/core/tfrt/tpu/BUILD b/tensorflow/core/tfrt/tpu/BUILD
index 8f63c6d..d20c1f5 100644
--- a/tensorflow/core/tfrt/tpu/BUILD
+++ b/tensorflow/core/tfrt/tpu/BUILD
@@ -8,6 +8,7 @@
     name = "tpu_resources",
     hdrs = ["tpu_resources.h"],
     visibility = [
+        "//tensorflow/core/tfrt/graph_executor:__pkg__",
         "//tensorflow/core/tfrt/saved_model:__pkg__",
     ],
     deps = if_google([