| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <memory> |
| #include <string> |
| #include <vector> |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "llvm/IR/LLVMContext.h" |
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project |
| #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Location.h" // from @llvm-project |
| #include "mlir/IR/OperationSupport.h" // from @llvm-project |
| #include "mlir/IR/Value.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Target/NVVMIR.h" // from @llvm-project |
| #include "tensorflow/compiler/xla/service/buffer_assignment.h" |
| #include "tensorflow/compiler/xla/service/dump.h" |
| #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" |
| #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" |
| #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" |
| #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" |
| #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" |
| #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" |
| #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" |
| #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" |
| #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" |
| #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" |
| #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" |
| #include "tensorflow/compiler/xla/service/gpu/target_constants.h" |
| #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| #include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" |
| #include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" |
| #include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" |
| #include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" |
| #include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/io/path.h" |
| #include "tensorflow/core/platform/cuda_libdevice_path.h" |
| #include "tensorflow/stream_executor/gpu/asm_compiler.h" |
| |
| namespace xla { |
| namespace mlir_gpu { |
| namespace { |
| |
| using ::mlir::BlockArgument; |
| using ::mlir::dyn_cast; |
| using ::mlir::FuncOp; |
| using ::mlir::ModuleOp; |
| using ::mlir::OwningModuleRef; |
| using ::mlir::UnknownLoc; |
| using ::mlir::Value; |
| using ::mlir::gpu::LaunchFuncOp; |
| using ::mlir::LLVM::LLVMDialect; |
| using ::mlir::LLVM::LLVMFuncOp; |
| using ::mlir::LLVM::LLVMType; |
| using ::xla::gpu::GpuExecutable; |
| using ::xla::gpu::GpuHloSchedule; |
| using ::xla::gpu::GpuVersion; |
| using ::xla::gpu::StreamAssignment; |
| using ::xla::gpu::ThunkSchedule; |
| |
| // A Compiler implementation that converts XLAs IR to a matching MLIR dialect, |
| // performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for |
| // generation of a thunk suitable for XLAs runtime. |
| class MlirCompilerImpl : public MlirCompiler { |
| public: |
| StatusOr<std::unique_ptr<HloModule>> RunHloPasses( |
| std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, |
| const CompileOptions& options) override; |
| |
| StatusOr<std::unique_ptr<Executable>> RunBackend( |
| std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, |
| const CompileOptions& options) override; |
| |
| StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( |
| std::unique_ptr<HloModuleGroup> module_group, |
| std::vector<std::vector<se::StreamExecutor*>> stream_execs, |
| const CompileOptions& options) override; |
| |
| StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> |
| CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, |
| const AotCompilationOptions& options) override; |
| |
| HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { |
| int64 pointer_size = data_layout_.getPointerSize(); |
| return [pointer_size](const Shape& shape) { |
| return ShapeUtil::ByteSizeOf(shape, pointer_size); |
| }; |
| } |
| }; |
| |
| // TODO(b/137624192) Share with NVPTX compiler |
| static std::vector<std::string> CandidateCudaRoots( |
| const HloModuleConfig& config) { |
| return tensorflow::CandidateCudaRoots( |
| config.debug_options().xla_gpu_cuda_data_dir()); |
| } |
| |
| void PrintCantFindCudaMessage(absl::string_view msg, |
| const HloModuleConfig& hlo_module_config) { |
| LOG(WARNING) << msg; |
| LOG(WARNING) << "Searched for CUDA in the following directories:"; |
| |
| for (const auto& dir : CandidateCudaRoots(hlo_module_config)) { |
| LOG(WARNING) << " " << dir; |
| } |
| LOG(WARNING) |
| << "You can choose the search directory by setting xla_gpu_cuda_data_dir " |
| "in HloModule's DebugOptions. For most apps, setting the environment " |
| "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work."; |
| } |
| |
| // Returns the directory containing nvvm libdevice files. |
| std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { |
| for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) { |
| const std::string libdevice_dir = |
| tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); |
| VLOG(2) << "Looking for libdevice at " << libdevice_dir; |
| if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { |
| VLOG(2) << "Found libdevice dir " << libdevice_dir; |
| return libdevice_dir; |
| } |
| } |
| PrintCantFindCudaMessage( |
| "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may " |
| "result in compilation or runtime failures, if the program we try to run " |
| "uses routines from libdevice.", |
| hlo_module_config); |
| |
| // GetCudaRootCandidates always includes ".", but if everything fails, we |
| // return it anyway. Better than returning the empty string. |
| return "."; |
| } |
| |
| StatusOr<std::unique_ptr<HloModule>> MlirCompilerImpl::RunHloPasses( |
| std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, |
| const CompileOptions& options) { |
| // Until we find a reason to do something different, run the same passes |
| // that the normal GPU backend runs. |
| gpu::NVPTXCompiler xla_compiler; |
| TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec, |
| options.device_allocator)); |
| TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get())); |
| |
| return std::move(module); |
| } |
| |
| // TODO(b/137624192): Move this to custom call handling and share. |
| absl::optional<bool> CanShareBufferHint(const HloInstruction* user, |
| const HloInstruction* operand, |
| const ShapeIndex& user_index) { |
| if (user->opcode() == HloOpcode::kCustomCall) { |
| // Share the bias buffer with the parent instruction. |
| if (user->custom_call_target() == xla::gpu::kGemmCallTarget) { |
| if (user->operand_count() == 3 && user->operand(2) == operand) { |
| return true; |
| } |
| } |
| // The operand of cholesky can be shared with the first output. |
| if (user->custom_call_target() == xla::gpu::kCusolverCholeskyCallTarget) { |
| return user_index.size() == 1 && user_index[0] == 0; |
| } |
| } |
| return absl::nullopt; |
| } |
| |
| // TODO(b/137624192): Share this with nvptx backend. |
| GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { |
| int cc_major, cc_minor; |
| const auto& device_description = stream_exec->GetDeviceDescription(); |
| if (!device_description.cuda_compute_capability(&cc_major, &cc_minor)) { |
| LOG(WARNING) |
| << "Couldn't get compute capability for device; assuming sm_20."; |
| cc_major = 2; |
| cc_minor = 0; |
| } |
| return std::make_pair(cc_major, cc_minor); |
| } |
| |
| // Return the constant launch bound along the "x" dimension in "dim" if all the |
| // other dimensions are 1. Return nullopt otherwise or when any of the bounds |
| // is not constant. |
| static absl::optional<int64> getLaunchBound(const mlir::gpu::KernelDim3& dim) { |
| auto get_constant = [](mlir::Operation* op, |
| mlir::StringRef name) -> absl::optional<int64> { |
| if (auto constant = llvm::dyn_cast_or_null<mlir::ConstantOp>(op)) { |
| return constant.value().cast<mlir::IntegerAttr>().getInt(); |
| } |
| op->emitError() << "bound " << name << " is not constant"; |
| return absl::nullopt; |
| }; |
| auto y_op = dim.y.getDefiningOp(); |
| auto dim_y = get_constant(y_op, "y"); |
| if (!dim_y.has_value() || dim_y.value() != 1) { |
| y_op->emitError() << "bound 'y' is not constant 1"; |
| return absl::nullopt; |
| } |
| auto z_op = dim.z.getDefiningOp(); |
| auto dim_z = get_constant(z_op, "z"); |
| if (!dim_z.has_value() || dim_z.value() != 1) { |
| z_op->emitError() << "bound 'z' is not constant 1"; |
| return absl::nullopt; |
| } |
| return get_constant(dim.x.getDefiningOp(), "x"); |
| } |
| |
| // Indexes of a range of arguments in a GPU function. This is used to keep the |
| // range of arguments that correspond to a lowered kernel argument of |
| // (previously) memref type. |
| struct LaunchFuncArgument { |
| int kernel_argument_begin; |
| int kernel_argument_size; |
| }; |
| |
| using OperandToValueMap = |
| absl::flat_hash_map<const HloInstruction*, std::vector<LaunchFuncArgument>>; |
| |
| static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap( |
| OperandToValueMap* operand_to_value_map, const HloInstruction* instr, |
| LaunchFuncOp launchOp, LLVMFuncOp kernel) { |
| auto operands = instr->operands(); |
| std::vector<const HloInstruction*> ordered_operands; |
| bool has_failed = false; |
| // A memref will expand into multiple kernel operands, accumulate their number |
| // in order to find them later. |
| int cur_operand_position = 0; |
| |
| for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands(); |
| ++kernel_index) { |
| auto launchop_operand = |
| launchOp.getKernelOperand(kernel_index).dyn_cast<BlockArgument>(); |
| if (!launchop_operand) { |
| launchOp.emitError("argument to kernel is not a function input"); |
| has_failed = true; |
| continue; |
| } |
| auto memref_type = |
| launchop_operand.getType().dyn_cast<::mlir::MemRefType>(); |
| if (!memref_type) { |
| launchOp.emitError("only memref-typed arguments are supported"); |
| has_failed = true; |
| break; |
| } |
| // host_index is the argument position to the surrounding function that |
| // contains the launch. This index corresponds to HLO operand indices |
| // by construction. |
| auto host_index = launchop_operand.getArgNumber(); |
| // The trailing argument to the outer function are the results. |
| auto operand = |
| (host_index < operands.size()) ? operands[host_index] : instr; |
| if (!operand_to_value_map->count(operand)) { |
| ordered_operands.push_back(operand); |
| } |
| // Associate the HLO operand with the argument values of the kernel |
| // function. |
| int num_unpacked = |
| mlir::MemRefDescriptor::getNumUnpackedValues(memref_type); |
| (*operand_to_value_map)[operand].push_back( |
| {cur_operand_position, num_unpacked}); |
| cur_operand_position += num_unpacked; |
| } |
| if (has_failed) { |
| return InternalError("Mapping operands to kernel arguments has failed."); |
| } |
| return ordered_operands; |
| } |
| |
| Status InsertBufferLoadPreduleIntoKernel( |
| LLVMFuncOp kernel, const OperandToValueMap& operand_to_value_map, |
| const std::vector<const HloInstruction*>& ordered_operands, |
| BufferAssignment* assignment, |
| const std::vector<const BufferAllocation*>& buffers) { |
| mlir::OpBuilder builder(kernel.getBody()); |
| auto* context = kernel.getContext(); |
| auto offset_type = LLVMType::getInt64Ty(context); |
| auto ptr_type = LLVMType::getInt8PtrTy(context); |
| auto void_type = LLVMType::getVoidTy(context); |
| auto loc = kernel.getLoc(); |
| |
| auto num_original_args = kernel.getNumArguments(); |
| std::vector<LLVMType> new_arg_types(buffers.size(), ptr_type); |
| kernel->setAttr(kernel.getTypeAttrName(), |
| mlir::TypeAttr::get(LLVMType::getFunctionTy( |
| void_type, new_arg_types, /*isVarArg=*/false))); |
| std::vector<Value> original_args(kernel.args_begin(), kernel.args_end()); |
| |
| std::vector<mlir::Type> as_mlir_types(new_arg_types.begin(), |
| new_arg_types.end()); |
| auto new_args = kernel.front().addArguments(as_mlir_types); |
| std::vector<Value> buffer_args(new_args.begin(), new_args.end()); |
| |
| for (auto operand : ordered_operands) { |
| TF_ASSIGN_OR_RETURN(auto slice, |
| assignment->GetUniqueTopLevelSlice(operand)); |
| auto buffer = std::find(buffers.begin(), buffers.end(), slice.allocation()); |
| auto index = buffer - buffers.begin(); |
| auto offset = builder.create<mlir::LLVM::ConstantOp>( |
| loc, offset_type, builder.getI64IntegerAttr(slice.offset())); |
| auto ptr = buffer_args[index]; |
| |
| // Replace uses of function arguments pertaining to memref descriptors with |
| // values derived from HLO buffers. The instructions inserting these values |
| // into memref descriptors were already introduced during the lowering phase |
| // as per MLIR calling convention. |
| for (auto arg : operand_to_value_map.at(operand)) { |
| mlir::MemRefDescriptorView original( |
| mlir::ValueRange(original_args) |
| .slice(arg.kernel_argument_begin, arg.kernel_argument_size)); |
| |
| // Allocated and aligned pointers are the same. |
| auto casted = builder.create<mlir::LLVM::BitcastOp>( |
| loc, original.alignedPtr().getType().cast<LLVMType>(), |
| mlir::ValueRange(ptr)); |
| original.alignedPtr().replaceAllUsesWith(casted); |
| original.allocatedPtr().replaceAllUsesWith(casted); |
| |
| // Use the offset of the HLO buffer instead of the one expected in the |
| // function call. |
| original.offset().replaceAllUsesWith(offset); |
| |
| // Fill the shape. |
| auto shape = operand->shape(); |
| // Unless the operand is a scalar pointer, also fill shape and strides. |
| if (shape.dimensions().empty()) { |
| continue; |
| } |
| |
| // TODO(b/137624192) Pass in the descriptor to allow for dynamic shapes. |
| assert(shape.IsArray() && shape.is_static()); |
| for (auto extent : llvm::enumerate(shape.dimensions())) { |
| auto shape = builder.create<mlir::LLVM::ConstantOp>( |
| loc, original.size(extent.index()).getType(), |
| builder.getI64IntegerAttr(extent.value())); |
| original.size(extent.index()).replaceAllUsesWith(shape); |
| } |
| // Finally, fill the strides. |
| // TODO(b/137624192): Take assigned layout into account. |
| uint64_t accumulator = 0; |
| for (int64_t idx = shape.rank() - 1; idx >= 0; --idx) { |
| if (accumulator == 0) { |
| accumulator = 1; |
| } else { |
| accumulator *= shape.dimensions(idx + 1); |
| } |
| auto stride = builder.create<mlir::LLVM::ConstantOp>( |
| loc, original.stride(idx).getType(), |
| builder.getI64IntegerAttr(accumulator)); |
| original.stride(idx).replaceAllUsesWith(stride); |
| } |
| } |
| } |
| |
| // Now we can remove the original arguments, as they should have no more |
| // users. |
| for (int i = 0; i < num_original_args; ++i) { |
| kernel.front().eraseArgument(0); |
| } |
| |
| return Status::OK(); |
| } |
| |
| StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk( |
| FuncOp func, const HloInstruction* const instr, ModuleOp kernel_module, |
| BufferAssignment* assignment) { |
| // Find the single LaunchFuncOp and compute a mapping from operands of |
| // the hlo instruction to the corresponding values of the kernel |
| // function in the target module; |
| LaunchFuncOp launchOp; |
| auto walkResult = func.walk([&launchOp](LaunchFuncOp op) { |
| if (launchOp) { |
| op.emitError("multiple kernels for single top-level HLO"); |
| return mlir::WalkResult::interrupt(); |
| } |
| launchOp = op; |
| return mlir::WalkResult::advance(); |
| }); |
| if (walkResult.wasInterrupted()) { |
| return InternalError("Multiple kernels for single top-level HLO"); |
| } |
| if (!launchOp) { |
| // If there was no launchOp, then no kernel was generated, so the lowering |
| // from the LHLO ops to the GPU dialect is not implemented yet. |
| return Unimplemented("No kernel was generated."); |
| } |
| |
| auto kernel = |
| kernel_module.lookupSymbol<LLVMFuncOp>(launchOp.getKernelName()); |
| |
| // Store the assignment of operands to block arguments. Note that an operand |
| // might be used in multiple argument positions, hence the vector. |
| OperandToValueMap operand_to_value_map; |
| TF_ASSIGN_OR_RETURN( |
| auto ordered_operands, |
| ComputeOperandToValueMap(&operand_to_value_map, instr, launchOp, kernel)); |
| |
| // Get the required buffers to support the inputs. Use a set and vector here |
| // to keep the order fixed. This is mostly useful for testing. |
| std::unordered_set<const BufferAllocation*> buffers_needed; |
| std::vector<const BufferAllocation*> buffers; |
| // TODO(b/137624192) Add support for tuples. |
| for (auto operand : ordered_operands) { |
| TF_ASSIGN_OR_RETURN(auto buffer, |
| assignment->GetUniqueTopLevelSlice(operand)); |
| if (buffers_needed.insert(buffer.allocation()).second) { |
| buffers.push_back(buffer.allocation()); |
| } |
| } |
| |
| // TODO(b/137624192) Add support for temp buffer. |
| // TODO(b/137624192) Add support for constant buffers. |
| |
| // Change the signature to match what the XLA runtime expects from the |
| // kernel. |
| TF_RETURN_IF_ERROR(InsertBufferLoadPreduleIntoKernel( |
| kernel, operand_to_value_map, ordered_operands, assignment, buffers)); |
| |
| // Finally, create the thunk and set the launch dimensions. |
| gpu::Thunk::ThunkInfo info; |
| auto thunk = absl::make_unique<gpu::KernelThunk>(info, buffers, |
| kernel.getName().str()); |
| |
| // Set launch bounds. |
| mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues(); |
| mlir::gpu::KernelDim3 grid = launchOp.getGridSizeOperandValues(); |
| absl::optional<int64> num_threads = getLaunchBound(block); |
| absl::optional<int64> num_blocks = getLaunchBound(grid); |
| if (!num_threads || !num_blocks) { |
| return Unimplemented("Unsupported launch bounds"); |
| } |
| thunk->SetLaunchDimensions(gpu::LaunchDimensions(*num_blocks, *num_threads)); |
| return std::move(thunk); |
| } |
| |
| StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend( |
| std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, |
| const CompileOptions& options) { |
| // Determine the HLO schedule, which is an ordering of HLO instructions. This |
| // is used by buffer assignment to enable buffer reuse, and the same ordering |
| // must also be used to determine the thunk launch schedule. |
| std::unique_ptr<StreamAssignment> stream_assignment = |
| xla::gpu::AssignStreams(*module); |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<GpuHloSchedule> hlo_schedule, |
| GpuHloSchedule::Build(*module, *stream_assignment, |
| data_layout_.getPointerSize())); |
| |
| // Run buffer analysis on the HLO graph. This analysis figures out which |
| // temporary buffers are required to run the computation. |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferAssignment> buffer_assignment, |
| BufferAssigner::Run( |
| module.get(), hlo_schedule->ConsumeHloOrdering(), |
| BufferSizeBytesFunction(), |
| /*color_alignment=*/ |
| [](LogicalBuffer::Color) { |
| return xla::gpu::kXlaAllocatedBufferAlignBytes; |
| }, |
| /*allocate_buffers_for_constants=*/true, |
| /*colorer=*/BufferAssigner::DefaultColorer(), |
| /*must_not_live_out=*/{}, &CanShareBufferHint)); |
| DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); |
| |
| EmissionContext emission_context(std::move(module)); |
| if (error_handler_) { |
| emission_context.setErrorHandler(error_handler_); |
| } |
| |
| OwningModuleRef mlir_module = |
| ModuleOp::create(UnknownLoc::get(emission_context.getContext())); |
| LhloDialectEmitter lhlo_emitter(&emission_context, *buffer_assignment, |
| stream_exec->platform(), *mlir_module); |
| |
| absl::flat_hash_map<const HloInstruction*, std::unique_ptr<gpu::Thunk>> |
| hlo_to_thunk; |
| for (HloInstruction* instruction : hlo_schedule->ThunkLaunchOrder()) { |
| TF_RETURN_IF_ERROR(instruction->Visit(&lhlo_emitter)); |
| gpu::ThunkSequence thunks = lhlo_emitter.ConsumeThunkSequence(); |
| TF_RET_CHECK(thunks.size() <= 1) << instruction->ToString(); |
| if (!thunks.empty()) { |
| auto thunk = std::move(thunks.front()); |
| hlo_to_thunk[instruction] = std::move(thunk); |
| } |
| } |
| |
| TF_RETURN_IF_ERROR( |
| module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module)); |
| |
| TF_RETURN_IF_ERROR(LowerLHLOToGPU(*mlir_module)); |
| |
| TF_RETURN_IF_ERROR( |
| module_hook_.invoke(IRHook::LoweringStage::GPU, *mlir_module)); |
| |
| TF_RETURN_IF_ERROR(LowerKernelBodiesToNVVM(*mlir_module)); |
| |
| TF_RETURN_IF_ERROR( |
| module_hook_.invoke(IRHook::LoweringStage::LLVM, *mlir_module)); |
| |
| TF_ASSIGN_OR_RETURN(OwningModuleRef kernel_module, |
| ExtractKernelModule(*mlir_module)); |
| |
| for (auto entry : lhlo_emitter.InstructionToFunctionMap()) { |
| TF_ASSIGN_OR_RETURN( |
| auto thunk, |
| TransformKernelToXlaThunk(entry.second, entry.first, *kernel_module, |
| buffer_assignment.get())); |
| hlo_to_thunk[entry.first] = std::move(thunk); |
| } |
| |
| absl::flat_hash_map<const gpu::Thunk*, const HloInstruction*> thunk_to_hlo; |
| gpu::ThunkSequence thunk_sequence; |
| { |
| for (HloInstruction* hlo : hlo_schedule->ThunkLaunchOrder()) { |
| auto it = hlo_to_thunk.find(hlo); |
| if (it != hlo_to_thunk.end()) { |
| const HloInstruction* hlo = it->first; |
| auto& thunk = it->second; |
| thunk_to_hlo[thunk.get()] = hlo; |
| thunk_sequence.push_back(std::move(thunk)); |
| } |
| } |
| } |
| |
| TF_RETURN_IF_ERROR( |
| module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module)); |
| |
| // Translate to LLVM IR in a fresh context. The module is further translated |
| // to textual PTX and a CUBIN blob so there is no need for the context to live |
| // longer than this function. |
| llvm::LLVMContext llvmContext; |
| auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext); |
| |
| if (!llvmModule) { |
| return InternalError("Translation to LLVM failed"); |
| } |
| |
| llvmModule->setModuleIdentifier(emission_context.getHloModule()->name()); |
| // TODO(herhut): Why is this needed and does not come from the template? |
| llvmModule->setDataLayout(gpu::nvptx::kDataLayout); |
| |
| const auto& config = emission_context.getHloModule()->config(); |
| TF_ASSIGN_OR_RETURN( |
| auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(), |
| GetGpuVersion(stream_exec), |
| config, GetLibdeviceDir(config))); |
| // Allow to fallback to the driver compilation when ptxas isn't able to |
| // compile. |
| StatusOr<std::vector<uint8>> maybe_cubin = |
| se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(), |
| gpu::PtxOptsFromConfig(config)); |
| std::vector<uint8> cubin; |
| if (maybe_cubin.ok()) { |
| cubin = std::move(maybe_cubin).ValueOrDie(); |
| } else if (maybe_cubin.status().code() == |
| tensorflow::error::Code::UNIMPLEMENTED) { |
| xla::gpu::WarnIfBadDriverJITVersion(); |
| } else { |
| return maybe_cubin.status(); |
| } |
| |
| auto thunk_schedule = absl::make_unique<ThunkSchedule>( |
| std::make_unique<gpu::ThunkSequence>(std::move(thunk_sequence)), |
| std::move(stream_assignment), std::move(thunk_to_hlo)); |
| |
| if (DumpingEnabledForHloModule(*emission_context.getHloModule())) { |
| DumpToFileInDirOrStdout(*emission_context.getHloModule(), "", |
| "thunk_schedule", thunk_schedule->ToString()); |
| } |
| |
| // TODO(b/137624192): Add profiling support. |
| return {absl::make_unique<GpuExecutable>( |
| ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), |
| emission_context.releaseHloModule(), std::move(buffer_assignment), |
| nullptr, nullptr, std::vector<GpuExecutable::ConstantInfo>())}; |
| } |
| |
| StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile( |
| std::unique_ptr<HloModuleGroup> module_group, |
| std::vector<std::vector<se::StreamExecutor*>> stream_execs, |
| const CompileOptions& options) { |
| return Unimplemented("Not yet implemented in MLIR compiler"); |
| } |
| |
| StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> |
| MlirCompilerImpl::CompileAheadOfTime( |
| std::unique_ptr<HloModuleGroup> /*module_group*/, |
| const AotCompilationOptions& /*options*/) { |
| return Unimplemented("Not yet implemented in MLIR compiler"); |
| } |
| |
| } // namespace |
| } // namespace mlir_gpu |
| } // namespace xla |
| |
| static bool InitModule() { |
| xla::Compiler::RegisterCompilerFactory( |
| stream_executor::cuda::kCudaPlatformId, []() { |
| return absl::make_unique<xla::FailoverCompiler>( |
| absl::make_unique<xla::mlir_gpu::MlirCompilerImpl>(), |
| absl::make_unique<xla::gpu::NVPTXCompiler>()); |
| }); |
| return true; |
| } |
| static bool module_initialized = InitModule(); |