[xla:jitrt] Basic support for gpu AOT compilation
PiperOrigin-RevId: 467097310
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 5d7332c..9444bba 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -675,6 +675,7 @@
"@com_google_absl//absl/synchronization",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
"//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu",
"//tensorflow/compiler/xla/service:hlo_execution_profile",
"//tensorflow/compiler/xla:array2d",
@@ -743,6 +744,7 @@
# copybara:uncomment "@tf_runtime//backends/jitrt",
# copybara:uncomment "@tf_runtime//backends/jitrt:diagnostics",
# copybara:uncomment "@tf_runtime//backends/jitrt:jitrt_compiler",
+ # copybara:uncomment "@tf_runtime//:init_tfrt_dialects",
],
"//conditions:default": [],
}),
@@ -1730,7 +1732,10 @@
":runtime_intrinsics",
] + select({
":is_xlir_enabled": [
+ ":jitrt_custom_calls",
"//tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu:pass_utils",
+ "@tf_runtime//backends/jitrt",
+ "@tf_runtime//backends/jitrt:jitrt_compiler",
],
"//conditions:default": [],
}),
@@ -1852,12 +1857,10 @@
"XLA_FLAGS": "--xla_gpu_jitrt_executable",
},
tags = [
- "broken", # b/239679478: Switching AOT support from BEF to JitRt.
"gpu",
"no_oss",
"no_rocm",
"nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan.
- "notap", # b/239679478: Switching AOT support from BEF to JitRt.
"requires-gpu-nvidia",
],
deps = [
@@ -2598,6 +2601,7 @@
"//tensorflow/compiler/xla/python:xla_client_test_gpu",
"//tensorflow/compiler/xla/service/gpu:cudnn_fused_conv_rewriter_test",
"//tensorflow/compiler/xla/service/gpu:custom_call_test",
+ "//tensorflow/compiler/xla/service/gpu:gpu_aot_compilation_test",
"//tensorflow/compiler/xla/service/gpu/tests:add_preds.hlo.test",
"//tensorflow/compiler/xla/service/gpu/tests:all_reduce.hlo.test",
"//tensorflow/compiler/xla/service/gpu/tests:concat.hlo.test",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 7d4dbb9..f88cb81 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -188,6 +188,10 @@
#if XLA_ENABLE_XLIR
#include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pass_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h"
+#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
+#include "tfrt/jitrt/jitrt_compiler.h" // from @tf_runtime
+namespace jitrt = ::tfrt::jitrt;
#endif // XLA_ENABLE_XLIR
namespace xla {
@@ -288,6 +292,24 @@
using OwnedThunkSchedule = GpuExecutable::OwnedThunkSchedule;
using OwnedJitRtProgram = GpuExecutable::OwnedJitRtProgram;
+StatusOr<std::unique_ptr<Executable>> JitRtAotCompilationResult::LoadExecutable(
+ Compiler* compiler, se::StreamExecutor* executor) const {
+ TF_ASSIGN_OR_RETURN(
+ HloModuleConfig hlo_module_config,
+ HloModule::CreateModuleConfigFromProto(
+ jitrt_executable_.hlo_module_proto(), GetDebugOptionsFromFlags()));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModule> hlo_module,
+ HloModule::CreateFromProto(jitrt_executable_.hlo_module_proto(),
+ hlo_module_config));
+ auto gpu_compiler = tensorflow::down_cast<GpuCompiler*>(compiler);
+ return GpuExecutable::LoadFromObjFile(
+ std::move(hlo_module), jitrt_executable_.obj_file(),
+ jitrt_executable_.mlir_module(), jitrt_executable_.entry_func_attrs(),
+ GetDebugOptionsFromFlags(), gpu_compiler->GetGpuVersion(executor),
+ executor);
+}
+
GpuCompiler::GpuCompiler(se::Platform::Id platform_id,
const char* target_triple, const char* data_layout)
: platform_id_(platform_id),
@@ -1378,7 +1400,77 @@
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& options) {
+#if XLA_ENABLE_XLIR
+ CHECK(options.PlatformId() == se::cuda::kCudaPlatformId);
+ CHECK(options.executor() != nullptr);
+ auto stream_exec = options.executor();
+
+ std::vector<std::unique_ptr<HloModule>> modules =
+ module_group->ConsumeModules();
+ std::vector<std::unique_ptr<AotCompilationResult>> results;
+
+ for (const auto& module : modules) {
+ llvm::LLVMContext llvm_context;
+ GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
+
+ // Compile the module
+ CompileModuleResults compile_module_results;
+ TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
+ module.get(), &llvm_context, target_triple_, data_layout_,
+ stream_exec->platform()->Name(), stream_exec->platform()->id(),
+ gpu_device_info,
+ stream_exec->GetDeviceDescription().cuda_compute_capability(),
+ stream_exec->GetDeviceDescription().rocm_compute_capability(),
+ GetCanShareBuffer(), pointer_size_, &compile_module_results));
+ auto& compiled_executable = compile_module_results.executable;
+
+ if (!std::holds_alternative<OwnedJitRtProgram>(compiled_executable)) {
+ return InternalError("JitRtProgram not provided");
+ }
+
+ const auto& program = std::get<OwnedJitRtProgram>(compiled_executable);
+
+ // Options for the default JitRt compilation pipeline.
+ jitrt::CompilationPipelineOptions copts;
+ copts.num_worker_threads = 1;
+
+ // Options for constructing JitRt JitExecutable.
+ jitrt::CompilationOptions opts;
+ opts.specialization = jitrt::CompilationOptions::Specialization::kDisabled;
+ opts.register_dialects = jitrt::RegisterDefaultJitRtDialects;
+
+ // Register JitRt Gpu runtime custom calls with the linker.
+ opts.runtime_symbol_map = GetSymbolsBinding(JitRtGpuCustomCalls());
+
+ opts.create_compilation_pipeline = [copts](mlir::PassManager& pm) {
+ jitrt::CreateDefaultJitRtCompilationPipeline(pm, copts);
+ };
+
+ // Instantiate new JitExecutable from the MLIR source.
+ auto jit_executable = jitrt::JitExecutable::Instantiate(
+ program->module, program->entry_point, opts);
+ if (auto err = jit_executable.takeError())
+ return InternalError("Failed to compile JitRt program: %s",
+ tfrt::StrCat(err));
+
+ // For static shapes we can always serialize only the default executable.
+ jitrt::Executable& executable = jit_executable->DefaultExecutable().get();
+
+ // Check if JitRt executable saved the compilation result.
+ std::unique_ptr<llvm::MemoryBuffer> obj_file = executable.obj_file();
+ if (!obj_file)
+ return InternalError("JitRt executable didn't save the obj file");
+
+ std::string data(obj_file->getBuffer().data(),
+ obj_file->getBuffer().size());
+ results.emplace_back(std::make_unique<xla::gpu::JitRtAotCompilationResult>(
+ module->ToProto(), data, program->module,
+ compile_module_results.entry_func_attrs));
+ }
+ return std::move(results);
+#else
return Unimplemented("");
+#endif // XLA_ENABLE_XLIR
}
HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
index 020a674..8d25f5f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
@@ -31,12 +31,49 @@
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
namespace xla {
namespace gpu {
+// TODO(b/232263665): It should be shared between GPU and CPU.
+class JitRtAotCompilationResult : public AotCompilationResult {
+ public:
+ static StatusOr<std::unique_ptr<JitRtAotCompilationResult>> FromString(
+ const std::string& serialized) {
+ JitRtExecutableProto jitrt_executable;
+ if (!jitrt_executable.ParseFromString(serialized)) {
+ return InternalError("Failed to parse serialized JitRtExecutableProto.");
+ }
+ return std::unique_ptr<JitRtAotCompilationResult>(
+ new JitRtAotCompilationResult(std::move(jitrt_executable)));
+ }
+
+ JitRtAotCompilationResult(HloModuleProto hlo, const std::string& obj_file,
+ const std::string& mlir_module,
+ EntryFunctionAttributes entry_func_attrs) {
+ *jitrt_executable_.mutable_hlo_module_proto() = hlo;
+ *jitrt_executable_.mutable_entry_func_attrs() = entry_func_attrs;
+ jitrt_executable_.set_obj_file(obj_file);
+ jitrt_executable_.set_mlir_module(mlir_module);
+ }
+
+ StatusOr<std::string> SerializeAsString() const override {
+ return jitrt_executable_.SerializeAsString();
+ }
+
+ StatusOr<std::unique_ptr<Executable>> LoadExecutable(
+ Compiler* compiler, se::StreamExecutor* executor) const override;
+
+ private:
+ explicit JitRtAotCompilationResult(JitRtExecutableProto jitrt_executable)
+ : jitrt_executable_(std::move(jitrt_executable)) {}
+
+ JitRtExecutableProto jitrt_executable_;
+};
+
// The GPU compiler generates efficient GPU executables.
class GpuCompiler : public LLVMCompiler {
public:
@@ -73,6 +110,13 @@
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
+ // Returns a (deserialized) AotCompilationResult from a serialized
+ // AotCompilationResult.
+ StatusOr<std::unique_ptr<AotCompilationResult>> LoadAotCompilationResult(
+ const std::string& serialized_aot_result) override {
+ return JitRtAotCompilationResult::FromString(serialized_aot_result);
+ }
+
protected:
virtual Status OptimizeHloPostLayoutAssignment(
HloModule* hlo_module, se::StreamExecutor* stream_exec,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index eea135a..e41e11e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -26,6 +26,7 @@
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
+#include "mlir/Parser/Parser.h" // from @llvm-project
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
@@ -56,6 +57,7 @@
#include "tfrt/jitrt/diagnostics.h" // from @tf_runtime
#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
#include "tfrt/jitrt/jitrt_compiler.h" // from @tf_runtime
+#include "tfrt/init_tfrt_dialects.h" // from @tf_runtime
#endif // XLA_ENABLE_XLIR
namespace xla {
@@ -135,12 +137,23 @@
tfrt::StrCat(err));
// Pass ownership to the GpuExecutable.
- return new JitRtExecutable(std::move(program->buffer_sizes),
- std::move(*jit_executable),
- std::move(program->debug_options));
+ return new JitRtExecutable(
+ std::move(program->buffer_sizes),
+ std::make_unique<jitrt::JitExecutable>(std::move(*jit_executable)),
+ std::move(program->debug_options));
}
- jitrt::JitExecutable& jit_executable() { return jit_executable_; }
+ // Create JitRtExecutable from the AOT compiled binary.
+ static StatusOr<JitRtExecutable*> Create(
+ absl::Span<const int64_t> buffer_sizes, jitrt::Executable executable,
+ DebugOptions debug_options) {
+ // Pass ownership to the GpuExecutable.
+ return new JitRtExecutable(
+ std::vector<int64_t>(buffer_sizes.begin(), buffer_sizes.end()),
+ std::make_unique<jitrt::Executable>(std::move(executable)),
+ std::move(debug_options));
+ }
+
jitrt::Executable& default_executable() { return *default_executable_; }
JitRtKernelsCache& kernels_cache() { return kernels_cache_; }
JitRtGemmConfigCache& gemm_configs_cache() { return gemm_configs_cache_; }
@@ -155,17 +168,30 @@
const DebugOptions& debug_options() const { return debug_options_; }
private:
- explicit JitRtExecutable(std::vector<int64_t> buffer_sizes,
- jitrt::JitExecutable jit_executable,
- DebugOptions debug_options)
+ JitRtExecutable(std::vector<int64_t> buffer_sizes,
+ std::unique_ptr<jitrt::JitExecutable> jit_executable,
+ DebugOptions debug_options)
: buffer_sizes_(std::move(buffer_sizes)),
jit_executable_(std::move(jit_executable)),
- default_executable_(&jit_executable_.DefaultExecutable().get()),
+ default_executable_(&jit_executable_->DefaultExecutable().get()),
+ debug_options_(std::move(debug_options)) {}
+
+ JitRtExecutable(std::vector<int64_t> buffer_sizes,
+ std::unique_ptr<jitrt::Executable> aot_executable,
+ DebugOptions debug_options)
+ : buffer_sizes_(std::move(buffer_sizes)),
+ aot_executable_(std::move(aot_executable)),
+ executable_(aot_executable.get()),
debug_options_(std::move(debug_options)) {}
std::vector<int64_t> buffer_sizes_;
- jitrt::JitExecutable jit_executable_;
+
+ std::unique_ptr<jitrt::JitExecutable> jit_executable_;
jitrt::Executable* default_executable_; // owned by `jit_executable`
+
+ std::unique_ptr<jitrt::Executable> aot_executable_;
+ jitrt::Executable* executable_;
+
DebugOptions debug_options_;
// Keep a cache of kernels instantiated by this executable.
@@ -1002,5 +1028,113 @@
return output;
}
+GpuExecutable::GpuExecutable(
+ std::shared_ptr<HloModule> hlo_module, GpuVersion gpu_version,
+ xla::EntryFunctionAttributes entry_func_attrs,
+ absl::string_view module_name, Shape xla_output_shape,
+ std::vector<BufferAllocation> allocations,
+ absl::flat_hash_map<ShapeIndex, OutputInfo> output_info,
+ JitRtExecutable* jitrt_executable)
+ : Executable(std::move(hlo_module)),
+ gpu_version_(gpu_version),
+ entry_func_attrs_(entry_func_attrs),
+ module_name_(module_name),
+ output_shape_(xla_output_shape),
+ allocations_(std::move(allocations)),
+ output_info_(std::move(output_info)),
+ jitrt_executable_(jitrt_executable) {
+ XlaDebugInfoManager::Get()->RegisterModule(
+ module().unique_id(), shared_module(), debug_buffer_assignment_);
+}
+
+StatusOr<std::unique_ptr<Executable>> GpuExecutable::LoadFromObjFile(
+ std::shared_ptr<HloModule> hlo_module, absl::string_view obj_file,
+ absl::string_view mlir_module,
+ xla::EntryFunctionAttributes entry_func_attrs, DebugOptions debug_options,
+ GpuVersion gpu_version, se::StreamExecutor* executor) {
+#if XLA_ENABLE_XLIR
+ // Load MLIR module behind the compiled object file to recover XLA allocations
+ // and output info details. Also recover buffer sizes from the entrypoint
+ // function signature.
+ mlir::MLIRContext context;
+
+ mlir::DialectRegistry registry;
+ tfrt::RegisterTFRTDialects(registry);
+ tfrt::RegisterTFRTCompiledDialects(registry);
+ context.appendDialectRegistry(registry);
+
+ auto module = mlir::parseSourceString<mlir::ModuleOp>(mlir_module, &context);
+ if (!module) return InternalError("Failed to parse AOT compiled module");
+
+ // Get the XLA module entrypoint function.
+ auto func = mlir::cast<mlir::func::FuncOp>(
+ module->lookupSymbol(hlo_module->entry_computation()->name()));
+
+ // Get the buffer sizes from the entrypoint function signature.
+ std::vector<int64_t> buffer_sizes;
+ buffer_sizes.reserve(func.getNumArguments());
+ for (auto type : func.getArgumentTypes()) {
+ auto memref = type.dyn_cast<mlir::MemRefType>();
+ if (!memref || !memref.hasStaticShape() || memref.getRank() != 1)
+ return InternalError("Illegal entrypoint argument type: %s",
+ tfrt::StrCat(type));
+ buffer_sizes.push_back(memref.getDimSize(0));
+ }
+
+ // Infer XLA allocations and output info from the MLIR module.
+ std::vector<BufferAllocation> allocations;
+ absl::flat_hash_map<ShapeIndex, OutputInfo> output_info;
+ Shape result_xla_shape;
+ TF_RETURN_IF_ERROR(SetUpMlirAllocation(func, buffer_sizes, &allocations,
+ &output_info, &result_xla_shape,
+ /*buffer_param_offset=*/0));
+
+ // Create a named buffer from compiled object file.
+ llvm::StringRef data(obj_file.data(), obj_file.size());
+ auto buffer = llvm::MemoryBuffer::getMemBuffer(data, hlo_module->name());
+
+ // Create a JitRt function signature (all arguments passed as 1d memrefs).
+ llvm::SmallVector<std::unique_ptr<jitrt::Type>> args;
+ llvm::SmallVector<std::unique_ptr<jitrt::Type>> rt_args;
+ rt_args.push_back(std::make_unique<jitrt::KernelContextOperandType>());
+
+ for (int64_t size : buffer_sizes) {
+ auto i8 = tfrt::DType::I8;
+ args.push_back(std::make_unique<jitrt::MemrefType>(size, i8));
+ rt_args.push_back(std::make_unique<jitrt::MemrefType>(size, i8));
+ }
+
+ jitrt::FunctionType signature(std::move(args), /*results=*/{});
+ jitrt::FunctionType rt_signature(std::move(rt_args), /*results=*/{});
+
+ auto symbol_map = GetSymbolsBinding(JitRtGpuCustomCalls());
+
+ // Load JitRt executable from an object file, and link it with Gpu runtime
+ // intrinsics implementing Gpu custom calls.
+ auto executable = jitrt::Executable::LoadFromObjFile(
+ hlo_module->name(), std::move(buffer),
+ hlo_module->entry_computation()->name(), std::move(signature),
+ std::move(rt_signature), symbol_map);
+ if (auto err = executable.takeError())
+ return InternalError("Failed to load JitRt executable: %s",
+ tfrt::StrCat(err));
+
+ // Move jitrt::Executable ownership to the JitRtExecutable.
+ TF_ASSIGN_OR_RETURN(
+ JitRtExecutable * jitrt_executable,
+ JitRtExecutable::Create(buffer_sizes, std::move(*executable),
+ std::move(debug_options)));
+
+ // Construct GpuExecutable for the loaded JitRt executable.
+ std::string name = hlo_module->name();
+ return std::unique_ptr<Executable>(
+ new GpuExecutable(std::move(hlo_module), gpu_version, entry_func_attrs,
+ name, result_xla_shape, std::move(allocations),
+ std::move(output_info), jitrt_executable));
+
+#else // XLA_ENABLE_XLIR
+ return FailedPrecondition("Not built with XLA_ENABLE_XLIR");
+#endif // XLA_ENABLE_XLIR
+}
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 818cd7f..35a1adc 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -141,6 +141,24 @@
absl::flat_hash_map<ShapeIndex, OutputInfo>* output_info,
Shape* output_shape, int buffer_param_offset = 0);
+ // Returns an Executable that is loaded from an object file (XLA program
+ // compiled to a native function using the JitRt stack).
+ static StatusOr<std::unique_ptr<Executable>> LoadFromObjFile(
+ std::shared_ptr<HloModule> hlo_module, absl::string_view obj_file,
+ absl::string_view mlir_module,
+ xla::EntryFunctionAttributes entry_func_attrs, DebugOptions debug_options,
+ GpuVersion gpu_version, stream_executor::StreamExecutor* executor);
+
+ // Constructor to use when loading a GpuExecutable from an object file (native
+ // function compiled for JitRt). Omits setting class members that aren't used
+ // in JitRt execution mode.
+ GpuExecutable(std::shared_ptr<HloModule> hlo_module, GpuVersion gpu_version,
+ xla::EntryFunctionAttributes entry_func_attrs,
+ absl::string_view module_name, Shape xla_output_shape,
+ std::vector<BufferAllocation> allocations,
+ absl::flat_hash_map<ShapeIndex, OutputInfo> output_info,
+ JitRtExecutable* jitrt_executable);
+
static StatusOr<std::unique_ptr<GpuExecutable>> Create(Params params);
~GpuExecutable() override;
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index d9926e2..d0f279b 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -748,3 +748,25 @@
// xla::Shape in string format.
string result_xla_shape = 2;
}
+
+// Encodes the underlying JitRt executable compiled from the XLA module.
+message JitRtExecutableProto {
+ HloModuleProto hlo_module_proto = 1;
+
+ // XLA-specific attributes of the executable's entry function.
+ EntryFunctionAttributes entry_func_attrs = 2;
+
+ // TODO(b/232263665)): We need to know the TargetMachine this executable was
+ // compiled for, otherwise we can accidentally use illegal instrauctions (e.g.
+ // use AVX512 when it's not available).
+
+ // TODO(b/232263665)): Serialized executable has to know what APIs it has to
+ // be linked with, including the version. For example Gpu executable must be
+ // linked with a runtime layer that abstracts over CUDA.
+
+ // Serialized object file compiled from the XLA module.
+ bytes obj_file = 3;
+
+ // Serialized MLIR module corresponding to compiled object file.
+ string mlir_module = 4;
+}
\ No newline at end of file