[xla:runtime] NFC: Clean up jitrt <-> xla namespace re-export

Remove xla.h header and rename all uses of ::tfrt::jitrt symbols to ::xla::runtime.

PiperOrigin-RevId: 467853273
diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD
index f2336b0..2435108 100644
--- a/tensorflow/compiler/mlir/tfrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/BUILD
@@ -170,9 +170,12 @@
 #     srcs = ["jit/tf_jitrt_test.cc"],
 #     deps = [
 #         ":tf_jitrt",
+#         "//tensorflow/compiler/xla/runtime:results",
+#         "//tensorflow/compiler/xla/runtime:types",
 #         "@com_google_googletest//:gtest_main",
 #         "@llvm-project//mlir:mlir_c_runner_utils",
 #         "@tf_runtime//backends/jitrt",
+#         "@tf_runtime//backends/jitrt:results",
 #     ],
 # )
 # copybara:uncomment_end
@@ -204,7 +207,11 @@
         "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
         "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes",
         "//tensorflow/compiler/xla/mlir/utils/runtime:async_runtime_api",
+        "//tensorflow/compiler/xla/runtime:arguments",
         "//tensorflow/compiler/xla/runtime:async_runtime",
+        "//tensorflow/compiler/xla/runtime:executable",
+        "//tensorflow/compiler/xla/runtime:jit_executable",
+        "//tensorflow/compiler/xla/runtime:types",
         "//tensorflow/core:framework",
         "//tensorflow/core:platform_base",
         "//tensorflow/core/platform:dynamic_annotations",
@@ -224,6 +231,7 @@
         "@tf_runtime//:tracing",
         "@tf_runtime//backends/jitrt",
         "@tf_runtime//backends/jitrt:jitrt_compiler",
+        "@tf_runtime//backends/jitrt:results",
     ],
 )
 
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD b/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD
index cff9311..72ba919 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD
@@ -17,6 +17,9 @@
     deps = [
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tfrt:tf_jitrt_pipeline",
+        "//tensorflow/compiler/xla/runtime:arguments",
+        "//tensorflow/compiler/xla/runtime:jit_executable",
+        "//tensorflow/compiler/xla/runtime:types",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/platform:logging",
@@ -30,8 +33,8 @@
         "@tf_runtime//:hostcontext",
         "@tf_runtime//:support",
         "@tf_runtime//:tensor",
-        "@tf_runtime//backends/jitrt",
         "@tf_runtime//backends/jitrt:jitrt_compiler",
+        "@tf_runtime//backends/jitrt:results",
     ],
 )
 
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc
index 6b4ac21..889e172 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc
@@ -32,7 +32,7 @@
 
 using ::tfrt::HostContext;
 using ::tfrt::jitrt::CompilationPipelineOptions;
-using ::tfrt::jitrt::MemrefType;
+using ::xla::runtime::MemrefType;
 
 const bool kStaticDim = false;
 const bool kDynamicDim = true;
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h
index 42ba84d..7344a93 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h
@@ -27,8 +27,11 @@
 #include "llvm/Support/MemoryBuffer.h"
 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h"
+#include "tensorflow/compiler/xla/runtime/arguments.h"
+#include "tensorflow/compiler/xla/runtime/jit_executable.h"
+#include "tensorflow/compiler/xla/runtime/types.h"
 #include "tensorflow/core/platform/test_benchmark.h"
-#include "tfrt/jitrt/jitrt.h"  // from @tf_runtime
+#include "tfrt/jitrt/results.h"  // from @tf_runtime
 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
@@ -42,9 +45,9 @@
 
 using ::tfrt::HostContext;
 using ::tfrt::RemainingResults;
-using ::tfrt::jitrt::JitExecutable;
-using ::tfrt::jitrt::MemrefDesc;
-using ::tfrt::jitrt::Type;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
+using ::xla::runtime::Type;
 
 // Constants to make shape specification more readable.
 // kStaticDim refers to the static shape in IR taken from ARGS of the benchmark.
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc
index 4d1c250..5b01a98 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc
@@ -39,12 +39,13 @@
 using ::tfrt::RequestContext;
 using ::tfrt::RequestContextBuilder;
 
-using ::tfrt::jitrt::Executable;
-using ::tfrt::jitrt::HostContextAsyncTaskRunner;
-using ::tfrt::jitrt::JitExecutable;
-using ::tfrt::jitrt::MemrefDesc;
 using ::tfrt::jitrt::RemainingResultsConverter;
 
+using ::xla::runtime::Executable;
+using ::xla::runtime::HostContextAsyncTaskRunner;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
+
 // Returns random tensors generated based on the input specs.
 static llvm::SmallVector<Tensor> GetInputTensors(
     llvm::ArrayRef<InputTensorSpec> input_specs) {
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h
index fc74378..e2576a7 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h
@@ -35,12 +35,14 @@
 using ::tfrt::RemainingResults;
 using ::tfrt::RequestContext;
 using ::tfrt::RequestContextBuilder;
-using ::tfrt::jitrt::Executable;
-using ::tfrt::jitrt::HostContextAsyncTaskRunner;
-using ::tfrt::jitrt::JitExecutable;
-using ::tfrt::jitrt::MemrefDesc;
+
 using ::tfrt::jitrt::RemainingResultsConverter;
 
+using ::xla::runtime::Executable;
+using ::xla::runtime::HostContextAsyncTaskRunner;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
+
 // -------------------------------------------------------------------------- //
 // Run benchmark by compiling MLIR function using TFRT JitRt API.
 // -------------------------------------------------------------------------- //
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h
index 96da5f4..817e115 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h
@@ -34,11 +34,11 @@
 using ::tfrt::RemainingResults;
 using ::tfrt::RequestContext;
 using ::tfrt::RequestContextBuilder;
-using ::tfrt::jitrt::Executable;
-using ::tfrt::jitrt::HostContextAsyncTaskRunner;
-using ::tfrt::jitrt::JitExecutable;
-using ::tfrt::jitrt::MemrefDesc;
 using ::tfrt::jitrt::RemainingResultsConverter;
+using ::xla::runtime::Executable;
+using ::xla::runtime::HostContextAsyncTaskRunner;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
 
 // -------------------------------------------------------------------------- //
 // Run benchmark by compiling MLIR function using TFRT JitRt API.
diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD b/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD
index ec5143b..f80be60 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD
@@ -43,6 +43,8 @@
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tfrt:tf_jitrt_pipeline",
         "//tensorflow/compiler/mlir/tfrt/python_tests:python_test_attrs_registration",
+        "//tensorflow/compiler/xla/runtime:executable",
+        "//tensorflow/compiler/xla/runtime:jit_executable",
         "//tensorflow/core/platform:dynamic_annotations",
         "//third_party/eigen3",
         "//third_party/python_runtime:headers",  # build_cleaner: keep
@@ -55,6 +57,7 @@
         "@tf_runtime//:support",
         "@tf_runtime//backends/jitrt",
         "@tf_runtime//backends/jitrt:jitrt_compiler",
+        "@tf_runtime//backends/jitrt:results",
     ],
 )
 
@@ -94,6 +97,5 @@
     deps = [
         "//tensorflow/compiler/xla/runtime:arguments",
         "@tf_runtime//:dtype",
-        "@tf_runtime//backends/jitrt:xla",
     ],
 )
diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc
index 76b40d0..029e217 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc
@@ -30,6 +30,7 @@
 #include "tensorflow/core/platform/dynamic_annotations.h"
 #include "tfrt/jitrt/jitrt.h"  // from @tf_runtime
 #include "tfrt/jitrt/jitrt_compiler.h"  // from @tf_runtime
+#include "tfrt/jitrt/results.h"  // from @tf_runtime
 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
@@ -55,14 +56,15 @@
 
 using ::tfrt::jitrt::CompilationPipelineOptions;
 using ::tfrt::jitrt::CreateDefaultJitRtCompilationPipeline;
