[xla:jitrt] Basic support for gpu AOT compilation

PiperOrigin-RevId: 467097310
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 5d7332c..9444bba 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -675,6 +675,7 @@
         "@com_google_absl//absl/synchronization",
         "@llvm-project//mlir:FuncDialect",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Parser",
         "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu",
         "//tensorflow/compiler/xla/service:hlo_execution_profile",
         "//tensorflow/compiler/xla:array2d",
@@ -743,6 +744,7 @@
             # copybara:uncomment "@tf_runtime//backends/jitrt",
             # copybara:uncomment "@tf_runtime//backends/jitrt:diagnostics",
             # copybara:uncomment "@tf_runtime//backends/jitrt:jitrt_compiler",
+            # copybara:uncomment "@tf_runtime//:init_tfrt_dialects",
         ],
         "//conditions:default": [],
     }),
@@ -1730,7 +1732,10 @@
         ":runtime_intrinsics",
     ] + select({
         ":is_xlir_enabled": [
+            ":jitrt_custom_calls",
             "//tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu:pass_utils",
+            "@tf_runtime//backends/jitrt",
+            "@tf_runtime//backends/jitrt:jitrt_compiler",
         ],
         "//conditions:default": [],
     }),
@@ -1852,12 +1857,10 @@
         "XLA_FLAGS": "--xla_gpu_jitrt_executable",
     },
     tags = [
-        "broken",  # b/239679478: Switching AOT support from BEF to JitRt.
         "gpu",
         "no_oss",
         "no_rocm",
         "nomsan",  # Pulls in precompiled NVIDIA libraries which cause false positives in msan.
-        "notap",  # b/239679478: Switching AOT support from BEF to JitRt.
         "requires-gpu-nvidia",
     ],
     deps = [
@@ -2598,6 +2601,7 @@
         "//tensorflow/compiler/xla/python:xla_client_test_gpu",
         "//tensorflow/compiler/xla/service/gpu:cudnn_fused_conv_rewriter_test",
         "//tensorflow/compiler/xla/service/gpu:custom_call_test",
+        "//tensorflow/compiler/xla/service/gpu:gpu_aot_compilation_test",
         "//tensorflow/compiler/xla/service/gpu/tests:add_preds.hlo.test",
         "//tensorflow/compiler/xla/service/gpu/tests:all_reduce.hlo.test",
         "//tensorflow/compiler/xla/service/gpu/tests:concat.hlo.test",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 7d4dbb9..f88cb81 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -188,6 +188,10 @@
 
 #if XLA_ENABLE_XLIR
 #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pass_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h"
+#include "tfrt/jitrt/jitrt.h"  // from @tf_runtime
+#include "tfrt/jitrt/jitrt_compiler.h"  // from @tf_runtime
+namespace jitrt = ::tfrt::jitrt;
 #endif  // XLA_ENABLE_XLIR
 
 namespace xla {
@@ -288,6 +292,24 @@
 using OwnedThunkSchedule = GpuExecutable::OwnedThunkSchedule;
 using OwnedJitRtProgram = GpuExecutable::OwnedJitRtProgram;
 
+StatusOr<std::unique_ptr<Executable>> JitRtAotCompilationResult::LoadExecutable(
+    Compiler* compiler, se::StreamExecutor* executor) const {
+  TF_ASSIGN_OR_RETURN(
+      HloModuleConfig hlo_module_config,
+      HloModule::CreateModuleConfigFromProto(
+          jitrt_executable_.hlo_module_proto(), GetDebugOptionsFromFlags()));
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<HloModule> hlo_module,
+      HloModule::CreateFromProto(jitrt_executable_.hlo_module_proto(),
+                                 hlo_module_config));
+  auto gpu_compiler = tensorflow::down_cast<GpuCompiler*>(compiler);
+  return GpuExecutable::LoadFromObjFile(
+      std::move(hlo_module), jitrt_executable_.obj_file(),
+      jitrt_executable_.mlir_module(), jitrt_executable_.entry_func_attrs(),
+      GetDebugOptionsFromFlags(), gpu_compiler->GetGpuVersion(executor),
+      executor);
+}
+
 GpuCompiler::GpuCompiler(se::Platform::Id platform_id,
                          const char* target_triple, const char* data_layout)
     : platform_id_(platform_id),
@@ -1378,7 +1400,77 @@
 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
 GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
                                 const AotCompilationOptions& options) {
+#if XLA_ENABLE_XLIR
+  CHECK(options.PlatformId() == se::cuda::kCudaPlatformId);
+  CHECK(options.executor() != nullptr);
+  auto stream_exec = options.executor();
+
+  std::vector<std::unique_ptr<HloModule>> modules =
+      module_group->ConsumeModules();
+  std::vector<std::unique_ptr<AotCompilationResult>> results;
+
+  for (const auto& module : modules) {
+    llvm::LLVMContext llvm_context;
+    GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
+
+    // Compile the module
+    CompileModuleResults compile_module_results;
+    TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
+        module.get(), &llvm_context, target_triple_, data_layout_,
+        stream_exec->platform()->Name(), stream_exec->platform()->id(),
+        gpu_device_info,
+        stream_exec->GetDeviceDescription().cuda_compute_capability(),
+        stream_exec->GetDeviceDescription().rocm_compute_capability(),
+        GetCanShareBuffer(), pointer_size_, &compile_module_results));
+    auto& compiled_executable = compile_module_results.executable;
+
+    if (!std::holds_alternative<OwnedJitRtProgram>(compiled_executable)) {
+      return InternalError("JitRtProgram not provided");
+    }
+
+    const auto& program = std::get<OwnedJitRtProgram>(compiled_executable);
+
+    // Options for the default JitRt compilation pipeline.
+    jitrt::CompilationPipelineOptions copts;
+    copts.num_worker_threads = 1;
+
+    // Options for constructing JitRt JitExecutable.
+    jitrt::CompilationOptions opts;
+    opts.specialization = jitrt::CompilationOptions::Specialization::kDisabled;
+    opts.register_dialects = jitrt::RegisterDefaultJitRtDialects;
+
+    // Register JitRt Gpu runtime custom calls with the linker.
+    opts.runtime_symbol_map = GetSymbolsBinding(JitRtGpuCustomCalls());
+
+    opts.create_compilation_pipeline = [copts](mlir::PassManager& pm) {
+      jitrt::CreateDefaultJitRtCompilationPipeline(pm, copts);
+    };
+
+    // Instantiate new JitExecutable from the MLIR source.
+    auto jit_executable = jitrt::JitExecutable::Instantiate(
+        program->module, program->entry_point, opts);
+    if (auto err = jit_executable.takeError())
+      return InternalError("Failed to compile JitRt program: %s",
+                           tfrt::StrCat(err));
+
+    // For static shapes we can always serialize only the default executable.
+    jitrt::Executable& executable = jit_executable->DefaultExecutable().get();
+
+    // Check if JitRt executable saved the compilation result.
+    std::unique_ptr<llvm::MemoryBuffer> obj_file = executable.obj_file();
+    if (!obj_file)
+      return InternalError("JitRt executable didn't save the obj file");
+
+    std::string data(obj_file->getBuffer().data(),
+                     obj_file->getBuffer().size());
+    results.emplace_back(std::make_unique<xla::gpu::JitRtAotCompilationResult>(
+        module->ToProto(), data, program->module,
+        compile_module_results.entry_func_attrs));
+  }
+  return std::move(results);
+#else
   return Unimplemented("");
