blob: c0e30f0cc372214e20ea8373f520c13b28ed45fa [file] [log] [blame]
//===- ConvertToLLVMIR.cpp - MLIR to LLVM IR conversion ---------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a pass that converts CFG function to LLVM IR. No ML
// functions must be presented in MLIR.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/Functional.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Translation.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
using namespace mlir;
namespace {
class ModuleLowerer {
public:
explicit ModuleLowerer(llvm::LLVMContext &llvmContext)
: llvmContext(llvmContext), builder(llvmContext) {}
bool runOnModule(Module &m, llvm::Module &llvmModule);
private:
bool convertBasicBlock(const BasicBlock &bb, bool ignoreArguments = false);
bool convertCFGFunction(const CFGFunction &cfgFunc, llvm::Function &llvmFunc);
bool convertFunctions(const Module &mlirModule, llvm::Module &llvmModule);
bool convertInstruction(const Instruction &inst);
void connectPHINodes(const CFGFunction &cfgFunc);
/// Type conversion functions. If any conversion fails, report errors to the
/// context of the MLIR type and return nullptr.
/// \{
llvm::FunctionType *convertFunctionType(FunctionType type);
llvm::IntegerType *convertIndexType(IndexType type);
llvm::IntegerType *convertIntegerType(IntegerType type);
llvm::Type *convertFloatType(FloatType type);
llvm::Type *convertType(Type type);
/// \}
llvm::DenseMap<const Function *, llvm::Function *> functionMapping;
llvm::DenseMap<const SSAValue *, llvm::Value *> valueMapping;
llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping;
llvm::LLVMContext &llvmContext;
llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> builder;
llvm::IntegerType *indexType;
};
llvm::IntegerType *ModuleLowerer::convertIndexType(IndexType type) {
return indexType;
}
llvm::IntegerType *ModuleLowerer::convertIntegerType(IntegerType type) {
return builder.getIntNTy(type.getBitWidth());
}
llvm::Type *ModuleLowerer::convertFloatType(FloatType type) {
MLIRContext *context = type.getContext();
switch (type.getKind()) {
case Type::Kind::F32:
return builder.getFloatTy();
case Type::Kind::F64:
return builder.getDoubleTy();
case Type::Kind::F16:
return builder.getHalfTy();
case Type::Kind::BF16:
return context->emitError(UnknownLoc::get(context),
"Unsupported type: BF16"),
nullptr;
default:
llvm_unreachable("non-float type in convertFloatType");
}
}
llvm::FunctionType *ModuleLowerer::convertFunctionType(FunctionType type) {
// TODO(zinenko): convert tuple to LLVM structure types
assert(type.getNumResults() <= 1 && "NYI: tuple returns");
auto resultType = type.getNumResults() == 0
? llvm::Type::getVoidTy(llvmContext)
: convertType(type.getResult(0));
if (!resultType)
return nullptr;
auto argTypes =
functional::map([this](Type inputType) { return convertType(inputType); },
type.getInputs());
if (std::any_of(argTypes.begin(), argTypes.end(),
[](const llvm::Type *t) { return t == nullptr; }))
return nullptr;
return llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false);
}
llvm::Type *ModuleLowerer::convertType(Type type) {
if (auto funcType = type.dyn_cast<FunctionType>())
return convertFunctionType(funcType);
if (auto intType = type.dyn_cast<IntegerType>())
return convertIntegerType(intType);
if (auto floatType = type.dyn_cast<FloatType>())
return convertFloatType(floatType);
if (auto indexType = type.dyn_cast<IndexType>())
return convertIndexType(indexType);
MLIRContext *context = type.getContext();
std::string message;
llvm::raw_string_ostream os(message);
os << "unsupported type: ";
type.print(os);
context->emitError(UnknownLoc::get(context), os.str());
return nullptr;
}
static llvm::CmpInst::Predicate getLLVMCmpPredicate(CmpIPredicate p) {
switch (p) {
case CmpIPredicate::EQ:
return llvm::CmpInst::Predicate::ICMP_EQ;
case CmpIPredicate::NE:
return llvm::CmpInst::Predicate::ICMP_NE;
case CmpIPredicate::SLT:
return llvm::CmpInst::Predicate::ICMP_SLT;
case CmpIPredicate::SLE:
return llvm::CmpInst::Predicate::ICMP_SLE;
case CmpIPredicate::SGT:
return llvm::CmpInst::Predicate::ICMP_SGT;
case CmpIPredicate::SGE:
return llvm::CmpInst::Predicate::ICMP_SGE;
case CmpIPredicate::ULT:
return llvm::CmpInst::Predicate::ICMP_ULT;
case CmpIPredicate::ULE:
return llvm::CmpInst::Predicate::ICMP_ULE;
case CmpIPredicate::UGT:
return llvm::CmpInst::Predicate::ICMP_UGT;
case CmpIPredicate::UGE:
return llvm::CmpInst::Predicate::ICMP_UGE;
default:
llvm_unreachable("incorrect comparison predicate");
}
}
// Convert specific operation instruction types LLVM instructions.
// FIXME(zinenko): this should eventually become a separate MLIR pass that
// converts MLIR standard operations into LLVM IR dialect; the translation in
// that case would become a simple 1:1 instruction and value remapping.
bool ModuleLowerer::convertInstruction(const Instruction &inst) {
if (auto op = inst.dyn_cast<AddIOp>())
return valueMapping[op->getResult()] =
builder.CreateAdd(valueMapping[op->getOperand(0)],
valueMapping[op->getOperand(1)]),
false;
if (auto op = inst.dyn_cast<MulIOp>())
return valueMapping[op->getResult()] =
builder.CreateMul(valueMapping[op->getOperand(0)],
valueMapping[op->getOperand(1)]),
false;
if (auto op = inst.dyn_cast<CmpIOp>())
return valueMapping[op->getResult()] =
builder.CreateICmp(getLLVMCmpPredicate(op->getPredicate()),
valueMapping[op->getOperand(0)],
valueMapping[op->getOperand(1)]),
false;
if (auto constantOp = inst.dyn_cast<ConstantOp>()) {
llvm::Type *type = convertType(constantOp->getType());
if (!type)
return true;
assert(isa<llvm::IntegerType>(type) &&
"only integer LLVM types are supported");
auto attr = (constantOp->getValue()).cast<IntegerAttr>();
// Create a new APInt even if we can extract one from the attribute, because
// attributes are currently hardcoded to be 64-bit APInts and LLVM will
// create an i64 constant from those.
valueMapping[constantOp->getResult()] = llvm::Constant::getIntegerValue(
type, llvm::APInt(type->getIntegerBitWidth(), attr.getInt()));
return false;
}
if (auto callOp = inst.dyn_cast<CallOp>()) {
auto operands = functional::map(
[this](const SSAValue *value) { return valueMapping.lookup(value); },
callOp->getOperands());
auto numResults = callOp->getNumResults();
// TODO(zinenko): support tuple returns
assert(numResults <= 1 && "NYI: tuple returns");
llvm::Value *result =
builder.CreateCall(functionMapping[callOp->getCallee()], operands);
if (numResults == 1)
valueMapping[callOp->getResult(0)] = result;
return false;
}
// Terminators.
if (auto returnInst = inst.dyn_cast<ReturnOp>()) {
unsigned numOperands = returnInst->getNumOperands();
// TODO(zinenko): support tuple returns
assert(numOperands <= 1u && "NYI: tuple returns");
if (numOperands == 0)
builder.CreateRetVoid();
else
builder.CreateRet(valueMapping[returnInst->getOperand(0)]);
return false;
}
if (auto branchInst = inst.dyn_cast<BranchOp>()) {
builder.CreateBr(blockMapping[branchInst->getDest()]);
return false;
}
if (auto condBranchInst = inst.dyn_cast<CondBranchOp>()) {
builder.CreateCondBr(valueMapping[condBranchInst->getCondition()],
blockMapping[condBranchInst->getTrueDest()],
blockMapping[condBranchInst->getFalseDest()]);
return false;
}
inst.emitError("unsupported operation");
return true;
}
bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb,
bool ignoreArguments) {
builder.SetInsertPoint(blockMapping[&bb]);
// Before traversing instructions, make block arguments available through
// value remapping and PHI nodes, but do not add incoming edges for the PHI
// nodes just yet: those values may be defined by this or following blocks.
// This step is omitted if "ignoreArguments" is set. The arguments of the
// first basic block have been already made available through the remapping of
// LLVM function arguments.
if (!ignoreArguments) {
auto predecessors = bb.getPredecessors();
unsigned numPredecessors =
std::distance(predecessors.begin(), predecessors.end());
for (const auto *arg : bb.getArguments()) {
llvm::Type *type = convertType(arg->getType());
if (!type)
return true;
llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
valueMapping[arg] = phi;
}
}
// Traverse instructions.
for (const auto &inst : bb) {
if (convertInstruction(inst))
return true;
}
return false;
}
// Get the SSA value passed to the current block from the terminator instruction
// of its predecessor.
static const SSAValue *getPHISourceValue(const BasicBlock *current,
const BasicBlock *pred,
unsigned numArguments,
unsigned index) {
const Instruction &terminator = *pred->getTerminator();
if (terminator.isa<BranchOp>()) {
return terminator.getOperand(index);
}
// For conditional branches, we need to check if the current block is reached
// through the "true" or the "false" branch and take the relevant operands.
auto condBranchOp = terminator.dyn_cast<CondBranchOp>();
assert(condBranchOp &&
"only branch instructions can be terminators of a basic block that "
"has successors");
condBranchOp->emitError("NYI: conditional branches with arguments");
return nullptr;
}
void ModuleLowerer::connectPHINodes(const CFGFunction &cfgFunc) {
// Skip the first block, it cannot be branched to and its arguments correspond
// to the arguments of the LLVM function.
for (auto it = std::next(cfgFunc.begin()), eit = cfgFunc.end(); it != eit;
++it) {
const BasicBlock *bb = &*it;
llvm::BasicBlock *llvmBB = blockMapping[bb];
auto phis = llvmBB->phis();
auto numArguments = bb->getNumArguments();
assert(numArguments == std::distance(phis.begin(), phis.end()));
for (auto &numberedPhiNode : llvm::enumerate(phis)) {
auto &phiNode = numberedPhiNode.value();
unsigned index = numberedPhiNode.index();
for (const auto *pred : bb->getPredecessors()) {
phiNode.addIncoming(
valueMapping[getPHISourceValue(bb, pred, numArguments, index)],
blockMapping[pred]);
}
}
}
}
bool ModuleLowerer::convertCFGFunction(const CFGFunction &cfgFunc,
llvm::Function &llvmFunc) {
// Clear the block mapping. Blocks belong to a function, no need to keep
// blocks from the previous functions around. Furthermore, we use this
// mapping to connect PHI nodes inside the function later.
blockMapping.clear();
// First, create all blocks so we can jump to them.
for (const auto &bb : cfgFunc) {
auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
llvmBB->insertInto(&llvmFunc);
blockMapping[&bb] = llvmBB;
}
// Then, convert blocks one by one.
for (auto indexedBB : llvm::enumerate(cfgFunc)) {
const auto &bb = indexedBB.value();
if (convertBasicBlock(bb, /*ignoreArguments=*/indexedBB.index() == 0))
return true;
}
// Finally, after all blocks have been traversed and values mapped, connect
// the PHI nodes to the results of preceding blocks.
connectPHINodes(cfgFunc);
return false;
}
bool ModuleLowerer::convertFunctions(const Module &mlirModule,
llvm::Module &llvmModule) {
// Declare all functions first because there may be function calls that form a
// call graph with cycles. We don't expect MLFunctions here.
for (const Function &function : mlirModule) {
const Function *functionPtr = &function;
if (!isa<ExtFunction>(functionPtr) && !isa<CFGFunction>(functionPtr))
continue;
llvm::Constant *llvmFuncCst = llvmModule.getOrInsertFunction(
function.getName(), convertFunctionType(function.getType()));
assert(isa<llvm::Function>(llvmFuncCst));
functionMapping[functionPtr] = cast<llvm::Function>(llvmFuncCst);
}
// Convert CFG functions.
for (const Function &function : mlirModule) {
const Function *functionPtr = &function;
auto cfgFunction = dyn_cast<CFGFunction>(functionPtr);
if (!cfgFunction)
continue;
llvm::Function *llvmFunc = functionMapping[cfgFunction];
// Add function arguments to the value remapping table. In CFGFunction,
// arguments of the first block are those of the function.
assert(!cfgFunction->getBlocks().empty() &&
"expected at least one basic block in a CFGFunction");
const BasicBlock &firstBlock = *cfgFunction->begin();
for (auto arg : llvm::enumerate(llvmFunc->args())) {
valueMapping[firstBlock.getArgument(arg.index())] = &arg.value();
}
if (convertCFGFunction(*cfgFunction, *functionMapping[cfgFunction]))
return true;
}
return false;
}
bool ModuleLowerer::runOnModule(Module &m, llvm::Module &llvmModule) {
// Create index type once for the entire module, it needs module info that is
// not available in the convert*Type calls.
indexType =
builder.getIntNTy(llvmModule.getDataLayout().getPointerSizeInBits());
return convertFunctions(m, llvmModule);
}
} // namespace
// Entry point for the lowering procedure.
std::unique_ptr<llvm::Module>
mlir::convertModuleToLLVMIR(Module &module, llvm::LLVMContext &llvmContext) {
auto llvmModule = llvm::make_unique<llvm::Module>("FIXME_name", llvmContext);
if (ModuleLowerer(llvmContext).runOnModule(module, *llvmModule))
return nullptr;
return llvmModule;
}
// MLIR to LLVM IR translation registration.
static TranslateFromMLIRRegistration MLIRToLLVMIRTranslate(
"mlir-to-llvmir", [](Module *module, llvm::StringRef outputFilename) {
if (!module)
return true;
llvm::LLVMContext llvmContext;
auto llvmModule = convertModuleToLLVMIR(*module, llvmContext);
if (!llvmModule)
return true;
auto file = openOutputFile(outputFilename);
if (!file)
return true;
llvmModule->print(file->os(), nullptr);
file->keep();
return false;
});