-using ::tfrt::jitrt::Executable;
-using ::tfrt::jitrt::HostContextAsyncTaskRunner;
-using ::tfrt::jitrt::JitExecutable;
-using ::tfrt::jitrt::MemrefDesc;
 using ::tfrt::jitrt::RegisterDefaultJitRtDialects;
 using ::tfrt::jitrt::RemainingResultsConverter;
 using ::tfrt::jitrt::ReturnStridedMemref;
 
+using ::xla::runtime::Executable;
+using ::xla::runtime::HostContextAsyncTaskRunner;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
+
 namespace tensorflow {
 
 TfJitRtExecutor::TfJitRtExecutor()
diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h
index 314e6f3..0288fbd 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h
+++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h
@@ -23,7 +23,8 @@
 #include "pybind11/numpy.h"
 #include "pybind11/pybind11.h"
 #include "pybind11/stl.h"
-#include "tfrt/jitrt/jitrt.h"  // from @tf_runtime
+#include "tensorflow/compiler/xla/runtime/executable.h"
+#include "tensorflow/compiler/xla/runtime/jit_executable.h"
 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
 
 namespace tensorflow {
@@ -34,7 +35,7 @@
 class TfJitRtExecutor {
  public:
   using Handle = int64_t;
-  using Specialization = tfrt::jitrt::JitExecutable::Specialization;
+  using Specialization = xla::runtime::JitExecutable::Specialization;
 
   TfJitRtExecutor();
 
@@ -57,7 +58,7 @@
 
  private:
   tfrt::HostContext host_context_;
-  llvm::DenseMap<Handle, tfrt::jitrt::JitExecutable> jit_executables_;
+  llvm::DenseMap<Handle, xla::runtime::JitExecutable> jit_executables_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc
index 5e0c0ad..6db9221 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc
@@ -35,7 +35,7 @@
 
 namespace py = pybind11;
 
-using ::tfrt::jitrt::MemrefDesc;
+using ::xla::runtime::MemrefDesc;
 
 static py::array ConvertTensorToPyArray(const Tensor& tensor) {
   auto tensor_sizes = tensor.shape().dim_sizes();
diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc
index b1feca3..a1d490f 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc
@@ -32,7 +32,11 @@
 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h"
 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
 #include "tensorflow/compiler/xla/mlir/utils/runtime/async_runtime_api.h"
+#include "tensorflow/compiler/xla/runtime/arguments.h"
 #include "tensorflow/compiler/xla/runtime/async_runtime.h"
+#include "tensorflow/compiler/xla/runtime/executable.h"
+#include "tensorflow/compiler/xla/runtime/jit_executable.h"
+#include "tensorflow/compiler/xla/runtime/types.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/platform/dynamic_annotations.h"
@@ -42,6 +46,7 @@
 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
 #include "tfrt/jitrt/jitrt.h"  // from @tf_runtime
 #include "tfrt/jitrt/jitrt_compiler.h"  // from @tf_runtime
+#include "tfrt/jitrt/results.h"  // from @tf_runtime
 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
@@ -100,22 +105,23 @@
 using ::tfrt::StringAttribute;
 using ::tfrt::TaskFunction;
 
-using ::tfrt::jitrt::ArgumentConstraint;
-using ::tfrt::jitrt::ArgumentsRef;
 using ::tfrt::jitrt::CompilationPipelineOptions;
 using ::tfrt::jitrt::CreateDefaultJitRtCompilationPipeline;
-using ::tfrt::jitrt::EigenThreadPoolAsyncTaskRunner;
-using ::tfrt::jitrt::Executable;
-using ::tfrt::jitrt::JitExecutable;
 using ::tfrt::jitrt::JitExecutableCache;
-using ::tfrt::jitrt::MemrefDesc;
 using ::tfrt::jitrt::RegisterDefaultJitRtDialects;
 using ::tfrt::jitrt::ReturnErrors;
 using ::tfrt::jitrt::ReturnStridedMemref;
 using ::tfrt::jitrt::ReturnValueConversion;
-using ::tfrt::jitrt::SpecializationListener;
 using ::tfrt::jitrt::StaticRemainingResultsConverter;
 
+using ::xla::runtime::ArgumentConstraint;
+using ::xla::runtime::ArgumentsRef;
+using ::xla::runtime::EigenThreadPoolAsyncTaskRunner;
+using ::xla::runtime::Executable;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
+using ::xla::runtime::SpecializationListener;
+
 using ::tensorflow::profiler::TraceMe;
 using ::tensorflow::profiler::TraceMeEncode;
 using ::tensorflow::tfd::KernelFallbackCompatRequestState;
diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc
index e294e88..c9981e0 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc
@@ -22,6 +22,9 @@
 #include "mlir/ExecutionEngine/CRunnerUtils.h"
 #include "testing/base/public/benchmark.h"
 #include <gtest/gtest.h>
+#include "tensorflow/compiler/xla/runtime/results.h"
+#include "tensorflow/compiler/xla/runtime/types.h"
+#include "tfrt/jitrt/results.h"  // from @tf_runtime
 
 namespace tensorflow {
 
@@ -30,11 +33,12 @@
 using ::tfrt::RCReference;
 using ::tfrt::RemainingResults;
 
-using ::tfrt::jitrt::MemrefType;
 using ::tfrt::jitrt::ReturnStridedMemref;
 using ::tfrt::jitrt::ReturnValueConversion;
 using ::tfrt::jitrt::StaticRemainingResultsConverter;
 
+using ::xla::runtime::MemrefType;
+
 using ReturnTensorflowTensor =
     ReturnValueConversion<TensorflowConversionContext,
                           ReturnStridedMemref<ConvertTensor>>;
diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD b/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD
index ad84270..8310c25 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD
@@ -29,6 +29,8 @@
     deps = [
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tfrt:tf_jitrt_pipeline",
+        "//tensorflow/compiler/xla/runtime:executable",
+        "//tensorflow/compiler/xla/runtime:jit_executable",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "@llvm-project//llvm:Support",
diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc
index 5a8997f..14941d0 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc
+++ b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc
@@ -19,6 +19,8 @@
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h"
+#include "tensorflow/compiler/xla/runtime/executable.h"
+#include "tensorflow/compiler/xla/runtime/jit_executable.h"
 #include "tensorflow/core/platform/test_benchmark.h"
 #include "tfrt/jitrt/jitrt.h"  // from @tf_runtime
 #include "tfrt/jitrt/jitrt_compiler.h"  // from @tf_runtime
@@ -59,10 +61,11 @@
 
 using ::tfrt::jitrt::CompilationPipelineOptions;
 using ::tfrt::jitrt::CreateDefaultJitRtCompilationPipeline;
-using ::tfrt::jitrt::JitExecutable;
 using ::tfrt::jitrt::JitExecutableCache;
 using ::tfrt::jitrt::RegisterDefaultJitRtDialects;
 
+using ::xla::runtime::JitExecutable;
+
 static void BM_InstantiateExecutable(::testing::benchmark::State& state) {
   // Options for the default JitRt compilation pipeline (lowering to LLVM).
   CompilationPipelineOptions copts;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 2ed28a1..79f111d 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -591,6 +591,10 @@
 #         "@llvm-project//llvm:OrcJIT",
 #         "@llvm-project//mlir:Support",
 #         "//tensorflow/compiler/xla:tfrt_utils",
+#         "//tensorflow/compiler/xla/runtime:arguments",
+#         "//tensorflow/compiler/xla/runtime:types",
+#         "//tensorflow/compiler/xla/runtime:executable",
+#         "//tensorflow/compiler/xla/runtime:jit_executable",
 #         "//tensorflow/compiler/xla:shape_util",
 #         "//tensorflow/compiler/xla/service:custom_call_status_internal",
 #         "//tensorflow/compiler/xla/service:custom_call_target_registry",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 6a8f714..2872649 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -1435,27 +1435,27 @@
     copts.num_worker_threads = 1;
 
     // Options for constructing JitRt JitExecutable.
-    jitrt::JitExecutable::Options opts;
-    opts.specialization = jitrt::JitExecutable::Specialization::kDisabled;
+    runtime::JitExecutable::Options opts;
+    opts.specialization = runtime::JitExecutable::Specialization::kDisabled;
     opts.compiler.register_dialects = jitrt::RegisterDefaultJitRtDialects;
 
     // Register JitRt Gpu runtime custom calls with the linker.
     opts.compiler.runtime_symbol_map =
-        jitrt::GetSymbolsBinding(JitRtGpuCustomCalls());
+        runtime::GetSymbolsBinding(JitRtGpuCustomCalls());
 
     opts.compiler.create_compilation_pipeline = [copts](mlir::PassManager& pm) {
       jitrt::CreateDefaultJitRtCompilationPipeline(pm, copts);
     };
 
     // Instantiate new JitExecutable from the MLIR source.
-    auto jit_executable = jitrt::JitExecutable::Instantiate(
+    auto jit_executable = runtime::JitExecutable::Instantiate(
         program->module, program->entry_point, opts);
     if (auto err = jit_executable.takeError())
       return InternalError("Failed to compile JitRt program: %s",
                            tfrt::StrCat(err));
 
     // For static shapes we can always serialize only the default executable.
-    jitrt::Executable& executable = jit_executable->DefaultExecutable().get();
+    runtime::Executable& executable = jit_executable->DefaultExecutable().get();
 
     // Check if JitRt executable saved the compilation result.
     std::unique_ptr<llvm::MemoryBuffer> obj_file = executable.obj_file();
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index f4ceee0..7e9e107 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -103,8 +103,8 @@
     copts.populate_attr_encodings = PopulateLmhloToXlaAttrEncoding;
 
     // Options for constructing JitRt JitExecutable.
-    jitrt::JitExecutable::Options opts;
-    opts.specialization = jitrt::JitExecutable::Specialization::kDisabled;
+    runtime::JitExecutable::Options opts;
+    opts.specialization = runtime::JitExecutable::Specialization::kDisabled;
     opts.compiler.register_dialects = [](mlir::DialectRegistry& registry) {
       jitrt::RegisterDefaultJitRtDialects(registry);
       // For the encoding of attributes to custom calls.
@@ -113,7 +113,7 @@
 
     // Register JitRt Gpu runtime custom calls with the linker.
     opts.compiler.runtime_symbol_map =
-        jitrt::GetSymbolsBinding(JitRtGpuCustomCalls());
+        runtime::GetSymbolsBinding(JitRtGpuCustomCalls());
 
     // We just use the default compilation pipeline provided by the JitRt.
     // Alternatively instead of having a separate JitRtProgram (LMHLO lowered to
@@ -131,7 +131,7 @@
     opts.compiler.jit_code_opt_level = llvm::CodeGenOpt::None;
 
     // Instantiate new JitExecutable from the MLIR source.
-    auto jit_executable = jitrt::JitExecutable::Instantiate(
+    auto jit_executable = runtime::JitExecutable::Instantiate(
         program->module, program->entry_point, opts);
     if (auto err = jit_executable.takeError())
       return InternalError("Failed to compile JitRt program: %s",
@@ -140,22 +140,22 @@
     // Pass ownership to the GpuExecutable.
     return new JitRtExecutable(
         std::move(program->buffer_sizes),
-        std::make_unique<jitrt::JitExecutable>(std::move(*jit_executable)),
+        std::make_unique<runtime::JitExecutable>(std::move(*jit_executable)),
         std::move(program->debug_options));
   }
 
   // Create JitRtExecutable from the AOT compiled binary.
   static StatusOr<JitRtExecutable*> Create(
-      absl::Span<const int64_t> buffer_sizes, jitrt::Executable executable,
+      absl::Span<const int64_t> buffer_sizes, runtime::Executable executable,
       DebugOptions debug_options) {
     // Pass ownership to the GpuExecutable.
     return new JitRtExecutable(
         std::vector<int64_t>(buffer_sizes.begin(), buffer_sizes.end()),
-        std::make_unique<jitrt::Executable>(std::move(executable)),
+        std::make_unique<runtime::Executable>(std::move(executable)),
         std::move(debug_options));
   }
 
-  jitrt::Executable& default_executable() { return *default_executable_; }
+  runtime::Executable& default_executable() { return *default_executable_; }
   JitRtKernelsCache& kernels_cache() { return kernels_cache_; }
   JitRtGemmConfigCache& gemm_configs_cache() { return gemm_configs_cache_; }
   JitRtCollectiveSupport& collectives() { return collectives_; }
@@ -170,7 +170,7 @@
 
  private:
   JitRtExecutable(std::vector<int64_t> buffer_sizes,
-                  std::unique_ptr<jitrt::JitExecutable> jit_executable,
+                  std::unique_ptr<runtime::JitExecutable> jit_executable,
                   DebugOptions debug_options)
       : buffer_sizes_(std::move(buffer_sizes)),
         jit_executable_(std::move(jit_executable)),
@@ -178,7 +178,7 @@
         debug_options_(std::move(debug_options)) {}
 
   JitRtExecutable(std::vector<int64_t> buffer_sizes,
-                  std::unique_ptr<jitrt::Executable> aot_executable,
+                  std::unique_ptr<runtime::Executable> aot_executable,
                   DebugOptions debug_options)
       : buffer_sizes_(std::move(buffer_sizes)),
         aot_executable_(std::move(aot_executable)),
@@ -187,11 +187,11 @@
 
   std::vector<int64_t> buffer_sizes_;
 
-  std::unique_ptr<jitrt::JitExecutable> jit_executable_;
-  jitrt::Executable* default_executable_;  // owned by `jit_executable`
+  std::unique_ptr<runtime::JitExecutable> jit_executable_;
+  runtime::Executable* default_executable_;  // owned by `jit_executable`
 
-  std::unique_ptr<jitrt::Executable> aot_executable_;
-  jitrt::Executable* executable_;
+  std::unique_ptr<runtime::Executable> aot_executable_;
+  runtime::Executable* executable_;
 
   DebugOptions debug_options_;
 
@@ -615,7 +615,7 @@
   // compiled function will make a copy of all arguments and will write all
   // results after the call to `Execute` completes, so it is safe to keep in on
   // the stack.
-  jitrt::Executable::CallFrame call_frame;
+  runtime::Executable::CallFrame call_frame;
 
   // Each buffer allocation pased as 1d memref to the compiled kernel:
   //   {basePtr, dataPtr, offset, [sizes, ...], [strides, ...]}
@@ -648,14 +648,14 @@
   }
 
   // JitRt executables do not return any values.
-  jitrt::NoResultConverter converter;
+  runtime::NoResultConverter converter;
 
   // Prepare options for executing JitRt program.
-  jitrt::Executable::ExecuteOpts opts;
+  runtime::Executable::ExecuteOpts opts;
 
   // We don't expect to see any async tasks in the JitRt executable.
   opts.async_task_runner =
-      reinterpret_cast<jitrt::AsyncTaskRunner*>(0XDEADBEEF);
+      reinterpret_cast<runtime::AsyncTaskRunner*>(0XDEADBEEF);
 
   // Get the async communications stream for async collectives.
   int device_ordinal = run_options->stream()->parent()->device_ordinal();
@@ -669,7 +669,7 @@
       async_comms_stream.ok() ? async_comms_stream->get() : nullptr);
 
   // Pass auxiliary data to the custom call handlers.
-  jitrt::CustomCall::UserData user_data;
+  runtime::CustomCall::UserData user_data;
   user_data.insert_all(
       run_options, &jitrt_executable->debug_options(),
       &jitrt_executable->kernels_cache(),
@@ -678,9 +678,9 @@
   opts.custom_call_data = &user_data;
 
   // Collect all emitted diagnostic messages.
-  jitrt::DiagnosticEngine diagnostic_engine;
+  runtime::DiagnosticEngine diagnostic_engine;
   std::string diagnostic;
-  diagnostic_engine.AddHandler([&](jitrt::Diagnostic& d) {
+  diagnostic_engine.AddHandler([&](runtime::Diagnostic& d) {
     llvm::raw_string_ostream(diagnostic) << d.str();
     return mlir::success();
   });
@@ -689,7 +689,7 @@
 
   // Get the default executable. We do not support specialization because
   // all shapes are static. Default executable is guaranteed to be available.
-  jitrt::Executable& executable = jitrt_executable->default_executable();
+  runtime::Executable& executable = jitrt_executable->default_executable();
 
   // Execute with the prepared call frame.
   executable.Execute(call_frame, opts);
@@ -1095,24 +1095,24 @@
   auto buffer = llvm::MemoryBuffer::getMemBuffer(data, hlo_module->name());
 
   // Create a JitRt function signature (all arguments passed as 1d memrefs).
-  llvm::SmallVector<std::unique_ptr<jitrt::Type>> args;
-  llvm::SmallVector<std::unique_ptr<jitrt::Type>> rt_args;
-  rt_args.push_back(std::make_unique<jitrt::KernelContextOperandType>());
+  llvm::SmallVector<std::unique_ptr<runtime::Type>> args;
+  llvm::SmallVector<std::unique_ptr<runtime::Type>> rt_args;
+  rt_args.push_back(std::make_unique<runtime::KernelContextOperandType>());
 
   for (int64_t size : buffer_sizes) {
     auto i8 = tfrt::DType::I8;
-    args.push_back(std::make_unique<jitrt::MemrefType>(size, i8));
-    rt_args.push_back(std::make_unique<jitrt::MemrefType>(size, i8));
+    args.push_back(std::make_unique<runtime::MemrefType>(size, i8));
+    rt_args.push_back(std::make_unique<runtime::MemrefType>(size, i8));
   }
 
-  jitrt::FunctionType signature(std::move(args), /*results=*/{});
-  jitrt::FunctionType rt_signature(std::move(rt_args), /*results=*/{});
+  runtime::FunctionType signature(std::move(args), /*results=*/{});
+  runtime::FunctionType rt_signature(std::move(rt_args), /*results=*/{});
 
-  auto symbol_map = jitrt::GetSymbolsBinding(JitRtGpuCustomCalls());
+  auto symbol_map = runtime::GetSymbolsBinding(JitRtGpuCustomCalls());
 
   // Load JitRt executable from an object file, and link it with Gpu runtime
   // intrinsics implementing Gpu custom calls.
-  auto executable = jitrt::Executable::LoadFromObjFile(
+  auto executable = runtime::Executable::LoadFromObjFile(
       hlo_module->name(), std::move(buffer),
       hlo_module->entry_computation()->name(), std::move(signature),
       std::move(rt_signature), symbol_map);
diff --git a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
index 352fcc5..0a3c72f 100644
--- a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
+++ b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
@@ -23,8 +23,12 @@
 
 #include "llvm/ExecutionEngine/Orc/Mangling.h"
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/xla/runtime/arguments.h"
 #include "tensorflow/compiler/xla/runtime/custom_call.h"
+#include "tensorflow/compiler/xla/runtime/executable.h"
+#include "tensorflow/compiler/xla/runtime/jit_executable.h"
 #include "tensorflow/compiler/xla/runtime/type_id.h"
+#include "tensorflow/compiler/xla/runtime/types.h"
 #include "tensorflow/compiler/xla/service/custom_call_status_internal.h"
 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
@@ -45,7 +49,6 @@
 #include "tensorflow/core/platform/human_readable_json.h"
 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
 #include "tensorflow/stream_executor/gpu/gpu_types.h"
-#include "tfrt/jitrt/jitrt.h"  // from @tf_runtime
 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -53,17 +56,17 @@
 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
                                    xla::gpu::JitRtKernelsCache);
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
                                    xla::gpu::JitRtGemmConfigCache);
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
                                    xla::gpu::JitRtCollectiveSupport);
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
                                    xla::gpu::JitRtAsyncCollectiveSupport);
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
                                    const xla::ServiceExecutableRunOptions);
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
                                    const xla::DebugOptions);
 
 namespace xla {
@@ -83,12 +86,12 @@
 using mlir::success;
 
 using tfrt::MakeStringError;
-using tfrt::jitrt::CustomCall;
-using tfrt::jitrt::DirectCustomCallLibrary;
-using tfrt::jitrt::Executable;
+
+using ::xla::runtime::CustomCall;
+using ::xla::runtime::DirectCustomCallLibrary;
+using ::xla::runtime::Executable;
 
 namespace se = ::stream_executor;
-namespace jitrt = ::tfrt::jitrt;
 namespace lmhlo_gpu = ::mlir::lmhlo_gpu;
 namespace mhlo = ::mlir::mhlo;
 
@@ -122,24 +125,23 @@
 // Add custom call arguments and attributes encoding for custom HLO enums and
 // structs, so that we can pass them to custom calls.
 void PopulateLmhloToXlaAttrEncoding(
-    jitrt::CustomCallAttrEncodingSet& encoding) {
-  encoding.Add<
-      jitrt::EnumAttrEncoding<lmhlo_gpu::ActivationAttr, lmhlo_gpu::Activation,
-                              se::dnn::ActivationMode>>(
+    tfrt::jitrt::CustomCallAttrEncodingSet& encoding) {
+  encoding.Add<tfrt::jitrt::EnumAttrEncoding<lmhlo_gpu::ActivationAttr,
+                                             lmhlo_gpu::Activation,
+                                             se::dnn::ActivationMode>>(
       [](lmhlo_gpu::Activation value) -> se::dnn::ActivationMode {
         return ConvertConvActivationMode(value).value();
       });
 
-  encoding.Add<jitrt::EnumAttrEncoding<lmhlo_gpu::CublasLtMatmulEpilogueAttr,
-                                       lmhlo_gpu::CublasLtMatmulEpilogue,
-                                       se::cuda::BlasLt::Epilogue>>(
-      [](lmhlo_gpu::CublasLtMatmulEpilogue value)
-          -> se::cuda::BlasLt::Epilogue {
-        return cublas_lt::AsBlasLtEpilogue(value).value();
-      });
+  encoding.Add<tfrt::jitrt::EnumAttrEncoding<
+      lmhlo_gpu::CublasLtMatmulEpilogueAttr, lmhlo_gpu::CublasLtMatmulEpilogue,
+      se::cuda::BlasLt::Epilogue>>([](lmhlo_gpu::CublasLtMatmulEpilogue value)
+                                       -> se::cuda::BlasLt::Epilogue {
+    return cublas_lt::AsBlasLtEpilogue(value).value();
+  });
 
-  encoding.Add<
-      jitrt::EnumAttrEncoding<mhlo::FftTypeAttr, mhlo::FftType, se::fft::Type>>(
+  encoding.Add<tfrt::jitrt::EnumAttrEncoding<mhlo::FftTypeAttr, mhlo::FftType,
+                                             se::fft::Type>>(
       [](mhlo::FftType value) -> se::fft::Type {
         switch (value) {
           case mhlo::FftType::FFT:
@@ -156,9 +158,10 @@
       });
 
   using DotDimsAttr = mhlo::DotDimensionNumbersAttr;
-  encoding.Add<jitrt::AggregateAttrEncoding<DotDimsAttr, DotDimensionNumbers>>(
+  encoding.Add<
+      tfrt::jitrt::AggregateAttrEncoding<DotDimsAttr, DotDimensionNumbers>>(
       encoding,
-      jitrt::AggregateAttrDef<DotDimsAttr>()
+      tfrt::jitrt::AggregateAttrDef<DotDimsAttr>()
           .Add("lhs_batch", &DotDimsAttr::getLhsBatchingDimensions)
           .Add("lhs_contract", &DotDimsAttr::getLhsContractingDimensions)
           .Add("rhs_batch", &DotDimsAttr::getRhsBatchingDimensions)
@@ -166,9 +169,9 @@
 
   using ConvDimsAttr = mhlo::ConvDimensionNumbersAttr;
   encoding.Add<
-      jitrt::AggregateAttrEncoding<ConvDimsAttr, ConvDimensionNumbers>>(
+      tfrt::jitrt::AggregateAttrEncoding<ConvDimsAttr, ConvDimensionNumbers>>(
       encoding,
-      jitrt::AggregateAttrDef<ConvDimsAttr>()
+      tfrt::jitrt::AggregateAttrDef<ConvDimsAttr>()
           .Add("input_batch_dim", &ConvDimsAttr::getInputBatchDimension)
           .Add("input_feature_dim", &ConvDimsAttr::getInputFeatureDimension)
           .Add("input_spatial_dims", &ConvDimsAttr::getInputSpatialDimensions)
@@ -183,9 +186,10 @@
                &ConvDimsAttr::getOutputSpatialDimensions));
 
   using ConvConfigAttr = lmhlo_gpu::ConvolutionBackendConfigAttr;
-  encoding.Add<jitrt::AggregateAttrEncoding<ConvConfigAttr, ConvBackendConfig>>(
+  encoding.Add<
+      tfrt::jitrt::AggregateAttrEncoding<ConvConfigAttr, ConvBackendConfig>>(
       encoding,
-      jitrt::AggregateAttrDef<ConvConfigAttr>()
+      tfrt::jitrt::AggregateAttrDef<ConvConfigAttr>()
           .Add("algorithm", &ConvConfigAttr::getAlgorithm)
           .Add("tensor_ops_enabled", &ConvConfigAttr::getTensorOpsEnabled)
           .Add("is_cudnn_frontend", &ConvConfigAttr::getIsCudnnFrontend)
@@ -230,7 +234,7 @@
   return se::DeviceMemoryBase(memref.data, size);
 }
 
-static se::DeviceMemoryBase GetDeviceAddress(jitrt::FlatMemrefView& memref) {
+static se::DeviceMemoryBase GetDeviceAddress(runtime::FlatMemrefView& memref) {
   return se::DeviceMemoryBase(memref.data, memref.size_in_bytes);
 }
 
@@ -295,7 +299,7 @@
 
 // -------------------------------------------------------------------------- //
 
-static Shape ToShape(const jitrt::StridedMemrefView& memref) {
+static Shape ToShape(const runtime::StridedMemrefView& memref) {
   PrimitiveType type = TfrtToPrimitiveType(memref.dtype);
 
   // Recover `minor_to_major` dimensions permutation from strides.
@@ -314,12 +318,15 @@
   return ShapeUtil::MakeShapeWithLayout(type, memref.sizes, minor_to_major);
 }
 
-static StatusOr<GemmConfig> GetGemmConfig(
-    const jitrt::StridedMemrefView& lhs, const jitrt::StridedMemrefView& rhs,
-    const jitrt::StridedMemrefView& out, int64_t algorithm, double alpha_real,
-    double alpha_imag, double beta, ArrayRef<int64_t> lhs_batch,
-    ArrayRef<int64_t> lhs_contract, ArrayRef<int64_t> rhs_batch,
-    ArrayRef<int64_t> rhs_contract) {
+static StatusOr<GemmConfig> GetGemmConfig(const runtime::StridedMemrefView& lhs,
+                                          const runtime::StridedMemrefView& rhs,
+                                          const runtime::StridedMemrefView& out,
+                                          int64_t algorithm, double alpha_real,
+                                          double alpha_imag, double beta,
+                                          ArrayRef<int64_t> lhs_batch,
+                                          ArrayRef<int64_t> lhs_contract,
+                                          ArrayRef<int64_t> rhs_batch,
+                                          ArrayRef<int64_t> rhs_contract) {
   return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs),
                          rhs_batch, rhs_contract, ToShape(out), alpha_real,
                          alpha_imag, beta, algorithm,
@@ -362,8 +369,8 @@
   std::vector<DeviceBufferPair> device_buffers;
   device_buffers.reserve(buffer_pairs);
   for (int i = 0; i < buffer_pairs; ++i) {
-    auto source = args.get<jitrt::StridedMemrefView>(i);
-    auto destination = args.get<jitrt::StridedMemrefView>(i + buffer_pairs);
+    auto source = args.get<runtime::StridedMemrefView>(i);
+    auto destination = args.get<runtime::StridedMemrefView>(i + buffer_pairs);
     if (failed(source) || failed(destination)) {
       // Unsupported argument type.
       return failure();
@@ -437,7 +444,7 @@
   // Add MemRef arguments as buffer arguments.
   for (unsigned i = 0; i < args.size(); ++i) {
     // Simple row major memref passed as shapeless buffer.
-    auto memref = args.get<jitrt::FlatMemrefView>(i);
+    auto memref = args.get<runtime::FlatMemrefView>(i);
     if (succeeded(memref)) {
       buffer_args.emplace_back(GetDeviceAddress(*memref));
       continue;
@@ -445,7 +452,7 @@
 
     // Memref layout must be encoded in the compiled device kernel, so we don't
     // have to pass strides or minor to major dimensions order to the kernel.
-    auto strided = args.get<jitrt::StridedMemrefView>(i);
+    auto strided = args.get<runtime::StridedMemrefView>(i);
     if (succeeded(strided)) {
       buffer_args.emplace_back(GetDeviceAddress(*strided));
       continue;
@@ -488,11 +495,12 @@
   LLVM_ATTRIBUTE_ALWAYS_INLINE
   Error operator()(const ServiceExecutableRunOptions* run_options,
                    const DebugOptions* debug_options,
-                   JitRtGemmConfigCache* configs, jitrt::StridedMemrefView lhs,
-                   jitrt::StridedMemrefView rhs, jitrt::StridedMemrefView out,
-                   int64_t algorithm, double alpha_real, double alpha_imag,
-                   double beta, DotDimensionNumbers dot_dims,
-                   int64_t uid) const;
+                   JitRtGemmConfigCache* configs,
+                   runtime::StridedMemrefView lhs,
+                   runtime::StridedMemrefView rhs,
+                   runtime::StridedMemrefView out, int64_t algorithm,
+                   double alpha_real, double alpha_imag, double beta,
+                   DotDimensionNumbers dot_dims, int64_t uid) const;
 
   static Gemm Handler() { return Gemm(); }
 };
@@ -501,9 +509,9 @@
 Error Gemm::operator()(const ServiceExecutableRunOptions* run_options,
                        const DebugOptions* debug_options,
                        JitRtGemmConfigCache* configs,
-                       jitrt::StridedMemrefView lhs,
-                       jitrt::StridedMemrefView rhs,
-                       jitrt::StridedMemrefView out, int64_t algorithm,
+                       runtime::StridedMemrefView lhs,
+                       runtime::StridedMemrefView rhs,
+                       runtime::StridedMemrefView out, int64_t algorithm,
                        double alpha_real, double alpha_imag, double beta,
                        DotDimensionNumbers dot_dims, int64_t uid) const {
   se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs);
@@ -537,9 +545,9 @@
                              .UserData<const ServiceExecutableRunOptions*>()
                              .UserData<const DebugOptions*>()
                              .UserData<JitRtGemmConfigCache*>()
-                             .Arg<jitrt::StridedMemrefView>()  // lhs
-                             .Arg<jitrt::StridedMemrefView>()  // rhs
-                             .Arg<jitrt::StridedMemrefView>()  // out
+                             .Arg<runtime::StridedMemrefView>()  // lhs
+                             .Arg<runtime::StridedMemrefView>()  // rhs
+                             .Arg<runtime::StridedMemrefView>()  // out
                              .Attr<int64_t>("algorithm")
                              .Attr<double>("alpha_real")
                              .Attr<double>("alpha_imag")
@@ -561,9 +569,9 @@
   LLVM_ATTRIBUTE_ALWAYS_INLINE
   Error operator()(const ServiceExecutableRunOptions* run_options,
                    const DebugOptions* debug_options,
-                   jitrt::StridedMemrefView a, jitrt::StridedMemrefView b,
-                   jitrt::StridedMemrefView c, jitrt::StridedMemrefView d,
-                   Optional<jitrt::StridedMemrefView> bias, int64_t algorithm,
+                   runtime::StridedMemrefView a, runtime::StridedMemrefView b,
+                   runtime::StridedMemrefView c, runtime::StridedMemrefView d,
+                   Optional<runtime::StridedMemrefView> bias, int64_t algorithm,
                    double alpha_real, double alpha_imag, double beta,
                    DotDimensionNumbers dot_dims,
                    se::cuda::BlasLt::Epilogue epilogue,
@@ -575,9 +583,9 @@
 
 Error CublasLtMatmul::operator()(
     const ServiceExecutableRunOptions* run_options,
-    const DebugOptions* debug_options, jitrt::StridedMemrefView a,
-    jitrt::StridedMemrefView b, jitrt::StridedMemrefView c,
-    jitrt::StridedMemrefView d, Optional<jitrt::StridedMemrefView> bias,
+    const DebugOptions* debug_options, runtime::StridedMemrefView a,
+    runtime::StridedMemrefView b, runtime::StridedMemrefView c,
+    runtime::StridedMemrefView d, Optional<runtime::StridedMemrefView> bias,
     int64_t algorithm, double alpha_real, double alpha_imag, double beta,
     DotDimensionNumbers dot_dims, se::cuda::BlasLt::Epilogue epilogue,
     ArrayRef<int32_t> precision, int64_t uid) const {
@@ -616,7 +624,7 @@
 
 // Adds custom call bindings for matmul operations.
 template <typename... Ts>
-static auto BindMatmulAttributes(jitrt::CustomCallBinding<Ts...> binding) {
+static auto BindMatmulAttributes(runtime::CustomCallBinding<Ts...> binding) {
   return std::move(binding)
       .template Attr<int64_t>("algorithm")
       .template Attr<double>("alpha_real")
@@ -634,11 +642,11 @@
       BindMatmulAttributes(CustomCall::Bind("xla.gpu.cublas.lt.matmul")
                                .UserData<const ServiceExecutableRunOptions*>()
                                .UserData<const DebugOptions*>()
-                               .Arg<jitrt::StridedMemrefView>()  // a
-                               .Arg<jitrt::StridedMemrefView>()  // b
-                               .Arg<jitrt::StridedMemrefView>()  // c
-                               .Arg<jitrt::StridedMemrefView>()  // d
-                               .Value(CustomCall::None)          // bias
+                               .Arg<runtime::StridedMemrefView>()  // a
+                               .Arg<runtime::StridedMemrefView>()  // b
+                               .Arg<runtime::StridedMemrefView>()  // c
+                               .Arg<runtime::StridedMemrefView>()  // d
+                               .Value(CustomCall::None)            // bias
                            )
           .To<RuntimeChecks()>(CublasLtMatmul::Handler())
           .release();
@@ -652,11 +660,11 @@
       BindMatmulAttributes(CustomCall::Bind("xla.gpu.cublas.lt.matmul.bias")
                                .UserData<const ServiceExecutableRunOptions*>()
                                .UserData<const DebugOptions*>()
-                               .Arg<jitrt::StridedMemrefView>()  // a
-                               .Arg<jitrt::StridedMemrefView>()  // b
-                               .Arg<jitrt::StridedMemrefView>()  // c
-                               .Arg<jitrt::StridedMemrefView>()  // d
-                               .Arg<jitrt::StridedMemrefView>()  // bias
+                               .Arg<runtime::StridedMemrefView>()  // a
+                               .Arg<runtime::StridedMemrefView>()  // b
+                               .Arg<runtime::StridedMemrefView>()  // c
+                               .Arg<runtime::StridedMemrefView>()  // d
+                               .Arg<runtime::StridedMemrefView>()  // bias
                            )
           .To<RuntimeChecks()>(CublasLtMatmul::Handler())
           .release();
@@ -699,8 +707,8 @@
 static GpuConvDescriptor GetConvDescriptor(
     CudnnConvKind kind,
     // Arguments
-    jitrt::StridedMemrefView operand0, jitrt::StridedMemrefView operand1,
-    jitrt::StridedMemrefView output, jitrt::FlatMemrefView scratch,
+    runtime::StridedMemrefView operand0, runtime::StridedMemrefView operand1,
+    runtime::StridedMemrefView output, runtime::FlatMemrefView scratch,
     // Attributes
     ConvDimensionNumbers dims, Window w, ConvBackendConfig b, ConvAttrs attrs,
     // Conv-specific arguments and attributes
@@ -711,7 +719,7 @@
   descriptor.kind = kind;
 
   // Apply backend config layout to the shape.
-  auto apply_layout = [](jitrt::StridedMemrefView& memref,
+  auto apply_layout = [](runtime::StridedMemrefView& memref,
                          ArrayRef<int64_t> minor_to_major) {
     Shape shape = ToShape(memref);
     return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
@@ -790,10 +798,11 @@
   LLVM_ATTRIBUTE_ALWAYS_INLINE
   Error operator()(
       const ServiceExecutableRunOptions* run_options,
-      const DebugOptions* debug_options, jitrt::StridedMemrefView operand0,
-      jitrt::StridedMemrefView operand1, Optional<jitrt::FlatMemrefView> bias,
-      Optional<jitrt::StridedMemrefView> side_input,
-      jitrt::StridedMemrefView output, jitrt::FlatMemrefView scratch,
+      const DebugOptions* debug_options, runtime::StridedMemrefView operand0,
+      runtime::StridedMemrefView operand1,
+      Optional<runtime::FlatMemrefView> bias,
+      Optional<runtime::StridedMemrefView> side_input,
+      runtime::StridedMemrefView output, runtime::FlatMemrefView scratch,
       ConvDimensionNumbers conv_dims,
       // Window config
       ArrayRef<int64_t> window_strides, ArrayRef<int64_t> padding,
@@ -859,7 +868,7 @@
 
 // Adds custom call bindings for convolution operations.
 template <typename... Ts>
-static auto BindConvAttributes(jitrt::CustomCallBinding<Ts...> binding) {
+static auto BindConvAttributes(runtime::CustomCallBinding<Ts...> binding) {
   return std::move(binding)
       // Convolution dimensions numbers
       .template Attr<ConvDimensionNumbers>("conv_dims")
@@ -882,12 +891,12 @@
       BindConvAttributes(CustomCall::Bind("xla.gpu.conv")
                              .UserData<const ServiceExecutableRunOptions*>()
                              .UserData<const DebugOptions*>()
-                             .Arg<jitrt::StridedMemrefView>()  // operand0
-                             .Arg<jitrt::StridedMemrefView>()  // operand1
-                             .Value(CustomCall::None)          // bias
-                             .Value(CustomCall::None)          // side_input
-                             .Arg<jitrt::StridedMemrefView>()  // output
-                             .Arg<jitrt::FlatMemrefView>()     // scratch
+                             .Arg<runtime::StridedMemrefView>()  // operand0
+                             .Arg<runtime::StridedMemrefView>()  // operand1
+                             .Value(CustomCall::None)            // bias
+                             .Value(CustomCall::None)            // side_input
+                             .Arg<runtime::StridedMemrefView>()  // output
+                             .Arg<runtime::FlatMemrefView>()     // scratch
                          )
           .To(Conv::Handler(kind))
           .release();
@@ -902,12 +911,12 @@
       BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused")
                              .UserData<const ServiceExecutableRunOptions*>()
                              .UserData<const DebugOptions*>()
-                             .Arg<jitrt::StridedMemrefView>()  // operand0
-                             .Arg<jitrt::StridedMemrefView>()  // operand1
-                             .Arg<jitrt::FlatMemrefView>()     // bias
-                             .Value(CustomCall::None)          // side_input
-                             .Arg<jitrt::StridedMemrefView>()  // output
-                             .Arg<jitrt::FlatMemrefView>()     // scratch
+                             .Arg<runtime::StridedMemrefView>()  // operand0
+                             .Arg<runtime::StridedMemrefView>()  // operand1
+                             .Arg<runtime::FlatMemrefView>()     // bias
+                             .Value(CustomCall::None)            // side_input
+                             .Arg<runtime::StridedMemrefView>()  // output
+                             .Arg<runtime::FlatMemrefView>()     // scratch
                          )
           .Attr<se::dnn::ActivationMode>("activation_mode")
           .To(Conv::Handler(kind))
@@ -923,12 +932,12 @@
       BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused.side_input")
                              .UserData<const ServiceExecutableRunOptions*>()
                              .UserData<const DebugOptions*>()
-                             .Arg<jitrt::StridedMemrefView>()  // operand0
-                             .Arg<jitrt::StridedMemrefView>()  // operand1
-                             .Arg<jitrt::FlatMemrefView>()     // bias
-                             .Arg<jitrt::StridedMemrefView>()  // side_input
-                             .Arg<jitrt::StridedMemrefView>()  // output
-                             .Arg<jitrt::FlatMemrefView>()     // scratch
+                             .Arg<runtime::StridedMemrefView>()  // operand0
+                             .Arg<runtime::StridedMemrefView>()  // operand1
+                             .Arg<runtime::FlatMemrefView>()     // bias
+                             .Arg<runtime::StridedMemrefView>()  // side_input
+                             .Arg<runtime::StridedMemrefView>()  // output
+                             .Arg<runtime::FlatMemrefView>()     // scratch
                          )
           .Attr<se::dnn::ActivationMode>("activation_mode")
           .Attr<double>("side_input_scale")
@@ -964,7 +973,7 @@
   size_t index = 0;
   for (auto& source : source_buffers.leaves()) {
     // Get the destination buffer.
-    auto dest = args.get<jitrt::StridedMemrefView>(index);
+    auto dest = args.get<runtime::StridedMemrefView>(index);
     if (failed(dest))
       return MakeStringError("Failed to get the destination buffer");
 
@@ -1038,7 +1047,7 @@
   size_t index = 0;
   for (auto& dest : dest_buffers->leaves()) {
     // Get the source buffer.
-    auto source = args.get<jitrt::StridedMemrefView>(index);
+    auto source = args.get<runtime::StridedMemrefView>(index);
     if (failed(source))
       return MakeStringError("Failed to get the source buffer");
 
@@ -1092,15 +1101,16 @@
 template <MemcpyDirection direction>
 struct Memcpy {
   Error operator()(const ServiceExecutableRunOptions* run_options,
-                   jitrt::FlatMemrefView dst, jitrt::FlatMemrefView src) const;
+                   runtime::FlatMemrefView dst,
+                   runtime::FlatMemrefView src) const;
   static Memcpy Handler() { return Memcpy(); }
 };
 }  // namespace
 
 template <MemcpyDirection direction>
 Error Memcpy<direction>::operator()(
-    const ServiceExecutableRunOptions* run_options, jitrt::FlatMemrefView dst,
-    jitrt::FlatMemrefView src) const {
+    const ServiceExecutableRunOptions* run_options, runtime::FlatMemrefView dst,
+    runtime::FlatMemrefView src) const {
   se::Stream* stream = run_options->stream();
 
   if (dst.size_in_bytes != src.size_in_bytes) {
@@ -1139,8 +1149,8 @@
 static bool MemcpyFn(runtime::KernelContext* ctx, void** args, void** attrs) {
   static auto* handler = CustomCall::Bind("xla.gpu.memcpy")
                              .UserData<const ServiceExecutableRunOptions*>()
-                             .Arg<jitrt::FlatMemrefView>()  // dst
-                             .Arg<jitrt::FlatMemrefView>()  // src
+                             .Arg<runtime::FlatMemrefView>()  // dst
+                             .Arg<runtime::FlatMemrefView>()  // src
                              .To<RuntimeChecks()>(Memcpy<direction>::Handler())
                              .release();
 
@@ -1153,7 +1163,7 @@
 
 struct Memset {
   Error operator()(const ServiceExecutableRunOptions* run_options,
-                   jitrt::FlatMemrefView dst,
+                   runtime::FlatMemrefView dst,
                    CustomCall::VariantArg constant) const;
   static Memset Handler() { return Memset(); }
 };
@@ -1161,7 +1171,7 @@
 }  // namespace
 
 Error Memset::operator()(const ServiceExecutableRunOptions* run_options,
-                         jitrt::FlatMemrefView dst,
+                         runtime::FlatMemrefView dst,
                          CustomCall::VariantArg constant) const {
   se::Stream* stream = run_options->stream();
   se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
@@ -1205,8 +1215,8 @@
 static bool MemsetFn(runtime::KernelContext* ctx, void** args, void** attrs) {
   static auto* handler = CustomCall::Bind("xla.gpu.memset")
                              .UserData<const ServiceExecutableRunOptions*>()
-                             .Arg<jitrt::FlatMemrefView>()   // dst
-                             .Arg<CustomCall::VariantArg>()  // constant
+                             .Arg<runtime::FlatMemrefView>()  // dst
+                             .Arg<CustomCall::VariantArg>()   // constant
                              .To<RuntimeChecks()>(Memset::Handler())
                              .release();
 
@@ -1219,16 +1229,16 @@
 struct Fft {
   LLVM_ATTRIBUTE_ALWAYS_INLINE
   Error operator()(const ServiceExecutableRunOptions* run_options,
-                   jitrt::StridedMemrefView input,
-                   jitrt::StridedMemrefView output,
+                   runtime::StridedMemrefView input,
+                   runtime::StridedMemrefView output,
                    ArrayRef<int64_t> fft_length, se::fft::Type fft_type) const;
   static Fft Handler() { return Fft(); }
 };
 }  // namespace
 
 Error Fft::operator()(const ServiceExecutableRunOptions* run_options,
-                      jitrt::StridedMemrefView input,
-                      jitrt::StridedMemrefView output,
+                      runtime::StridedMemrefView input,
+                      runtime::StridedMemrefView output,
                       ArrayRef<int64_t> fft_length,
                       se::fft::Type fft_type) const {
   // TODO(ezhulenev): Cache FFT plans in the GpuExecutable.
@@ -1270,13 +1280,12 @@
 static bool Fft(runtime::KernelContext* ctx, void** args, void** attrs) {
   static auto* handler = CustomCall::Bind("xla.gpu.fft")
                              .UserData<const ServiceExecutableRunOptions*>()
-                             .Arg<jitrt::StridedMemrefView>()  // input
-                             .Arg<jitrt::StridedMemrefView>()  // output
+                             .Arg<runtime::StridedMemrefView>()  // input
+                             .Arg<runtime::StridedMemrefView>()  // output
                              .Attr<ArrayRef<int64_t>>("fft_length")
                              .Attr<se::fft::Type>("fft_type")
                              .To<RuntimeChecks()>(Fft::Handler())
                              .release();
-
   return succeeded(Executable::Call(ctx, *handler, args, attrs));
 }
 
@@ -1286,19 +1295,20 @@
 struct Cholesky {
   LLVM_ATTRIBUTE_ALWAYS_INLINE
   Error operator()(const ServiceExecutableRunOptions* run_options,
-                   const DebugOptions* debug_options, jitrt::MemrefView operand,
-                   jitrt::MemrefView a, jitrt::MemrefView workspace,
-                   jitrt::MemrefView info, int64_t batch_size, bool is_lower,
-                   int64_t n) const;
+                   const DebugOptions* debug_options,
+                   runtime::MemrefView operand, runtime::MemrefView a,
+                   runtime::MemrefView workspace, runtime::MemrefView info,
+                   int64_t batch_size, bool is_lower, int64_t n) const;
   static Cholesky Handler() { return Cholesky(); }
 };
 }  // namespace
 
 Error Cholesky::operator()(const ServiceExecutableRunOptions* run_options,
                            const DebugOptions* debug_options,
-                           jitrt::MemrefView operand, jitrt::MemrefView a,
-                           jitrt::MemrefView workspace, jitrt::MemrefView info,
-                           int64_t batch_size, bool is_lower, int64_t n) const {
+                           runtime::MemrefView operand, runtime::MemrefView a,
+                           runtime::MemrefView workspace,
+                           runtime::MemrefView info, int64_t batch_size,
+                           bool is_lower, int64_t n) const {
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   se::DeviceMemoryBase operand_buffer = GetDeviceAddress(operand);
   se::DeviceMemoryBase a_buffer = GetDeviceAddress(a);
@@ -1332,10 +1342,10 @@
   static auto* handler = CustomCall::Bind("xla.gpu.cholesky")
                              .UserData<const ServiceExecutableRunOptions*>()
                              .UserData<const DebugOptions*>()
-                             .Arg<jitrt::MemrefView>()  // operand
-                             .Arg<jitrt::MemrefView>()  // a
-                             .Arg<jitrt::MemrefView>()  // workspace
-                             .Arg<jitrt::MemrefView>()  // info
+                             .Arg<runtime::MemrefView>()  // operand
+                             .Arg<runtime::MemrefView>()  // a
+                             .Arg<runtime::MemrefView>()  // workspace
+                             .Arg<runtime::MemrefView>()  // info
                              .Attr<int64_t>("batch_size")
                              .Attr<bool>("is_lower")
                              .Attr<int64_t>("n")
@@ -1362,9 +1372,10 @@
 
   Error operator()(const ServiceExecutableRunOptions* run_options,
                    const DebugOptions* debug_options,
-                   jitrt::StridedMemrefView a, jitrt::StridedMemrefView b,
-                   jitrt::StridedMemrefView result, jitrt::FlatMemrefView temp,
-                   bool left_side, bool lower, bool unit_diagonal,
+                   runtime::StridedMemrefView a, runtime::StridedMemrefView b,
+                   runtime::StridedMemrefView result,
+                   runtime::FlatMemrefView temp, bool left_side, bool lower,
+                   bool unit_diagonal,
                    TriangularSolveOptions::Transpose transpose_a) const;
   static TriangularSolve Handler() { return TriangularSolve(); }
 };
@@ -1381,10 +1392,10 @@
     return MakeStringError("Expected 4 arguments, got %n", args.size());
 
   // Check if all arguments have the correct type.
-  auto a = args.get<jitrt::StridedMemrefView>(0);
-  auto b = args.get<jitrt::StridedMemrefView>(1);
-  auto result = args.get<jitrt::StridedMemrefView>(2);
-  auto temp = args.get<jitrt::FlatMemrefView>(3);
+  auto a = args.get<runtime::StridedMemrefView>(0);
+  auto b = args.get<runtime::StridedMemrefView>(1);
+  auto result = args.get<runtime::StridedMemrefView>(2);
+  auto temp = args.get<runtime::FlatMemrefView>(3);
   if (failed(a) || failed(b) || failed(result) || failed(temp))
     return MakeStringError("Incorrect argument types");
 
@@ -1400,10 +1411,10 @@
 
 Error TriangularSolve::operator()(
     const ServiceExecutableRunOptions* run_options,
-    const DebugOptions* debug_options, jitrt::StridedMemrefView a,
-    jitrt::StridedMemrefView b, jitrt::StridedMemrefView result,
-    jitrt::FlatMemrefView temp, bool left_side, bool lower, bool unit_diagonal,
-    TriangularSolveOptions::Transpose transpose_a) const {
+    const DebugOptions* debug_options, runtime::StridedMemrefView a,
+    runtime::StridedMemrefView b, runtime::StridedMemrefView result,
+    runtime::FlatMemrefView temp, bool left_side, bool lower,
+    bool unit_diagonal, TriangularSolveOptions::Transpose transpose_a) const {
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   se::Stream* stream = run_options->stream();
 
@@ -1510,11 +1521,11 @@
   for (unsigned i = 0; i < args.size(); ++i) {
     // We use zero-sized memrefs to represent holes in custom calls with target
     // arguments mapping (see `CustomCallTargetArgMapping`).
-    if (auto memref = args.get<jitrt::FlatMemrefView>(i); succeeded(memref)) {
+    if (auto memref = args.get<runtime::FlatMemrefView>(i); succeeded(memref)) {
       buffers.push_back(memref->size_in_bytes == 0 ? nullptr : memref->data);
       continue;
     }
-    if (auto strided = args.get<jitrt::StridedMemrefView>(i);
+    if (auto strided = args.get<runtime::StridedMemrefView>(i);
         succeeded(strided)) {
       int64_t size_in_bytes = GetHostSize(strided->dtype);
       for (int64_t size : strided->sizes) size_in_bytes *= size;
@@ -1561,7 +1572,7 @@
   static auto* handler = CustomCall::Bind("xla.gpu.memcpy")
                              .UserData<const ServiceExecutableRunOptions*>()
                              .UserData<const DebugOptions*>()
-                             .Arg<jitrt::CustomCall::RemainingArgs>()  // args
+                             .Arg<CustomCall::RemainingArgs>()  // args
                              .Attr<StringRef>("call_target_name")
                              .Attr<int32_t>("api_version")
                              .Attr<StringRef>("backend_config")
@@ -2072,13 +2083,13 @@
 struct ReplicaId {
   LLVM_ATTRIBUTE_ALWAYS_INLINE
   Error operator()(const ServiceExecutableRunOptions* run_options,
-                   jitrt::FlatMemrefView result) const;
+                   runtime::FlatMemrefView result) const;
   static ReplicaId Handler() { return ReplicaId(); }
 };
 }  // namespace
 
 Error ReplicaId::operator()(const ServiceExecutableRunOptions* run_options,
-                            jitrt::FlatMemrefView result) const {
+                            runtime::FlatMemrefView result) const {
   VLOG(3) << "Running ReplicaId";
   se::Stream* stream = run_options->stream();
   NcclExecuteParams params(*run_options, stream);
@@ -2100,7 +2111,7 @@
 static bool ReplicaId(runtime::KernelContext* ctx, void** args, void** attrs) {
   static auto* handler = CustomCall::Bind("xla.gpu.replica_id")
                              .UserData<const ServiceExecutableRunOptions*>()
-                             .Arg<jitrt::FlatMemrefView>()  // result
+                             .Arg<runtime::FlatMemrefView>()  // result
                              .To<RuntimeChecks()>(ReplicaId::Handler())
                              .release();
 
@@ -2113,13 +2124,13 @@
 struct PartitionId {
   LLVM_ATTRIBUTE_ALWAYS_INLINE
   Error operator()(const ServiceExecutableRunOptions* run_options,
-                   jitrt::FlatMemrefView result) const;
+                   runtime::FlatMemrefView result) const;
   static PartitionId Handler() { return PartitionId(); }
 };
 }  // namespace
 
 Error PartitionId::operator()(const ServiceExecutableRunOptions* run_options,
-                              jitrt::FlatMemrefView result) const {
+                              runtime::FlatMemrefView result) const {
   VLOG(3) << "Running PartitionId";
   se::Stream* stream = run_options->stream();
   NcclExecuteParams params(*run_options, stream);
@@ -2142,7 +2153,7 @@
                         void** attrs) {
   static auto* handler = CustomCall::Bind("xla.gpu.partition_id")
                              .UserData<const ServiceExecutableRunOptions*>()
-                             .Arg<jitrt::FlatMemrefView>()  // result
+                             .Arg<runtime::FlatMemrefView>()  // result
                              .To<RuntimeChecks()>(PartitionId::Handler())
                              .release();