+#endif  // XLA_ENABLE_XLIR
 }
 
 HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
index 020a674..8d25f5f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
@@ -31,12 +31,49 @@
 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
 
 namespace xla {
 namespace gpu {
 
+// TODO(b/232263665): It should be shared between GPU and CPU.
+class JitRtAotCompilationResult : public AotCompilationResult {
+ public:
+  static StatusOr<std::unique_ptr<JitRtAotCompilationResult>> FromString(
+      const std::string& serialized) {
+    JitRtExecutableProto jitrt_executable;
+    if (!jitrt_executable.ParseFromString(serialized)) {
+      return InternalError("Failed to parse serialized JitRtExecutableProto.");
+    }
+    return std::unique_ptr<JitRtAotCompilationResult>(
+        new JitRtAotCompilationResult(std::move(jitrt_executable)));
+  }
+
+  JitRtAotCompilationResult(HloModuleProto hlo, const std::string& obj_file,
+                            const std::string& mlir_module,
+                            EntryFunctionAttributes entry_func_attrs) {
+    *jitrt_executable_.mutable_hlo_module_proto() = hlo;
+    *jitrt_executable_.mutable_entry_func_attrs() = entry_func_attrs;
+    jitrt_executable_.set_obj_file(obj_file);
+    jitrt_executable_.set_mlir_module(mlir_module);
+  }
+
+  StatusOr<std::string> SerializeAsString() const override {
+    return jitrt_executable_.SerializeAsString();
+  }
+
+  StatusOr<std::unique_ptr<Executable>> LoadExecutable(
+      Compiler* compiler, se::StreamExecutor* executor) const override;
+
+ private:
+  explicit JitRtAotCompilationResult(JitRtExecutableProto jitrt_executable)
+      : jitrt_executable_(std::move(jitrt_executable)) {}
+
+  JitRtExecutableProto jitrt_executable_;
+};
+
 // The GPU compiler generates efficient GPU executables.
 class GpuCompiler : public LLVMCompiler {
  public:
@@ -73,6 +110,13 @@
 
   HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
 
+  // Returns a (deserialized) AotCompilationResult from a serialized
+  // AotCompilationResult.
+  StatusOr<std::unique_ptr<AotCompilationResult>> LoadAotCompilationResult(
+      const std::string& serialized_aot_result) override {
+    return JitRtAotCompilationResult::FromString(serialized_aot_result);
+  }
+
  protected:
   virtual Status OptimizeHloPostLayoutAssignment(
       HloModule* hlo_module, se::StreamExecutor* stream_exec,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index eea135a..e41e11e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -26,6 +26,7 @@
 #include "absl/cleanup/cleanup.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/synchronization/mutex.h"
+#include "mlir/Parser/Parser.h"  // from @llvm-project
 #include "tensorflow/compiler/xla/map_util.h"
 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
@@ -56,6 +57,7 @@
 #include "tfrt/jitrt/diagnostics.h"  // from @tf_runtime
 #include "tfrt/jitrt/jitrt.h"  // from @tf_runtime
 #include "tfrt/jitrt/jitrt_compiler.h"  // from @tf_runtime
+#include "tfrt/init_tfrt_dialects.h"  // from @tf_runtime
 #endif  // XLA_ENABLE_XLIR
 
 namespace xla {
@@ -135,12 +137,23 @@
                            tfrt::StrCat(err));
 
     // Pass ownership to the GpuExecutable.
-    return new JitRtExecutable(std::move(program->buffer_sizes),
-                               std::move(*jit_executable),
-                               std::move(program->debug_options));
+    return new JitRtExecutable(
+        std::move(program->buffer_sizes),
+        std::make_unique<jitrt::JitExecutable>(std::move(*jit_executable)),
+        std::move(program->debug_options));
   }
 
-  jitrt::JitExecutable& jit_executable() { return jit_executable_; }
+  // Create JitRtExecutable from the AOT compiled binary.
+  static StatusOr<JitRtExecutable*> Create(
+      absl::Span<const int64_t> buffer_sizes, jitrt::Executable executable,
+      DebugOptions debug_options) {
+    // Pass ownership to the GpuExecutable.
+    return new JitRtExecutable(
+        std::vector<int64_t>(buffer_sizes.begin(), buffer_sizes.end()),
+        std::make_unique<jitrt::Executable>(std::move(executable)),
+        std::move(debug_options));
+  }
+
   jitrt::Executable& default_executable() { return *default_executable_; }
   JitRtKernelsCache& kernels_cache() { return kernels_cache_; }
   JitRtGemmConfigCache& gemm_configs_cache() { return gemm_configs_cache_; }
@@ -155,17 +168,30 @@
   const DebugOptions& debug_options() const { return debug_options_; }
 
  private:
-  explicit JitRtExecutable(std::vector<int64_t> buffer_sizes,
-                           jitrt::JitExecutable jit_executable,
-                           DebugOptions debug_options)
+  JitRtExecutable(std::vector<int64_t> buffer_sizes,
+                  std::unique_ptr<jitrt::JitExecutable> jit_executable,
+                  DebugOptions debug_options)
       : buffer_sizes_(std::move(buffer_sizes)),
         jit_executable_(std::move(jit_executable)),
-        default_executable_(&jit_executable_.DefaultExecutable().get()),
+        default_executable_(&jit_executable_->DefaultExecutable().get()),
+        debug_options_(std::move(debug_options)) {}
+
+  JitRtExecutable(std::vector<int64_t> buffer_sizes,
+                  std::unique_ptr<jitrt::Executable> aot_executable,
+                  DebugOptions debug_options)
+      : buffer_sizes_(std::move(buffer_sizes)),
+        aot_executable_(std::move(aot_executable)),
+        executable_(aot_executable.get()),
         debug_options_(std::move(debug_options)) {}
 
   std::vector<int64_t> buffer_sizes_;
-  jitrt::JitExecutable jit_executable_;
+
+  std::unique_ptr<jitrt::JitExecutable> jit_executable_;
   jitrt::Executable* default_executable_;  // owned by `jit_executable`
+
+  std::unique_ptr<jitrt::Executable> aot_executable_;
+  jitrt::Executable* executable_;
+
   DebugOptions debug_options_;
 
   // Keep a cache of kernels instantiated by this executable.
@@ -1002,5 +1028,113 @@
   return output;
 }
 
+GpuExecutable::GpuExecutable(
+    std::shared_ptr<HloModule> hlo_module, GpuVersion gpu_version,
+    xla::EntryFunctionAttributes entry_func_attrs,
+    absl::string_view module_name, Shape xla_output_shape,
+    std::vector<BufferAllocation> allocations,
+    absl::flat_hash_map<ShapeIndex, OutputInfo> output_info,
+    JitRtExecutable* jitrt_executable)
+    : Executable(std::move(hlo_module)),
+      gpu_version_(gpu_version),
+      entry_func_attrs_(entry_func_attrs),
+      module_name_(module_name),
+      output_shape_(xla_output_shape),
+      allocations_(std::move(allocations)),
+      output_info_(std::move(output_info)),
+      jitrt_executable_(jitrt_executable) {
+  XlaDebugInfoManager::Get()->RegisterModule(
+      module().unique_id(), shared_module(), debug_buffer_assignment_);
+}
+
+StatusOr<std::unique_ptr<Executable>> GpuExecutable::LoadFromObjFile(
+    std::shared_ptr<HloModule> hlo_module, absl::string_view obj_file,
+    absl::string_view mlir_module,
+    xla::EntryFunctionAttributes entry_func_attrs, DebugOptions debug_options,
+    GpuVersion gpu_version, se::StreamExecutor* executor) {
+#if XLA_ENABLE_XLIR
+  // Load MLIR module behind the compiled object file to recover XLA allocations
+  // and output info details. Also recover buffer sizes from the entrypoint
+  // function signature.
+  mlir::MLIRContext context;
+
+  mlir::DialectRegistry registry;
+  tfrt::RegisterTFRTDialects(registry);
+  tfrt::RegisterTFRTCompiledDialects(registry);
+  context.appendDialectRegistry(registry);
+
+  auto module = mlir::parseSourceString<mlir::ModuleOp>(mlir_module, &context);
+  if (!module) return InternalError("Failed to parse AOT compiled module");
+
+  // Get the XLA module entrypoint function.
+  auto func = mlir::cast<mlir::func::FuncOp>(
+      module->lookupSymbol(hlo_module->entry_computation()->name()));
+
+  // Get the buffer sizes from the entrypoint function signature.
+  std::vector<int64_t> buffer_sizes;
+  buffer_sizes.reserve(func.getNumArguments());
+  for (auto type : func.getArgumentTypes()) {
+    auto memref = type.dyn_cast<mlir::MemRefType>();
+    if (!memref || !memref.hasStaticShape() || memref.getRank() != 1)
+      return InternalError("Illegal entrypoint argument type: %s",
+                           tfrt::StrCat(type));
+    buffer_sizes.push_back(memref.getDimSize(0));
+  }
+
+  // Infer XLA allocations and output info from the MLIR module.
+  std::vector<BufferAllocation> allocations;
+  absl::flat_hash_map<ShapeIndex, OutputInfo> output_info;
+  Shape result_xla_shape;
+  TF_RETURN_IF_ERROR(SetUpMlirAllocation(func, buffer_sizes, &allocations,
+                                         &output_info, &result_xla_shape,
+                                         /*buffer_param_offset=*/0));
+
+  // Create a named buffer from compiled object file.
+  llvm::StringRef data(obj_file.data(), obj_file.size());
+  auto buffer = llvm::MemoryBuffer::getMemBuffer(data, hlo_module->name());
+
+  // Create a JitRt function signature (all arguments passed as 1d memrefs).
+  llvm::SmallVector<std::unique_ptr<jitrt::Type>> args;
+  llvm::SmallVector<std::unique_ptr<jitrt::Type>> rt_args;
+  rt_args.push_back(std::make_unique<jitrt::KernelContextOperandType>());
+
+  for (int64_t size : buffer_sizes) {
+    auto i8 = tfrt::DType::I8;
+    args.push_back(std::make_unique<jitrt::MemrefType>(size, i8));
+    rt_args.push_back(std::make_unique<jitrt::MemrefType>(size, i8));
+  }
+
+  jitrt::FunctionType signature(std::move(args), /*results=*/{});
+  jitrt::FunctionType rt_signature(std::move(rt_args), /*results=*/{});
+
+  auto symbol_map = GetSymbolsBinding(JitRtGpuCustomCalls());
+
+  // Load JitRt executable from an object file, and link it with Gpu runtime
+  // intrinsics implementing Gpu custom calls.
+  auto executable = jitrt::Executable::LoadFromObjFile(
+      hlo_module->name(), std::move(buffer),
+      hlo_module->entry_computation()->name(), std::move(signature),
+      std::move(rt_signature), symbol_map);
+  if (auto err = executable.takeError())
+    return InternalError("Failed to load JitRt executable: %s",
+                         tfrt::StrCat(err));
+
+  // Move jitrt::Executable ownership to the JitRtExecutable.
+  TF_ASSIGN_OR_RETURN(
+      JitRtExecutable * jitrt_executable,
+      JitRtExecutable::Create(buffer_sizes, std::move(*executable),
+                              std::move(debug_options)));
+
+  // Construct GpuExecutable for the loaded JitRt executable.
+  std::string name = hlo_module->name();
+  return std::unique_ptr<Executable>(
+      new GpuExecutable(std::move(hlo_module), gpu_version, entry_func_attrs,
+                        name, result_xla_shape, std::move(allocations),
+                        std::move(output_info), jitrt_executable));
+
+#else   // XLA_ENABLE_XLIR
+  return FailedPrecondition("Not built with XLA_ENABLE_XLIR");
+#endif  // XLA_ENABLE_XLIR
+}
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 818cd7f..35a1adc 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -141,6 +141,24 @@
       absl::flat_hash_map<ShapeIndex, OutputInfo>* output_info,
       Shape* output_shape, int buffer_param_offset = 0);
 
+  // Returns an Executable that is loaded from an object file (XLA program
+  // compiled to a native function using the JitRt stack).
+  static StatusOr<std::unique_ptr<Executable>> LoadFromObjFile(
+      std::shared_ptr<HloModule> hlo_module, absl::string_view obj_file,
+      absl::string_view mlir_module,
+      xla::EntryFunctionAttributes entry_func_attrs, DebugOptions debug_options,
+      GpuVersion gpu_version, stream_executor::StreamExecutor* executor);
+
+  // Constructor to use when loading a GpuExecutable from an object file (native
+  // function compiled for JitRt). Omits setting class members that aren't used
+  // in JitRt execution mode.
+  GpuExecutable(std::shared_ptr<HloModule> hlo_module, GpuVersion gpu_version,
+                xla::EntryFunctionAttributes entry_func_attrs,
+                absl::string_view module_name, Shape xla_output_shape,
+                std::vector<BufferAllocation> allocations,
+                absl::flat_hash_map<ShapeIndex, OutputInfo> output_info,
+                JitRtExecutable* jitrt_executable);
+
   static StatusOr<std::unique_ptr<GpuExecutable>> Create(Params params);
   ~GpuExecutable() override;
 
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index d9926e2..d0f279b 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -748,3 +748,25 @@
   // xla::Shape in string format.
   string result_xla_shape = 2;
 }
+
+// Encodes the underlying JitRt executable compiled from the XLA module.
+message JitRtExecutableProto {
+  HloModuleProto hlo_module_proto = 1;
+
+  // XLA-specific attributes of the executable's entry function.
+  EntryFunctionAttributes entry_func_attrs = 2;
+
+  // TODO(b/232263665)): We need to know the TargetMachine this executable was
+  // compiled for, otherwise we can accidentally use illegal instrauctions (e.g.
+  // use AVX512 when it's not available).
+
+  // TODO(b/232263665)): Serialized executable has to know what APIs it has to
+  // be linked with, including the version. For example Gpu executable must be
+  // linked with a runtime layer that abstracts over CUDA.
+
+  // Serialized object file compiled from the XLA module.
+  bytes obj_file = 3;
+
+  // Serialized MLIR module corresponding to compiled object file.
+  string mlir_module = 4;
+}
\ No newline at end of file