blob: 757d496aa4e7775e96a604a5c93f2fe08362fa28 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "llvm/IR/LLVMContext.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Target/NVVMIR.h" // from @llvm-project
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/cuda_libdevice_path.h"
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
namespace xla {
namespace mlir_gpu {
namespace {
using ::mlir::BlockArgument;
using ::mlir::dyn_cast;
using ::mlir::FuncOp;
using ::mlir::ModuleOp;
using ::mlir::OwningModuleRef;
using ::mlir::UnknownLoc;
using ::mlir::Value;
using ::mlir::gpu::LaunchFuncOp;
using ::mlir::LLVM::LLVMDialect;
using ::mlir::LLVM::LLVMFuncOp;
using ::mlir::LLVM::LLVMType;
using ::xla::gpu::GpuExecutable;
using ::xla::gpu::GpuHloSchedule;
using ::xla::gpu::GpuVersion;
using ::xla::gpu::StreamAssignment;
using ::xla::gpu::ThunkSchedule;
// A Compiler implementation that converts XLAs IR to a matching MLIR dialect,
// performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for
// generation of a thunk suitable for XLAs runtime.
class MlirCompilerImpl : public MlirCompiler {
public:
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
const CompileOptions& options) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& options) override;
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
int64 pointer_size = data_layout_.getPointerSize();
return [pointer_size](const Shape& shape) {
return ShapeUtil::ByteSizeOf(shape, pointer_size);
};
}
};
// TODO(b/137624192) Share with NVPTX compiler
static std::vector<std::string> CandidateCudaRoots(
const HloModuleConfig& config) {
return tensorflow::CandidateCudaRoots(
config.debug_options().xla_gpu_cuda_data_dir());
}
void PrintCantFindCudaMessage(absl::string_view msg,
const HloModuleConfig& hlo_module_config) {
LOG(WARNING) << msg;
LOG(WARNING) << "Searched for CUDA in the following directories:";
for (const auto& dir : CandidateCudaRoots(hlo_module_config)) {
LOG(WARNING) << " " << dir;
}
LOG(WARNING)
<< "You can choose the search directory by setting xla_gpu_cuda_data_dir "
"in HloModule's DebugOptions. For most apps, setting the environment "
"variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.";
}
// Returns the directory containing nvvm libdevice files.
std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) {
const std::string libdevice_dir =
tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice");
VLOG(2) << "Looking for libdevice at " << libdevice_dir;
if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) {
VLOG(2) << "Found libdevice dir " << libdevice_dir;
return libdevice_dir;
}
}
PrintCantFindCudaMessage(
"Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may "
"result in compilation or runtime failures, if the program we try to run "
"uses routines from libdevice.",
hlo_module_config);
// GetCudaRootCandidates always includes ".", but if everything fails, we
// return it anyway. Better than returning the empty string.
return ".";
}
StatusOr<std::unique_ptr<HloModule>> MlirCompilerImpl::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
const CompileOptions& options) {
// Until we find a reason to do something different, run the same passes
// that the normal GPU backend runs.
gpu::NVPTXCompiler xla_compiler;
TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec,
options.device_allocator));
TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get()));
return std::move(module);
}
// TODO(b/137624192): Move this to custom call handling and share.
absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
const HloInstruction* operand,
const ShapeIndex& user_index) {
if (user->opcode() == HloOpcode::kCustomCall) {
// Share the bias buffer with the parent instruction.
if (user->custom_call_target() == xla::gpu::kGemmCallTarget) {
if (user->operand_count() == 3 && user->operand(2) == operand) {
return true;
}
}
// The operand of cholesky can be shared with the first output.
if (user->custom_call_target() == xla::gpu::kCusolverCholeskyCallTarget) {
return user_index.size() == 1 && user_index[0] == 0;
}
}
return absl::nullopt;
}
// TODO(b/137624192): Share this with nvptx backend.
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) {
int cc_major, cc_minor;
const auto& device_description = stream_exec->GetDeviceDescription();
if (!device_description.cuda_compute_capability(&cc_major, &cc_minor)) {
LOG(WARNING)
<< "Couldn't get compute capability for device; assuming sm_20.";
cc_major = 2;
cc_minor = 0;
}
return std::make_pair(cc_major, cc_minor);
}
// Return the constant launch bound along the "x" dimension in "dim" if all the
// other dimensions are 1. Return nullopt otherwise or when any of the bounds
// is not constant.
static absl::optional<int64> getLaunchBound(const mlir::gpu::KernelDim3& dim) {
auto get_constant = [](mlir::Operation* op,
mlir::StringRef name) -> absl::optional<int64> {
if (auto constant = llvm::dyn_cast_or_null<mlir::ConstantOp>(op)) {
return constant.value().cast<mlir::IntegerAttr>().getInt();
}
op->emitError() << "bound " << name << " is not constant";
return absl::nullopt;
};
auto y_op = dim.y.getDefiningOp();
auto dim_y = get_constant(y_op, "y");
if (!dim_y.has_value() || dim_y.value() != 1) {
y_op->emitError() << "bound 'y' is not constant 1";
return absl::nullopt;
}
auto z_op = dim.z.getDefiningOp();
auto dim_z = get_constant(z_op, "z");
if (!dim_z.has_value() || dim_z.value() != 1) {
z_op->emitError() << "bound 'z' is not constant 1";
return absl::nullopt;
}
return get_constant(dim.x.getDefiningOp(), "x");
}
// Indexes of a range of arguments in a GPU function. This is used to keep the
// range of arguments that correspond to a lowered kernel argument of
// (previously) memref type.
struct LaunchFuncArgument {
int kernel_argument_begin;
int kernel_argument_size;
};
using OperandToValueMap =
absl::flat_hash_map<const HloInstruction*, std::vector<LaunchFuncArgument>>;
static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
OperandToValueMap* operand_to_value_map, const HloInstruction* instr,
LaunchFuncOp launchOp, LLVMFuncOp kernel) {
auto operands = instr->operands();
std::vector<const HloInstruction*> ordered_operands;
bool has_failed = false;
// A memref will expand into multiple kernel operands, accumulate their number
// in order to find them later.
int cur_operand_position = 0;
for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
++kernel_index) {
auto launchop_operand =
launchOp.getKernelOperand(kernel_index).dyn_cast<BlockArgument>();
if (!launchop_operand) {
launchOp.emitError("argument to kernel is not a function input");
has_failed = true;
continue;
}
auto memref_type =
launchop_operand.getType().dyn_cast<::mlir::MemRefType>();
if (!memref_type) {
launchOp.emitError("only memref-typed arguments are supported");
has_failed = true;
break;
}
// host_index is the argument position to the surrounding function that
// contains the launch. This index corresponds to HLO operand indices
// by construction.
auto host_index = launchop_operand.getArgNumber();
// The trailing argument to the outer function are the results.
auto operand =
(host_index < operands.size()) ? operands[host_index] : instr;
if (!operand_to_value_map->count(operand)) {
ordered_operands.push_back(operand);
}
// Associate the HLO operand with the argument values of the kernel
// function.
int num_unpacked =
mlir::MemRefDescriptor::getNumUnpackedValues(memref_type);
(*operand_to_value_map)[operand].push_back(
{cur_operand_position, num_unpacked});
cur_operand_position += num_unpacked;
}
if (has_failed) {
return InternalError("Mapping operands to kernel arguments has failed.");
}
return ordered_operands;
}
Status InsertBufferLoadPreduleIntoKernel(
LLVMFuncOp kernel, const OperandToValueMap& operand_to_value_map,
const std::vector<const HloInstruction*>& ordered_operands,
BufferAssignment* assignment,
const std::vector<const BufferAllocation*>& buffers) {
mlir::OpBuilder builder(kernel.getBody());
auto* context = kernel.getContext();
auto offset_type = LLVMType::getInt64Ty(context);
auto ptr_type = LLVMType::getInt8PtrTy(context);
auto void_type = LLVMType::getVoidTy(context);
auto loc = kernel.getLoc();
auto num_original_args = kernel.getNumArguments();
std::vector<LLVMType> new_arg_types(buffers.size(), ptr_type);
kernel->setAttr(kernel.getTypeAttrName(),
mlir::TypeAttr::get(LLVMType::getFunctionTy(
void_type, new_arg_types, /*isVarArg=*/false)));
std::vector<Value> original_args(kernel.args_begin(), kernel.args_end());
std::vector<mlir::Type> as_mlir_types(new_arg_types.begin(),
new_arg_types.end());
auto new_args = kernel.front().addArguments(as_mlir_types);
std::vector<Value> buffer_args(new_args.begin(), new_args.end());
for (auto operand : ordered_operands) {
TF_ASSIGN_OR_RETURN(auto slice,
assignment->GetUniqueTopLevelSlice(operand));
auto buffer = std::find(buffers.begin(), buffers.end(), slice.allocation());
auto index = buffer - buffers.begin();
auto offset = builder.create<mlir::LLVM::ConstantOp>(
loc, offset_type, builder.getI64IntegerAttr(slice.offset()));
auto ptr = buffer_args[index];
// Replace uses of function arguments pertaining to memref descriptors with
// values derived from HLO buffers. The instructions inserting these values
// into memref descriptors were already introduced during the lowering phase
// as per MLIR calling convention.
for (auto arg : operand_to_value_map.at(operand)) {
mlir::MemRefDescriptorView original(
mlir::ValueRange(original_args)
.slice(arg.kernel_argument_begin, arg.kernel_argument_size));
// Allocated and aligned pointers are the same.
auto casted = builder.create<mlir::LLVM::BitcastOp>(
loc, original.alignedPtr().getType().cast<LLVMType>(),
mlir::ValueRange(ptr));
original.alignedPtr().replaceAllUsesWith(casted);
original.allocatedPtr().replaceAllUsesWith(casted);
// Use the offset of the HLO buffer instead of the one expected in the
// function call.
original.offset().replaceAllUsesWith(offset);
// Fill the shape.
auto shape = operand->shape();
// Unless the operand is a scalar pointer, also fill shape and strides.
if (shape.dimensions().empty()) {
continue;
}
// TODO(b/137624192) Pass in the descriptor to allow for dynamic shapes.
assert(shape.IsArray() && shape.is_static());
for (auto extent : llvm::enumerate(shape.dimensions())) {
auto shape = builder.create<mlir::LLVM::ConstantOp>(
loc, original.size(extent.index()).getType(),
builder.getI64IntegerAttr(extent.value()));
original.size(extent.index()).replaceAllUsesWith(shape);
}
// Finally, fill the strides.
// TODO(b/137624192): Take assigned layout into account.
uint64_t accumulator = 0;
for (int64_t idx = shape.rank() - 1; idx >= 0; --idx) {
if (accumulator == 0) {
accumulator = 1;
} else {
accumulator *= shape.dimensions(idx + 1);
}
auto stride = builder.create<mlir::LLVM::ConstantOp>(
loc, original.stride(idx).getType(),
builder.getI64IntegerAttr(accumulator));
original.stride(idx).replaceAllUsesWith(stride);
}
}
}
// Now we can remove the original arguments, as they should have no more
// users.
for (int i = 0; i < num_original_args; ++i) {
kernel.front().eraseArgument(0);
}
return Status::OK();
}
StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk(
FuncOp func, const HloInstruction* const instr, ModuleOp kernel_module,
BufferAssignment* assignment) {
// Find the single LaunchFuncOp and compute a mapping from operands of
// the hlo instruction to the corresponding values of the kernel
// function in the target module;
LaunchFuncOp launchOp;
auto walkResult = func.walk([&launchOp](LaunchFuncOp op) {
if (launchOp) {
op.emitError("multiple kernels for single top-level HLO");
return mlir::WalkResult::interrupt();
}
launchOp = op;
return mlir::WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
return InternalError("Multiple kernels for single top-level HLO");
}
if (!launchOp) {
// If there was no launchOp, then no kernel was generated, so the lowering
// from the LHLO ops to the GPU dialect is not implemented yet.
return Unimplemented("No kernel was generated.");
}
auto kernel =
kernel_module.lookupSymbol<LLVMFuncOp>(launchOp.getKernelName());
// Store the assignment of operands to block arguments. Note that an operand
// might be used in multiple argument positions, hence the vector.
OperandToValueMap operand_to_value_map;
TF_ASSIGN_OR_RETURN(
auto ordered_operands,
ComputeOperandToValueMap(&operand_to_value_map, instr, launchOp, kernel));
// Get the required buffers to support the inputs. Use a set and vector here
// to keep the order fixed. This is mostly useful for testing.
std::unordered_set<const BufferAllocation*> buffers_needed;
std::vector<const BufferAllocation*> buffers;
// TODO(b/137624192) Add support for tuples.
for (auto operand : ordered_operands) {
TF_ASSIGN_OR_RETURN(auto buffer,
assignment->GetUniqueTopLevelSlice(operand));
if (buffers_needed.insert(buffer.allocation()).second) {
buffers.push_back(buffer.allocation());
}
}
// TODO(b/137624192) Add support for temp buffer.
// TODO(b/137624192) Add support for constant buffers.
// Change the signature to match what the XLA runtime expects from the
// kernel.
TF_RETURN_IF_ERROR(InsertBufferLoadPreduleIntoKernel(
kernel, operand_to_value_map, ordered_operands, assignment, buffers));
// Finally, create the thunk and set the launch dimensions.
gpu::Thunk::ThunkInfo info;
auto thunk = absl::make_unique<gpu::KernelThunk>(info, buffers,
kernel.getName().str());
// Set launch bounds.
mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues();
mlir::gpu::KernelDim3 grid = launchOp.getGridSizeOperandValues();
absl::optional<int64> num_threads = getLaunchBound(block);
absl::optional<int64> num_blocks = getLaunchBound(grid);
if (!num_threads || !num_blocks) {
return Unimplemented("Unsupported launch bounds");
}
thunk->SetLaunchDimensions(gpu::LaunchDimensions(*num_blocks, *num_threads));
return std::move(thunk);
}
StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
const CompileOptions& options) {
// Determine the HLO schedule, which is an ordering of HLO instructions. This
// is used by buffer assignment to enable buffer reuse, and the same ordering
// must also be used to determine the thunk launch schedule.
std::unique_ptr<StreamAssignment> stream_assignment =
xla::gpu::AssignStreams(*module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<GpuHloSchedule> hlo_schedule,
GpuHloSchedule::Build(*module, *stream_assignment,
data_layout_.getPointerSize()));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferAssignment> buffer_assignment,
BufferAssigner::Run(
module.get(), hlo_schedule->ConsumeHloOrdering(),
BufferSizeBytesFunction(),
/*color_alignment=*/
[](LogicalBuffer::Color) {
return xla::gpu::kXlaAllocatedBufferAlignBytes;
},
/*allocate_buffers_for_constants=*/true,
/*colorer=*/BufferAssigner::DefaultColorer(),
/*must_not_live_out=*/{}, &CanShareBufferHint));
DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations");
EmissionContext emission_context(std::move(module));
if (error_handler_) {
emission_context.setErrorHandler(error_handler_);
}
OwningModuleRef mlir_module =
ModuleOp::create(UnknownLoc::get(emission_context.getContext()));
LhloDialectEmitter lhlo_emitter(&emission_context, *buffer_assignment,
stream_exec->platform(), *mlir_module);
absl::flat_hash_map<const HloInstruction*, std::unique_ptr<gpu::Thunk>>
hlo_to_thunk;
for (HloInstruction* instruction : hlo_schedule->ThunkLaunchOrder()) {
TF_RETURN_IF_ERROR(instruction->Visit(&lhlo_emitter));
gpu::ThunkSequence thunks = lhlo_emitter.ConsumeThunkSequence();
TF_RET_CHECK(thunks.size() <= 1) << instruction->ToString();
if (!thunks.empty()) {
auto thunk = std::move(thunks.front());
hlo_to_thunk[instruction] = std::move(thunk);
}
}
TF_RETURN_IF_ERROR(
module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module));
TF_RETURN_IF_ERROR(LowerLHLOToGPU(*mlir_module));
TF_RETURN_IF_ERROR(
module_hook_.invoke(IRHook::LoweringStage::GPU, *mlir_module));
TF_RETURN_IF_ERROR(LowerKernelBodiesToNVVM(*mlir_module));
TF_RETURN_IF_ERROR(
module_hook_.invoke(IRHook::LoweringStage::LLVM, *mlir_module));
TF_ASSIGN_OR_RETURN(OwningModuleRef kernel_module,
ExtractKernelModule(*mlir_module));
for (auto entry : lhlo_emitter.InstructionToFunctionMap()) {
TF_ASSIGN_OR_RETURN(
auto thunk,
TransformKernelToXlaThunk(entry.second, entry.first, *kernel_module,
buffer_assignment.get()));
hlo_to_thunk[entry.first] = std::move(thunk);
}
absl::flat_hash_map<const gpu::Thunk*, const HloInstruction*> thunk_to_hlo;
gpu::ThunkSequence thunk_sequence;
{
for (HloInstruction* hlo : hlo_schedule->ThunkLaunchOrder()) {
auto it = hlo_to_thunk.find(hlo);
if (it != hlo_to_thunk.end()) {
const HloInstruction* hlo = it->first;
auto& thunk = it->second;
thunk_to_hlo[thunk.get()] = hlo;
thunk_sequence.push_back(std::move(thunk));
}
}
}
TF_RETURN_IF_ERROR(
module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module));
// Translate to LLVM IR in a fresh context. The module is further translated
// to textual PTX and a CUBIN blob so there is no need for the context to live
// longer than this function.
llvm::LLVMContext llvmContext;
auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext);
if (!llvmModule) {
return InternalError("Translation to LLVM failed");
}
llvmModule->setModuleIdentifier(emission_context.getHloModule()->name());
// TODO(herhut): Why is this needed and does not come from the template?
llvmModule->setDataLayout(gpu::nvptx::kDataLayout);
const auto& config = emission_context.getHloModule()->config();
TF_ASSIGN_OR_RETURN(
auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(),
GetGpuVersion(stream_exec),
config, GetLibdeviceDir(config)));
// Allow to fallback to the driver compilation when ptxas isn't able to
// compile.
StatusOr<std::vector<uint8>> maybe_cubin =
se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(),
gpu::PtxOptsFromConfig(config));
std::vector<uint8> cubin;
if (maybe_cubin.ok()) {
cubin = std::move(maybe_cubin).ValueOrDie();
} else if (maybe_cubin.status().code() ==
tensorflow::error::Code::UNIMPLEMENTED) {
xla::gpu::WarnIfBadDriverJITVersion();
} else {
return maybe_cubin.status();
}
auto thunk_schedule = absl::make_unique<ThunkSchedule>(
std::make_unique<gpu::ThunkSequence>(std::move(thunk_sequence)),
std::move(stream_assignment), std::move(thunk_to_hlo));
if (DumpingEnabledForHloModule(*emission_context.getHloModule())) {
DumpToFileInDirOrStdout(*emission_context.getHloModule(), "",
"thunk_schedule", thunk_schedule->ToString());
}
// TODO(b/137624192): Add profiling support.
return {absl::make_unique<GpuExecutable>(
ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule),
emission_context.releaseHloModule(), std::move(buffer_assignment),
nullptr, nullptr, std::vector<GpuExecutable::ConstantInfo>())};
}
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
const CompileOptions& options) {
return Unimplemented("Not yet implemented in MLIR compiler");
}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
MlirCompilerImpl::CompileAheadOfTime(
std::unique_ptr<HloModuleGroup> /*module_group*/,
const AotCompilationOptions& /*options*/) {
return Unimplemented("Not yet implemented in MLIR compiler");
}
} // namespace
} // namespace mlir_gpu
} // namespace xla
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
stream_executor::cuda::kCudaPlatformId, []() {
return absl::make_unique<xla::FailoverCompiler>(
absl::make_unique<xla::mlir_gpu::MlirCompilerImpl>(),
absl::make_unique<xla::gpu::NVPTXCompiler>());
});
return true;
}
static bool module_initialized = InitModule();