blob: 2e3ea537c59213f0d876f40fbecebf8619c65cb8 [file] [log] [blame]
//===- ConvertToLLVMIR.cpp - MLIR to LLVM IR conversion ---------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a pass that converts CFG function to LLVM IR. No ML
// functions must be presented in MLIR.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/Functional.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Translation.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
using namespace mlir;
namespace {
class ModuleLowerer {
public:
explicit ModuleLowerer(llvm::LLVMContext &llvmContext)
: llvmContext(llvmContext), builder(llvmContext) {}
bool runOnModule(Module &m, llvm::Module &llvmModule);
private:
bool convertBasicBlock(const BasicBlock &bb, bool ignoreArguments = false);
bool convertCFGFunction(const CFGFunction &cfgFunc, llvm::Function &llvmFunc);
bool convertFunctions(const Module &mlirModule, llvm::Module &llvmModule);
bool convertInstruction(const Instruction &inst);
void connectPHINodes(const CFGFunction &cfgFunc);
/// Type conversion functions. If any conversion fails, report errors to the
/// context of the MLIR type and return nullptr.
/// \{
llvm::FunctionType *convertFunctionType(FunctionType type);
llvm::IntegerType *convertIndexType(IndexType type);
llvm::IntegerType *convertIntegerType(IntegerType type);
llvm::Type *convertFloatType(FloatType type);
llvm::Type *convertType(Type type);
/// Convert a MemRefType `type` into an LLVM aggregate structure type. Each
/// structure type starts with a pointer to the elemental type of the MemRef
/// and continues with as many lowered to LLVM index types as MemRef has
/// dynamic dimensions. An instance of this type is called a MemRef decriptor
/// and replaces the MemRef everywhere it is used so that any instruction has
/// access to its dynamic sizes.
/// For example, given that `index` is converted to `i64`, `memref<?x?xf32>`
/// is converted to `{float*, i64, i64}` (two dynamic sizes, in order);
/// `memref<42x?x42xi32>` is converted to `{i32*, i64}` (only one size is
/// dynamic); `memref<2x3x4xf64>` is converted to `{double*}`.
llvm::StructType *convertMemRefType(MemRefType type);
/// \}
/// Get an a constant value of `indexType`.
inline llvm::Constant *getIndexConstant(int64_t value);
/// Given subscript indices and array sizes in row-major order,
/// i_n, i_{n-1}, ..., i_1
/// s_n, s_{n-1}, ..., s_1
/// obtain a value that corresponds to the linearized subscript
/// i_n * s_{n-1} * s_{n-2} * ... * s_1 +
/// + i_{n-1} * s_{n-2} * s_{n_3} * ... * s_1 +
/// + ... +
/// + i_2 * s_1 +
/// + i_1.
llvm::Value *linearizeSubscripts(ArrayRef<llvm::Value *> indices,
ArrayRef<llvm::Value *> allocSizes);
/// Emit LLVM IR instructions necessary to obtain a pointer to the element of
/// `memRef` accessed by `op` with indices `opIndices`. In particular, extract
/// any dynamic allocation sizes from the MemRef descriptor, linearize the
/// access subscript given the sizes, extract the data pointer from the MemRef
/// descriptor and get the pointer to the element indexed by the linearized
/// subscript. Return nullptr on errors.
llvm::Value *emitMemRefElementAccess(
const SSAValue *memRef, const Operation &op,
llvm::iterator_range<Operation::const_operand_iterator> opIndices);
/// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create
/// a Value for the MemRef descriptor, store any dynamic sizes passed to
/// the alloc operation in the descriptor, allocate the buffer for the data
/// using `allocFunc` and also store it in the descriptor. Return the MemRef
/// descriptor. This function returns `nullptr` in case of errors.
llvm::Value *emitMemRefAlloc(ConstOpPointer<AllocOp> allocOp);
/// Emit LLVM IR corresponding to the given Dealloc `op`. In particular,
/// use `freeFunc` to free the memory allocated for the MemRef's buffer. The
/// MemRef descriptor allocated on stack will cease to exist when the current
/// function returns without any extra action. Returns an LLVM Value (call
/// instruction) on success and nullptr on error.
llvm::Value *emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp);
llvm::DenseMap<const Function *, llvm::Function *> functionMapping;
llvm::DenseMap<const SSAValue *, llvm::Value *> valueMapping;
llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping;
llvm::LLVMContext &llvmContext;
llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> builder;
llvm::IntegerType *indexType;
/// Allocation function : (index) -> i8*, declaration only.
llvm::Constant *allocFunc;
/// Deallocation function : (i8*) -> void, declaration only.
llvm::Constant *freeFunc;
};
llvm::IntegerType *ModuleLowerer::convertIndexType(IndexType type) {
return indexType;
}
llvm::IntegerType *ModuleLowerer::convertIntegerType(IntegerType type) {
return builder.getIntNTy(type.getBitWidth());
}
llvm::Type *ModuleLowerer::convertFloatType(FloatType type) {
MLIRContext *context = type.getContext();
switch (type.getKind()) {
case Type::Kind::F32:
return builder.getFloatTy();
case Type::Kind::F64:
return builder.getDoubleTy();
case Type::Kind::F16:
return builder.getHalfTy();
case Type::Kind::BF16:
return context->emitError(UnknownLoc::get(context),
"Unsupported type: BF16"),
nullptr;
default:
llvm_unreachable("non-float type in convertFloatType");
}
}
llvm::FunctionType *ModuleLowerer::convertFunctionType(FunctionType type) {
// TODO(zinenko): convert tuple to LLVM structure types
assert(type.getNumResults() <= 1 && "NYI: tuple returns");
auto resultType = type.getNumResults() == 0
? llvm::Type::getVoidTy(llvmContext)
: convertType(type.getResult(0));
if (!resultType)
return nullptr;
auto argTypes =
functional::map([this](Type inputType) { return convertType(inputType); },
type.getInputs());
if (std::any_of(argTypes.begin(), argTypes.end(),
[](const llvm::Type *t) { return t == nullptr; }))
return nullptr;
return llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false);
}
// MemRefs are converted into LLVM structure types to accomodate dynamic sizes.
// The first element of a structure is a pointer to the elemental type of the
// MemRef. The following N elements are values of the Index type, one for each
// of N dynamic dimensions of the MemRef.
llvm::StructType *ModuleLowerer::convertMemRefType(MemRefType type) {
llvm::Type *elementType = convertType(type.getElementType());
if (!elementType)
return nullptr;
elementType = elementType->getPointerTo();
// Extra value for the memory space.
unsigned numDynamicSizes = type.getNumDynamicDims();
SmallVector<llvm::Type *, 8> types(numDynamicSizes + 1, indexType);
types.front() = elementType;
return llvm::StructType::get(llvmContext, types);
}
llvm::Type *ModuleLowerer::convertType(Type type) {
if (auto funcType = type.dyn_cast<FunctionType>())
return convertFunctionType(funcType);
if (auto intType = type.dyn_cast<IntegerType>())
return convertIntegerType(intType);
if (auto floatType = type.dyn_cast<FloatType>())
return convertFloatType(floatType);
if (auto indexType = type.dyn_cast<IndexType>())
return convertIndexType(indexType);
if (auto memRefType = type.dyn_cast<MemRefType>())
return convertMemRefType(memRefType);
MLIRContext *context = type.getContext();
std::string message;
llvm::raw_string_ostream os(message);
os << "unsupported type: ";
type.print(os);
context->emitError(UnknownLoc::get(context), os.str());
return nullptr;
}
llvm::Constant *ModuleLowerer::getIndexConstant(int64_t value) {
return llvm::Constant::getIntegerValue(
indexType, llvm::APInt(indexType->getBitWidth(), value));
}
// Given subscript indices and array sizes in row-major order,
// i_n, i_{n-1}, ..., i_1
// s_n, s_{n-1}, ..., s_1
// obtain a value that corresponds to the linearized subscript
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
// by accumulating the running linearized value.
llvm::Value *
ModuleLowerer::linearizeSubscripts(ArrayRef<llvm::Value *> indices,
ArrayRef<llvm::Value *> allocSizes) {
assert(indices.size() == allocSizes.size() &&
"mismatching number of indices and allocation sizes");
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
llvm::Value *linearized = indices.front();
for (unsigned i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
linearized = builder.CreateMul(linearized, allocSizes[i]);
linearized = builder.CreateAdd(linearized, indices[i]);
}
return linearized;
}
// Check if the MemRefType `type` is supported by the lowering. Emit errors at
// the location of `op` and return true. Return false if the type is supported.
// TODO(zinenko): this function should disappear when the conversion fully
// supports MemRefs.
static bool checkSupportedMemRefType(MemRefType type, const Operation &op) {
if (!type.getAffineMaps().empty()) {
op.emitError("NYI: memrefs with affine maps");
return true;
}
if (type.getMemorySpace() != 0) {
op.emitError("NYI: non-default memory space");
return true;
}
return false;
}
llvm::Value *ModuleLowerer::emitMemRefElementAccess(
const SSAValue *memRef, const Operation &op,
llvm::iterator_range<Operation::const_operand_iterator> opIndices) {
auto type = memRef->getType().dyn_cast<MemRefType>();
assert(type && "expected memRef value to have a MemRef type");
if (checkSupportedMemRefType(type, op))
return nullptr;
// A MemRef-typed value is remapped to its descriptor.
llvm::Value *memRefDescriptor = valueMapping.lookup(memRef);
// Get the list of MemRef sizes. Static sizes are defined as values. Dynamic
// sizes are extracted from the MemRef descriptor.
llvm::SmallVector<llvm::Value *, 4> sizes;
unsigned dynanmicSizeIdx = 0;
for (int64_t s : type.getShape()) {
llvm::Value *size = (s == -1) ? builder.CreateExtractValue(
memRefDescriptor, 1 + dynanmicSizeIdx++)
: getIndexConstant(s);
sizes.push_back(size);
}
// Obtain the list of access subscripts as values and linearize it given the
// list of sizes.
auto indices = functional::map(
[this](const SSAValue *value) { return valueMapping.lookup(value); },
opIndices);
auto subscript = linearizeSubscripts(indices, sizes);
// Extract the pointer to the data buffer and use LLVM's getelementptr to
// repoint it to the element indexed by the subscript.
llvm::Value *data = builder.CreateExtractValue(memRefDescriptor, 0);
return builder.CreateGEP(data, subscript);
}
llvm::Value *ModuleLowerer::emitMemRefAlloc(ConstOpPointer<AllocOp> allocOp) {
MemRefType type = allocOp->getType();
if (checkSupportedMemRefType(type, *allocOp->getOperation()))
return nullptr;
// Get actual sizes of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands.
SmallVector<llvm::Value *, 4> sizes;
sizes.reserve(allocOp->getNumOperands());
unsigned i = 0;
for (int s : type.getShape()) {
llvm::Value *value = (s == -1)
? valueMapping.lookup(allocOp->getOperand(i++))
: getIndexConstant(s);
sizes.push_back(value);
}
assert(!sizes.empty() && "zero-dimensional allocation");
// Compute the total numer of memref elements as Value.
llvm::Value *cumulativeSize = sizes.front();
for (unsigned i = 1, e = sizes.size(); i < e; ++i) {
cumulativeSize = builder.CreateMul(cumulativeSize, sizes[i]);
}
// Allocate the MemRef descriptor on stack and load it.
llvm::StructType *structType = convertMemRefType(type);
llvm::Type *elementType = convertType(type.getElementType());
if (!structType || !elementType)
return nullptr;
llvm::Value *memRefDescriptor = llvm::UndefValue::get(structType);
// Take into account the size of the elemental type before allocation.
// Elemental types can be scalars or vectors only.
unsigned byteWidth = elementType->getScalarSizeInBits() / 8;
assert(byteWidth > 0 && "could not determine size of a MemRef element");
if (elementType->isVectorTy()) {
byteWidth *= elementType->getVectorNumElements();
}
llvm::Value *byteWidthValue = getIndexConstant(byteWidth);
cumulativeSize = builder.CreateMul(cumulativeSize, byteWidthValue);
// Allocate the buffer for theMemRef and store a pointer to it in the MemRef
// descriptor.
llvm::Value *allocated = builder.CreateCall(allocFunc, cumulativeSize);
allocated = builder.CreateBitCast(allocated, elementType->getPointerTo());
memRefDescriptor = builder.CreateInsertValue(memRefDescriptor, allocated, 0);
// Store dynamically allocated sizes in the descriptor.
i = 0;
for (auto indexedSize : llvm::enumerate(sizes)) {
if (type.getShape()[indexedSize.index()] != -1)
continue;
memRefDescriptor = builder.CreateInsertValue(memRefDescriptor,
indexedSize.value(), 1 + i++);
}
// Return the final value of the descriptor (each insert returns a new,
// updated value, the old is still accessible but has old data).
return memRefDescriptor;
}
llvm::Value *
ModuleLowerer::emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp) {
// Extract the pointer to the MemRef buffer from its descriptor and call
// `freeFunc` on it.
llvm::Value *memRefDescriptor = valueMapping.lookup(deallocOp->getMemRef());
llvm::Value *data = builder.CreateExtractValue(memRefDescriptor, 0);
data = builder.CreateBitCast(data, builder.getInt8PtrTy());
return builder.CreateCall(freeFunc, data);
}
static llvm::CmpInst::Predicate getLLVMCmpPredicate(CmpIPredicate p) {
switch (p) {
case CmpIPredicate::EQ:
return llvm::CmpInst::Predicate::ICMP_EQ;
case CmpIPredicate::NE:
return llvm::CmpInst::Predicate::ICMP_NE;
case CmpIPredicate::SLT:
return llvm::CmpInst::Predicate::ICMP_SLT;
case CmpIPredicate::SLE:
return llvm::CmpInst::Predicate::ICMP_SLE;
case CmpIPredicate::SGT:
return llvm::CmpInst::Predicate::ICMP_SGT;
case CmpIPredicate::SGE:
return llvm::CmpInst::Predicate::ICMP_SGE;
case CmpIPredicate::ULT:
return llvm::CmpInst::Predicate::ICMP_ULT;
case CmpIPredicate::ULE:
return llvm::CmpInst::Predicate::ICMP_ULE;
case CmpIPredicate::UGT:
return llvm::CmpInst::Predicate::ICMP_UGT;
case CmpIPredicate::UGE:
return llvm::CmpInst::Predicate::ICMP_UGE;
default:
llvm_unreachable("incorrect comparison predicate");
}
}
// Convert specific operation instruction types LLVM instructions.
// FIXME(zinenko): this should eventually become a separate MLIR pass that
// converts MLIR standard operations into LLVM IR dialect; the translation in
// that case would become a simple 1:1 instruction and value remapping.
bool ModuleLowerer::convertInstruction(const Instruction &inst) {
if (auto op = inst.dyn_cast<AddIOp>())
return valueMapping[op->getResult()] =
builder.CreateAdd(valueMapping[op->getOperand(0)],
valueMapping[op->getOperand(1)]),
false;
if (auto op = inst.dyn_cast<MulIOp>())
return valueMapping[op->getResult()] =
builder.CreateMul(valueMapping[op->getOperand(0)],
valueMapping[op->getOperand(1)]),
false;
if (auto op = inst.dyn_cast<CmpIOp>())
return valueMapping[op->getResult()] =
builder.CreateICmp(getLLVMCmpPredicate(op->getPredicate()),
valueMapping[op->getOperand(0)],
valueMapping[op->getOperand(1)]),
false;
if (auto op = inst.dyn_cast<AddFOp>())
return valueMapping[op->getResult()] =
builder.CreateFAdd(valueMapping.lookup(op->getOperand(0)),
valueMapping.lookup(op->getOperand(1))),
false;
if (auto op = inst.dyn_cast<MulFOp>())
return valueMapping[op->getResult()] =
builder.CreateFMul(valueMapping.lookup(op->getOperand(0)),
valueMapping.lookup(op->getOperand(1))),
false;
if (auto constantOp = inst.dyn_cast<ConstantIndexOp>()) {
auto attr = constantOp->getValue();
valueMapping[constantOp->getResult()] = getIndexConstant(attr);
return false;
}
if (auto constantOp = inst.dyn_cast<ConstantFloatOp>()) {
llvm::Type *type = convertType(constantOp->getType());
if (!type)
return true;
// TODO(somebody): float attributes have "double" semantics whatever the
// type of the constant. This should be fixed at the parser level.
if (!type->isFloatTy()) {
inst.emitError("NYI: only floats are currently supported");
return true;
}
bool unused;
auto APvalue = constantOp->getValue();
APFloat::opStatus status = APvalue.convert(
llvm::APFloat::IEEEsingle(), llvm::APFloat::rmTowardZero, &unused);
if (status == APFloat::opInexact) {
inst.emitWarning(
"Lossy conversion of a float constant to the float type");
// No return intended.
}
if (status != APFloat::opOK) {
inst.emitError("Failed to convert a floating point constant");
return true;
}
auto value = APvalue.convertToFloat();
valueMapping[constantOp->getResult()] =
llvm::ConstantFP::get(type->getContext(), llvm::APFloat(value));
return false;
}
if (auto constantOp = inst.dyn_cast<ConstantOp>()) {
llvm::Type *type = convertType(constantOp->getType());
if (!type)
return true;
if (!isa<llvm::IntegerType>(type)) {
inst.emitError("only integer types are supported");
return true;
}
auto attr = (constantOp->getValue()).cast<IntegerAttr>();
// Create a new APInt even if we can extract one from the attribute, because
// attributes are currently hardcoded to be 64-bit APInts and LLVM will
// create an i64 constant from those.
valueMapping[constantOp->getResult()] = llvm::Constant::getIntegerValue(
type, llvm::APInt(type->getIntegerBitWidth(), attr.getInt()));
return false;
}
if (auto allocOp = inst.dyn_cast<AllocOp>()) {
llvm::Value *memRefDescriptor = emitMemRefAlloc(allocOp);
if (!memRefDescriptor)
return true;
valueMapping[allocOp->getResult()] = memRefDescriptor;
return false;
}
if (auto deallocOp = inst.dyn_cast<DeallocOp>()) {
return !emitMemRefDealloc(deallocOp);
}
if (auto loadOp = inst.dyn_cast<LoadOp>()) {
llvm::Value *element = emitMemRefElementAccess(
loadOp->getMemRef(), *loadOp->getOperation(), loadOp->getIndices());
if (!element)
return true;
valueMapping[loadOp->getResult()] = builder.CreateLoad(element);
return false;
}
if (auto storeOp = inst.dyn_cast<StoreOp>()) {
llvm::Value *element = emitMemRefElementAccess(
storeOp->getMemRef(), *storeOp->getOperation(), storeOp->getIndices());
if (!element)
return true;
builder.CreateStore(valueMapping.lookup(storeOp->getValueToStore()),
element);
return false;
}
if (auto dimOp = inst.dyn_cast<DimOp>()) {
const SSAValue *container = dimOp->getOperand();
MemRefType type = container->getType().dyn_cast<MemRefType>();
if (!type)
return dimOp->emitError("only memref types are supported"), true;
auto shape = type.getShape();
auto index = dimOp->getIndex();
assert(index < shape.size() && "out-of-bounds 'dim' operation");
// If the size is a constant, just define that constant.
if (shape[index] != -1) {
valueMapping[dimOp->getResult()] = getIndexConstant(shape[index]);
return false;
}
// Otherwise, compute the position of the requested index in the list of
// dynamic sizes stored in the MemRef descriptor and extract it from there.
unsigned numLeadingDynamicSizes = 0;
for (unsigned i = 0; i < index; ++i) {
if (shape[i] == -1)
++numLeadingDynamicSizes;
}
llvm::Value *memRefDescriptor = valueMapping.lookup(container);
llvm::Value *dynamicSize = builder.CreateExtractValue(
memRefDescriptor, 1 + numLeadingDynamicSizes);
valueMapping[dimOp->getResult()] = dynamicSize;
return false;
}
if (auto callOp = inst.dyn_cast<CallOp>()) {
auto operands = functional::map(
[this](const SSAValue *value) { return valueMapping.lookup(value); },
callOp->getOperands());
auto numResults = callOp->getNumResults();
// TODO(zinenko): support tuple returns
assert(numResults <= 1 && "NYI: tuple returns");
llvm::Value *result =
builder.CreateCall(functionMapping[callOp->getCallee()], operands);
if (numResults == 1)
valueMapping[callOp->getResult(0)] = result;
return false;
}
// Terminators.
if (auto returnInst = inst.dyn_cast<ReturnOp>()) {
unsigned numOperands = returnInst->getNumOperands();
// TODO(zinenko): support tuple returns
assert(numOperands <= 1u && "NYI: tuple returns");
if (numOperands == 0)
builder.CreateRetVoid();
else
builder.CreateRet(valueMapping[returnInst->getOperand(0)]);
return false;
}
if (auto branchInst = inst.dyn_cast<BranchOp>()) {
builder.CreateBr(blockMapping[branchInst->getDest()]);
return false;
}
if (auto condBranchInst = inst.dyn_cast<CondBranchOp>()) {
builder.CreateCondBr(valueMapping[condBranchInst->getCondition()],
blockMapping[condBranchInst->getTrueDest()],
blockMapping[condBranchInst->getFalseDest()]);
return false;
}
inst.emitError("unsupported operation");
return true;
}
bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb,
bool ignoreArguments) {
builder.SetInsertPoint(blockMapping[&bb]);
// Before traversing instructions, make block arguments available through
// value remapping and PHI nodes, but do not add incoming edges for the PHI
// nodes just yet: those values may be defined by this or following blocks.
// This step is omitted if "ignoreArguments" is set. The arguments of the
// first basic block have been already made available through the remapping of
// LLVM function arguments.
if (!ignoreArguments) {
auto predecessors = bb.getPredecessors();
unsigned numPredecessors =
std::distance(predecessors.begin(), predecessors.end());
for (const auto *arg : bb.getArguments()) {
llvm::Type *type = convertType(arg->getType());
if (!type)
return true;
llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
valueMapping[arg] = phi;
}
}
// Traverse instructions.
for (const auto &inst : bb) {
if (convertInstruction(inst))
return true;
}
return false;
}
// Get the SSA value passed to the current block from the terminator instruction
// of its predecessor.
static const SSAValue *getPHISourceValue(const BasicBlock *current,
const BasicBlock *pred,
unsigned numArguments,
unsigned index) {
const Instruction &terminator = *pred->getTerminator();
if (terminator.isa<BranchOp>()) {
return terminator.getOperand(index);
}
// For conditional branches, we need to check if the current block is reached
// through the "true" or the "false" branch and take the relevant operands.
auto condBranchOp = terminator.dyn_cast<CondBranchOp>();
assert(condBranchOp &&
"only branch instructions can be terminators of a basic block that "
"has successors");
condBranchOp->emitError("NYI: conditional branches with arguments");
return nullptr;
}
void ModuleLowerer::connectPHINodes(const CFGFunction &cfgFunc) {
// Skip the first block, it cannot be branched to and its arguments correspond
// to the arguments of the LLVM function.
for (auto it = std::next(cfgFunc.begin()), eit = cfgFunc.end(); it != eit;
++it) {
const BasicBlock *bb = &*it;
llvm::BasicBlock *llvmBB = blockMapping[bb];
auto phis = llvmBB->phis();
auto numArguments = bb->getNumArguments();
assert(numArguments == std::distance(phis.begin(), phis.end()));
for (auto &numberedPhiNode : llvm::enumerate(phis)) {
auto &phiNode = numberedPhiNode.value();
unsigned index = numberedPhiNode.index();
for (const auto *pred : bb->getPredecessors()) {
phiNode.addIncoming(
valueMapping[getPHISourceValue(bb, pred, numArguments, index)],
blockMapping[pred]);
}
}
}
}
bool ModuleLowerer::convertCFGFunction(const CFGFunction &cfgFunc,
llvm::Function &llvmFunc) {
// Clear the block mapping. Blocks belong to a function, no need to keep
// blocks from the previous functions around. Furthermore, we use this
// mapping to connect PHI nodes inside the function later.
blockMapping.clear();
// First, create all blocks so we can jump to them.
for (const auto &bb : cfgFunc) {
auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
llvmBB->insertInto(&llvmFunc);
blockMapping[&bb] = llvmBB;
}
// Then, convert blocks one by one.
for (auto indexedBB : llvm::enumerate(cfgFunc)) {
const auto &bb = indexedBB.value();
if (convertBasicBlock(bb, /*ignoreArguments=*/indexedBB.index() == 0))
return true;
}
// Finally, after all blocks have been traversed and values mapped, connect
// the PHI nodes to the results of preceding blocks.
connectPHINodes(cfgFunc);
return false;
}
bool ModuleLowerer::convertFunctions(const Module &mlirModule,
llvm::Module &llvmModule) {
// Declare all functions first because there may be function calls that form a
// call graph with cycles. We don't expect MLFunctions here.
for (const Function &function : mlirModule) {
const Function *functionPtr = &function;
if (!isa<ExtFunction>(functionPtr) && !isa<CFGFunction>(functionPtr))
continue;
llvm::Constant *llvmFuncCst = llvmModule.getOrInsertFunction(
function.getName(), convertFunctionType(function.getType()));
assert(isa<llvm::Function>(llvmFuncCst));
functionMapping[functionPtr] = cast<llvm::Function>(llvmFuncCst);
}
// Convert CFG functions.
for (const Function &function : mlirModule) {
const Function *functionPtr = &function;
auto cfgFunction = dyn_cast<CFGFunction>(functionPtr);
if (!cfgFunction)
continue;
llvm::Function *llvmFunc = functionMapping[cfgFunction];
// Add function arguments to the value remapping table. In CFGFunction,
// arguments of the first block are those of the function.
assert(!cfgFunction->getBlocks().empty() &&
"expected at least one basic block in a CFGFunction");
const BasicBlock &firstBlock = *cfgFunction->begin();
for (auto arg : llvm::enumerate(llvmFunc->args())) {
valueMapping[firstBlock.getArgument(arg.index())] = &arg.value();
}
if (convertCFGFunction(*cfgFunction, *functionMapping[cfgFunction]))
return true;
}
return false;
}
bool ModuleLowerer::runOnModule(Module &m, llvm::Module &llvmModule) {
// Create index type once for the entire module, it needs module info that is
// not available in the convert*Type calls.
indexType =
builder.getIntNTy(llvmModule.getDataLayout().getPointerSizeInBits());
// Declare or obtain (de)allocation functions.
allocFunc = llvmModule.getOrInsertFunction("__mlir_alloc",
builder.getInt8PtrTy(), indexType);
freeFunc = llvmModule.getOrInsertFunction("__mlir_free", builder.getVoidTy(),
builder.getInt8PtrTy());
return convertFunctions(m, llvmModule);
}
} // namespace
// Entry point for the lowering procedure.
std::unique_ptr<llvm::Module>
mlir::convertModuleToLLVMIR(Module &module, llvm::LLVMContext &llvmContext) {
auto llvmModule = llvm::make_unique<llvm::Module>("FIXME_name", llvmContext);
if (ModuleLowerer(llvmContext).runOnModule(module, *llvmModule))
return nullptr;
return llvmModule;
}
// MLIR to LLVM IR translation registration.
static TranslateFromMLIRRegistration MLIRToLLVMIRTranslate(
"mlir-to-llvmir", [](Module *module, llvm::StringRef outputFilename) {
if (!module)
return true;
llvm::LLVMContext llvmContext;
auto llvmModule = convertModuleToLLVMIR(*module, llvmContext);
if (!llvmModule)
return true;
auto file = openOutputFile(outputFilename);
if (!file)
return true;
llvmModule->print(file->os(), nullptr);
file->keep();
return false;
});