| //===- ConvertToLLVMIR.cpp - MLIR to LLVM IR conversion ---------*- C++ -*-===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| // |
| // This file implements a pass that converts CFG function to LLVM IR. No ML |
| // functions must be presented in MLIR. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/CFGFunction.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/StandardOps/StandardOps.h" |
| #include "mlir/Support/FileUtilities.h" |
| #include "mlir/Support/Functional.h" |
| #include "mlir/Target/LLVMIR.h" |
| #include "mlir/Translation.h" |
| #include "llvm/ADT/APFloat.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/IR/Constant.h" |
| #include "llvm/IR/DerivedTypes.h" |
| #include "llvm/IR/Function.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/LLVMContext.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/IR/Type.h" |
| |
| using namespace mlir; |
| |
| namespace { |
| class ModuleLowerer { |
| public: |
| explicit ModuleLowerer(llvm::LLVMContext &llvmContext) |
| : llvmContext(llvmContext), builder(llvmContext) {} |
| |
| bool runOnModule(Module &m, llvm::Module &llvmModule); |
| |
| private: |
| bool convertBasicBlock(const BasicBlock &bb, bool ignoreArguments = false); |
| bool convertCFGFunction(const CFGFunction &cfgFunc, llvm::Function &llvmFunc); |
| bool convertFunctions(const Module &mlirModule, llvm::Module &llvmModule); |
| bool convertInstruction(const Instruction &inst); |
| |
| void connectPHINodes(const CFGFunction &cfgFunc); |
| |
| /// Type conversion functions. If any conversion fails, report errors to the |
| /// context of the MLIR type and return nullptr. |
| /// \{ |
| llvm::FunctionType *convertFunctionType(FunctionType type); |
| llvm::IntegerType *convertIndexType(IndexType type); |
| llvm::IntegerType *convertIntegerType(IntegerType type); |
| llvm::Type *convertFloatType(FloatType type); |
| llvm::Type *convertType(Type type); |
| /// Convert a MemRefType `type` into an LLVM aggregate structure type. Each |
| /// structure type starts with a pointer to the elemental type of the MemRef |
| /// and continues with as many lowered to LLVM index types as MemRef has |
| /// dynamic dimensions. An instance of this type is called a MemRef decriptor |
| /// and replaces the MemRef everywhere it is used so that any instruction has |
| /// access to its dynamic sizes. |
| /// For example, given that `index` is converted to `i64`, `memref<?x?xf32>` |
| /// is converted to `{float*, i64, i64}` (two dynamic sizes, in order); |
| /// `memref<42x?x42xi32>` is converted to `{i32*, i64}` (only one size is |
| /// dynamic); `memref<2x3x4xf64>` is converted to `{double*}`. |
| llvm::StructType *convertMemRefType(MemRefType type); |
| /// \} |
| |
| /// Get an a constant value of `indexType`. |
| inline llvm::Constant *getIndexConstant(int64_t value); |
| |
| /// Given subscript indices and array sizes in row-major order, |
| /// i_n, i_{n-1}, ..., i_1 |
| /// s_n, s_{n-1}, ..., s_1 |
| /// obtain a value that corresponds to the linearized subscript |
| /// i_n * s_{n-1} * s_{n-2} * ... * s_1 + |
| /// + i_{n-1} * s_{n-2} * s_{n_3} * ... * s_1 + |
| /// + ... + |
| /// + i_2 * s_1 + |
| /// + i_1. |
| llvm::Value *linearizeSubscripts(ArrayRef<llvm::Value *> indices, |
| ArrayRef<llvm::Value *> allocSizes); |
| |
| /// Emit LLVM IR instructions necessary to obtain a pointer to the element of |
| /// `memRef` accessed by `op` with indices `opIndices`. In particular, extract |
| /// any dynamic allocation sizes from the MemRef descriptor, linearize the |
| /// access subscript given the sizes, extract the data pointer from the MemRef |
| /// descriptor and get the pointer to the element indexed by the linearized |
| /// subscript. Return nullptr on errors. |
| llvm::Value *emitMemRefElementAccess( |
| const SSAValue *memRef, const Operation &op, |
| llvm::iterator_range<Operation::const_operand_iterator> opIndices); |
| |
| /// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create |
| /// a Value for the MemRef descriptor, store any dynamic sizes passed to |
| /// the alloc operation in the descriptor, allocate the buffer for the data |
| /// using `allocFunc` and also store it in the descriptor. Return the MemRef |
| /// descriptor. This function returns `nullptr` in case of errors. |
| llvm::Value *emitMemRefAlloc(ConstOpPointer<AllocOp> allocOp); |
| |
| /// Emit LLVM IR corresponding to the given Dealloc `op`. In particular, |
| /// use `freeFunc` to free the memory allocated for the MemRef's buffer. The |
| /// MemRef descriptor allocated on stack will cease to exist when the current |
| /// function returns without any extra action. Returns an LLVM Value (call |
| /// instruction) on success and nullptr on error. |
| llvm::Value *emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp); |
| |
| llvm::DenseMap<const Function *, llvm::Function *> functionMapping; |
| llvm::DenseMap<const SSAValue *, llvm::Value *> valueMapping; |
| llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping; |
| llvm::LLVMContext &llvmContext; |
| llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> builder; |
| llvm::IntegerType *indexType; |
| |
| /// Allocation function : (index) -> i8*, declaration only. |
| llvm::Constant *allocFunc; |
| /// Deallocation function : (i8*) -> void, declaration only. |
| llvm::Constant *freeFunc; |
| }; |
| |
| llvm::IntegerType *ModuleLowerer::convertIndexType(IndexType type) { |
| return indexType; |
| } |
| |
| llvm::IntegerType *ModuleLowerer::convertIntegerType(IntegerType type) { |
| return builder.getIntNTy(type.getBitWidth()); |
| } |
| |
| llvm::Type *ModuleLowerer::convertFloatType(FloatType type) { |
| MLIRContext *context = type.getContext(); |
| switch (type.getKind()) { |
| case Type::Kind::F32: |
| return builder.getFloatTy(); |
| case Type::Kind::F64: |
| return builder.getDoubleTy(); |
| case Type::Kind::F16: |
| return builder.getHalfTy(); |
| case Type::Kind::BF16: |
| return context->emitError(UnknownLoc::get(context), |
| "Unsupported type: BF16"), |
| nullptr; |
| default: |
| llvm_unreachable("non-float type in convertFloatType"); |
| } |
| } |
| |
| llvm::FunctionType *ModuleLowerer::convertFunctionType(FunctionType type) { |
| // TODO(zinenko): convert tuple to LLVM structure types |
| assert(type.getNumResults() <= 1 && "NYI: tuple returns"); |
| auto resultType = type.getNumResults() == 0 |
| ? llvm::Type::getVoidTy(llvmContext) |
| : convertType(type.getResult(0)); |
| if (!resultType) |
| return nullptr; |
| |
| auto argTypes = |
| functional::map([this](Type inputType) { return convertType(inputType); }, |
| type.getInputs()); |
| if (std::any_of(argTypes.begin(), argTypes.end(), |
| [](const llvm::Type *t) { return t == nullptr; })) |
| return nullptr; |
| |
| return llvm::FunctionType::get(resultType, argTypes, /*isVarArg=*/false); |
| } |
| |
| // MemRefs are converted into LLVM structure types to accomodate dynamic sizes. |
| // The first element of a structure is a pointer to the elemental type of the |
| // MemRef. The following N elements are values of the Index type, one for each |
| // of N dynamic dimensions of the MemRef. |
| llvm::StructType *ModuleLowerer::convertMemRefType(MemRefType type) { |
| llvm::Type *elementType = convertType(type.getElementType()); |
| if (!elementType) |
| return nullptr; |
| elementType = elementType->getPointerTo(); |
| |
| // Extra value for the memory space. |
| unsigned numDynamicSizes = type.getNumDynamicDims(); |
| SmallVector<llvm::Type *, 8> types(numDynamicSizes + 1, indexType); |
| types.front() = elementType; |
| |
| return llvm::StructType::get(llvmContext, types); |
| } |
| |
| llvm::Type *ModuleLowerer::convertType(Type type) { |
| if (auto funcType = type.dyn_cast<FunctionType>()) |
| return convertFunctionType(funcType); |
| if (auto intType = type.dyn_cast<IntegerType>()) |
| return convertIntegerType(intType); |
| if (auto floatType = type.dyn_cast<FloatType>()) |
| return convertFloatType(floatType); |
| if (auto indexType = type.dyn_cast<IndexType>()) |
| return convertIndexType(indexType); |
| if (auto memRefType = type.dyn_cast<MemRefType>()) |
| return convertMemRefType(memRefType); |
| |
| MLIRContext *context = type.getContext(); |
| std::string message; |
| llvm::raw_string_ostream os(message); |
| os << "unsupported type: "; |
| type.print(os); |
| context->emitError(UnknownLoc::get(context), os.str()); |
| return nullptr; |
| } |
| |
| llvm::Constant *ModuleLowerer::getIndexConstant(int64_t value) { |
| return llvm::Constant::getIntegerValue( |
| indexType, llvm::APInt(indexType->getBitWidth(), value)); |
| } |
| |
| // Given subscript indices and array sizes in row-major order, |
| // i_n, i_{n-1}, ..., i_1 |
| // s_n, s_{n-1}, ..., s_1 |
| // obtain a value that corresponds to the linearized subscript |
| // \sum_k i_k * \prod_{j=1}^{k-1} s_j |
| // by accumulating the running linearized value. |
| llvm::Value * |
| ModuleLowerer::linearizeSubscripts(ArrayRef<llvm::Value *> indices, |
| ArrayRef<llvm::Value *> allocSizes) { |
| assert(indices.size() == allocSizes.size() && |
| "mismatching number of indices and allocation sizes"); |
| assert(!indices.empty() && "cannot linearize a 0-dimensional access"); |
| |
| llvm::Value *linearized = indices.front(); |
| for (unsigned i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { |
| linearized = builder.CreateMul(linearized, allocSizes[i]); |
| linearized = builder.CreateAdd(linearized, indices[i]); |
| } |
| |
| return linearized; |
| } |
| |
| // Check if the MemRefType `type` is supported by the lowering. Emit errors at |
| // the location of `op` and return true. Return false if the type is supported. |
| // TODO(zinenko): this function should disappear when the conversion fully |
| // supports MemRefs. |
| static bool checkSupportedMemRefType(MemRefType type, const Operation &op) { |
| if (!type.getAffineMaps().empty()) { |
| op.emitError("NYI: memrefs with affine maps"); |
| return true; |
| } |
| if (type.getMemorySpace() != 0) { |
| op.emitError("NYI: non-default memory space"); |
| return true; |
| } |
| return false; |
| } |
| |
| llvm::Value *ModuleLowerer::emitMemRefElementAccess( |
| const SSAValue *memRef, const Operation &op, |
| llvm::iterator_range<Operation::const_operand_iterator> opIndices) { |
| auto type = memRef->getType().dyn_cast<MemRefType>(); |
| assert(type && "expected memRef value to have a MemRef type"); |
| if (checkSupportedMemRefType(type, op)) |
| return nullptr; |
| |
| // A MemRef-typed value is remapped to its descriptor. |
| llvm::Value *memRefDescriptor = valueMapping.lookup(memRef); |
| |
| // Get the list of MemRef sizes. Static sizes are defined as values. Dynamic |
| // sizes are extracted from the MemRef descriptor. |
| llvm::SmallVector<llvm::Value *, 4> sizes; |
| unsigned dynanmicSizeIdx = 0; |
| for (int64_t s : type.getShape()) { |
| llvm::Value *size = (s == -1) ? builder.CreateExtractValue( |
| memRefDescriptor, 1 + dynanmicSizeIdx++) |
| : getIndexConstant(s); |
| sizes.push_back(size); |
| } |
| |
| // Obtain the list of access subscripts as values and linearize it given the |
| // list of sizes. |
| auto indices = functional::map( |
| [this](const SSAValue *value) { return valueMapping.lookup(value); }, |
| opIndices); |
| auto subscript = linearizeSubscripts(indices, sizes); |
| |
| // Extract the pointer to the data buffer and use LLVM's getelementptr to |
| // repoint it to the element indexed by the subscript. |
| llvm::Value *data = builder.CreateExtractValue(memRefDescriptor, 0); |
| return builder.CreateGEP(data, subscript); |
| } |
| |
| llvm::Value *ModuleLowerer::emitMemRefAlloc(ConstOpPointer<AllocOp> allocOp) { |
| MemRefType type = allocOp->getType(); |
| if (checkSupportedMemRefType(type, *allocOp->getOperation())) |
| return nullptr; |
| |
| // Get actual sizes of the memref as values: static sizes are constant |
| // values and dynamic sizes are passed to 'alloc' as operands. |
| SmallVector<llvm::Value *, 4> sizes; |
| sizes.reserve(allocOp->getNumOperands()); |
| unsigned i = 0; |
| for (int s : type.getShape()) { |
| llvm::Value *value = (s == -1) |
| ? valueMapping.lookup(allocOp->getOperand(i++)) |
| : getIndexConstant(s); |
| sizes.push_back(value); |
| } |
| assert(!sizes.empty() && "zero-dimensional allocation"); |
| |
| // Compute the total numer of memref elements as Value. |
| llvm::Value *cumulativeSize = sizes.front(); |
| for (unsigned i = 1, e = sizes.size(); i < e; ++i) { |
| cumulativeSize = builder.CreateMul(cumulativeSize, sizes[i]); |
| } |
| |
| // Allocate the MemRef descriptor on stack and load it. |
| llvm::StructType *structType = convertMemRefType(type); |
| llvm::Type *elementType = convertType(type.getElementType()); |
| if (!structType || !elementType) |
| return nullptr; |
| llvm::Value *memRefDescriptor = llvm::UndefValue::get(structType); |
| |
| // Take into account the size of the elemental type before allocation. |
| // Elemental types can be scalars or vectors only. |
| unsigned byteWidth = elementType->getScalarSizeInBits() / 8; |
| assert(byteWidth > 0 && "could not determine size of a MemRef element"); |
| if (elementType->isVectorTy()) { |
| byteWidth *= elementType->getVectorNumElements(); |
| } |
| llvm::Value *byteWidthValue = getIndexConstant(byteWidth); |
| cumulativeSize = builder.CreateMul(cumulativeSize, byteWidthValue); |
| |
| // Allocate the buffer for theMemRef and store a pointer to it in the MemRef |
| // descriptor. |
| llvm::Value *allocated = builder.CreateCall(allocFunc, cumulativeSize); |
| allocated = builder.CreateBitCast(allocated, elementType->getPointerTo()); |
| memRefDescriptor = builder.CreateInsertValue(memRefDescriptor, allocated, 0); |
| |
| // Store dynamically allocated sizes in the descriptor. |
| i = 0; |
| for (auto indexedSize : llvm::enumerate(sizes)) { |
| if (type.getShape()[indexedSize.index()] != -1) |
| continue; |
| memRefDescriptor = builder.CreateInsertValue(memRefDescriptor, |
| indexedSize.value(), 1 + i++); |
| } |
| |
| // Return the final value of the descriptor (each insert returns a new, |
| // updated value, the old is still accessible but has old data). |
| return memRefDescriptor; |
| } |
| |
| llvm::Value * |
| ModuleLowerer::emitMemRefDealloc(ConstOpPointer<DeallocOp> deallocOp) { |
| // Extract the pointer to the MemRef buffer from its descriptor and call |
| // `freeFunc` on it. |
| llvm::Value *memRefDescriptor = valueMapping.lookup(deallocOp->getMemRef()); |
| llvm::Value *data = builder.CreateExtractValue(memRefDescriptor, 0); |
| data = builder.CreateBitCast(data, builder.getInt8PtrTy()); |
| return builder.CreateCall(freeFunc, data); |
| } |
| |
| static llvm::CmpInst::Predicate getLLVMCmpPredicate(CmpIPredicate p) { |
| switch (p) { |
| case CmpIPredicate::EQ: |
| return llvm::CmpInst::Predicate::ICMP_EQ; |
| case CmpIPredicate::NE: |
| return llvm::CmpInst::Predicate::ICMP_NE; |
| case CmpIPredicate::SLT: |
| return llvm::CmpInst::Predicate::ICMP_SLT; |
| case CmpIPredicate::SLE: |
| return llvm::CmpInst::Predicate::ICMP_SLE; |
| case CmpIPredicate::SGT: |
| return llvm::CmpInst::Predicate::ICMP_SGT; |
| case CmpIPredicate::SGE: |
| return llvm::CmpInst::Predicate::ICMP_SGE; |
| case CmpIPredicate::ULT: |
| return llvm::CmpInst::Predicate::ICMP_ULT; |
| case CmpIPredicate::ULE: |
| return llvm::CmpInst::Predicate::ICMP_ULE; |
| case CmpIPredicate::UGT: |
| return llvm::CmpInst::Predicate::ICMP_UGT; |
| case CmpIPredicate::UGE: |
| return llvm::CmpInst::Predicate::ICMP_UGE; |
| default: |
| llvm_unreachable("incorrect comparison predicate"); |
| } |
| } |
| |
| // Convert specific operation instruction types LLVM instructions. |
| // FIXME(zinenko): this should eventually become a separate MLIR pass that |
| // converts MLIR standard operations into LLVM IR dialect; the translation in |
| // that case would become a simple 1:1 instruction and value remapping. |
| bool ModuleLowerer::convertInstruction(const Instruction &inst) { |
| if (auto op = inst.dyn_cast<AddIOp>()) |
| return valueMapping[op->getResult()] = |
| builder.CreateAdd(valueMapping[op->getOperand(0)], |
| valueMapping[op->getOperand(1)]), |
| false; |
| if (auto op = inst.dyn_cast<MulIOp>()) |
| return valueMapping[op->getResult()] = |
| builder.CreateMul(valueMapping[op->getOperand(0)], |
| valueMapping[op->getOperand(1)]), |
| false; |
| if (auto op = inst.dyn_cast<CmpIOp>()) |
| return valueMapping[op->getResult()] = |
| builder.CreateICmp(getLLVMCmpPredicate(op->getPredicate()), |
| valueMapping[op->getOperand(0)], |
| valueMapping[op->getOperand(1)]), |
| false; |
| |
| if (auto op = inst.dyn_cast<AddFOp>()) |
| return valueMapping[op->getResult()] = |
| builder.CreateFAdd(valueMapping.lookup(op->getOperand(0)), |
| valueMapping.lookup(op->getOperand(1))), |
| false; |
| if (auto op = inst.dyn_cast<MulFOp>()) |
| return valueMapping[op->getResult()] = |
| builder.CreateFMul(valueMapping.lookup(op->getOperand(0)), |
| valueMapping.lookup(op->getOperand(1))), |
| false; |
| |
| if (auto constantOp = inst.dyn_cast<ConstantIndexOp>()) { |
| auto attr = constantOp->getValue(); |
| valueMapping[constantOp->getResult()] = getIndexConstant(attr); |
| return false; |
| } |
| if (auto constantOp = inst.dyn_cast<ConstantFloatOp>()) { |
| llvm::Type *type = convertType(constantOp->getType()); |
| if (!type) |
| return true; |
| // TODO(somebody): float attributes have "double" semantics whatever the |
| // type of the constant. This should be fixed at the parser level. |
| if (!type->isFloatTy()) { |
| inst.emitError("NYI: only floats are currently supported"); |
| return true; |
| } |
| bool unused; |
| auto APvalue = constantOp->getValue(); |
| APFloat::opStatus status = APvalue.convert( |
| llvm::APFloat::IEEEsingle(), llvm::APFloat::rmTowardZero, &unused); |
| if (status == APFloat::opInexact) { |
| inst.emitWarning( |
| "Lossy conversion of a float constant to the float type"); |
| // No return intended. |
| } |
| if (status != APFloat::opOK) { |
| inst.emitError("Failed to convert a floating point constant"); |
| return true; |
| } |
| auto value = APvalue.convertToFloat(); |
| valueMapping[constantOp->getResult()] = |
| llvm::ConstantFP::get(type->getContext(), llvm::APFloat(value)); |
| return false; |
| } |
| if (auto constantOp = inst.dyn_cast<ConstantOp>()) { |
| llvm::Type *type = convertType(constantOp->getType()); |
| if (!type) |
| return true; |
| if (!isa<llvm::IntegerType>(type)) { |
| inst.emitError("only integer types are supported"); |
| return true; |
| } |
| auto attr = (constantOp->getValue()).cast<IntegerAttr>(); |
| // Create a new APInt even if we can extract one from the attribute, because |
| // attributes are currently hardcoded to be 64-bit APInts and LLVM will |
| // create an i64 constant from those. |
| valueMapping[constantOp->getResult()] = llvm::Constant::getIntegerValue( |
| type, llvm::APInt(type->getIntegerBitWidth(), attr.getInt())); |
| |
| return false; |
| } |
| |
| if (auto allocOp = inst.dyn_cast<AllocOp>()) { |
| llvm::Value *memRefDescriptor = emitMemRefAlloc(allocOp); |
| if (!memRefDescriptor) |
| return true; |
| |
| valueMapping[allocOp->getResult()] = memRefDescriptor; |
| return false; |
| } |
| if (auto deallocOp = inst.dyn_cast<DeallocOp>()) { |
| return !emitMemRefDealloc(deallocOp); |
| } |
| |
| if (auto loadOp = inst.dyn_cast<LoadOp>()) { |
| llvm::Value *element = emitMemRefElementAccess( |
| loadOp->getMemRef(), *loadOp->getOperation(), loadOp->getIndices()); |
| if (!element) |
| return true; |
| |
| valueMapping[loadOp->getResult()] = builder.CreateLoad(element); |
| return false; |
| } |
| if (auto storeOp = inst.dyn_cast<StoreOp>()) { |
| llvm::Value *element = emitMemRefElementAccess( |
| storeOp->getMemRef(), *storeOp->getOperation(), storeOp->getIndices()); |
| if (!element) |
| return true; |
| |
| builder.CreateStore(valueMapping.lookup(storeOp->getValueToStore()), |
| element); |
| return false; |
| } |
| if (auto dimOp = inst.dyn_cast<DimOp>()) { |
| const SSAValue *container = dimOp->getOperand(); |
| MemRefType type = container->getType().dyn_cast<MemRefType>(); |
| if (!type) |
| return dimOp->emitError("only memref types are supported"), true; |
| |
| auto shape = type.getShape(); |
| auto index = dimOp->getIndex(); |
| assert(index < shape.size() && "out-of-bounds 'dim' operation"); |
| |
| // If the size is a constant, just define that constant. |
| if (shape[index] != -1) { |
| valueMapping[dimOp->getResult()] = getIndexConstant(shape[index]); |
| return false; |
| } |
| |
| // Otherwise, compute the position of the requested index in the list of |
| // dynamic sizes stored in the MemRef descriptor and extract it from there. |
| unsigned numLeadingDynamicSizes = 0; |
| for (unsigned i = 0; i < index; ++i) { |
| if (shape[i] == -1) |
| ++numLeadingDynamicSizes; |
| } |
| llvm::Value *memRefDescriptor = valueMapping.lookup(container); |
| llvm::Value *dynamicSize = builder.CreateExtractValue( |
| memRefDescriptor, 1 + numLeadingDynamicSizes); |
| valueMapping[dimOp->getResult()] = dynamicSize; |
| return false; |
| } |
| |
| if (auto callOp = inst.dyn_cast<CallOp>()) { |
| auto operands = functional::map( |
| [this](const SSAValue *value) { return valueMapping.lookup(value); }, |
| callOp->getOperands()); |
| auto numResults = callOp->getNumResults(); |
| // TODO(zinenko): support tuple returns |
| assert(numResults <= 1 && "NYI: tuple returns"); |
| |
| llvm::Value *result = |
| builder.CreateCall(functionMapping[callOp->getCallee()], operands); |
| if (numResults == 1) |
| valueMapping[callOp->getResult(0)] = result; |
| return false; |
| } |
| |
| // Terminators. |
| if (auto returnInst = inst.dyn_cast<ReturnOp>()) { |
| unsigned numOperands = returnInst->getNumOperands(); |
| // TODO(zinenko): support tuple returns |
| assert(numOperands <= 1u && "NYI: tuple returns"); |
| |
| if (numOperands == 0) |
| builder.CreateRetVoid(); |
| else |
| builder.CreateRet(valueMapping[returnInst->getOperand(0)]); |
| |
| return false; |
| } |
| if (auto branchInst = inst.dyn_cast<BranchOp>()) { |
| builder.CreateBr(blockMapping[branchInst->getDest()]); |
| return false; |
| } |
| if (auto condBranchInst = inst.dyn_cast<CondBranchOp>()) { |
| builder.CreateCondBr(valueMapping[condBranchInst->getCondition()], |
| blockMapping[condBranchInst->getTrueDest()], |
| blockMapping[condBranchInst->getFalseDest()]); |
| return false; |
| } |
| inst.emitError("unsupported operation"); |
| return true; |
| } |
| |
| bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb, |
| bool ignoreArguments) { |
| builder.SetInsertPoint(blockMapping[&bb]); |
| |
| // Before traversing instructions, make block arguments available through |
| // value remapping and PHI nodes, but do not add incoming edges for the PHI |
| // nodes just yet: those values may be defined by this or following blocks. |
| // This step is omitted if "ignoreArguments" is set. The arguments of the |
| // first basic block have been already made available through the remapping of |
| // LLVM function arguments. |
| if (!ignoreArguments) { |
| auto predecessors = bb.getPredecessors(); |
| unsigned numPredecessors = |
| std::distance(predecessors.begin(), predecessors.end()); |
| for (const auto *arg : bb.getArguments()) { |
| llvm::Type *type = convertType(arg->getType()); |
| if (!type) |
| return true; |
| llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors); |
| valueMapping[arg] = phi; |
| } |
| } |
| |
| // Traverse instructions. |
| for (const auto &inst : bb) { |
| if (convertInstruction(inst)) |
| return true; |
| } |
| |
| return false; |
| } |
| |
| // Get the SSA value passed to the current block from the terminator instruction |
| // of its predecessor. |
| static const SSAValue *getPHISourceValue(const BasicBlock *current, |
| const BasicBlock *pred, |
| unsigned numArguments, |
| unsigned index) { |
| const Instruction &terminator = *pred->getTerminator(); |
| if (terminator.isa<BranchOp>()) { |
| return terminator.getOperand(index); |
| } |
| |
| // For conditional branches, we need to check if the current block is reached |
| // through the "true" or the "false" branch and take the relevant operands. |
| auto condBranchOp = terminator.dyn_cast<CondBranchOp>(); |
| assert(condBranchOp && |
| "only branch instructions can be terminators of a basic block that " |
| "has successors"); |
| |
| condBranchOp->emitError("NYI: conditional branches with arguments"); |
| return nullptr; |
| } |
| |
| void ModuleLowerer::connectPHINodes(const CFGFunction &cfgFunc) { |
| // Skip the first block, it cannot be branched to and its arguments correspond |
| // to the arguments of the LLVM function. |
| for (auto it = std::next(cfgFunc.begin()), eit = cfgFunc.end(); it != eit; |
| ++it) { |
| const BasicBlock *bb = &*it; |
| llvm::BasicBlock *llvmBB = blockMapping[bb]; |
| auto phis = llvmBB->phis(); |
| auto numArguments = bb->getNumArguments(); |
| assert(numArguments == std::distance(phis.begin(), phis.end())); |
| for (auto &numberedPhiNode : llvm::enumerate(phis)) { |
| auto &phiNode = numberedPhiNode.value(); |
| unsigned index = numberedPhiNode.index(); |
| for (const auto *pred : bb->getPredecessors()) { |
| phiNode.addIncoming( |
| valueMapping[getPHISourceValue(bb, pred, numArguments, index)], |
| blockMapping[pred]); |
| } |
| } |
| } |
| } |
| |
| bool ModuleLowerer::convertCFGFunction(const CFGFunction &cfgFunc, |
| llvm::Function &llvmFunc) { |
| // Clear the block mapping. Blocks belong to a function, no need to keep |
| // blocks from the previous functions around. Furthermore, we use this |
| // mapping to connect PHI nodes inside the function later. |
| blockMapping.clear(); |
| // First, create all blocks so we can jump to them. |
| for (const auto &bb : cfgFunc) { |
| auto *llvmBB = llvm::BasicBlock::Create(llvmContext); |
| llvmBB->insertInto(&llvmFunc); |
| blockMapping[&bb] = llvmBB; |
| } |
| |
| // Then, convert blocks one by one. |
| for (auto indexedBB : llvm::enumerate(cfgFunc)) { |
| const auto &bb = indexedBB.value(); |
| if (convertBasicBlock(bb, /*ignoreArguments=*/indexedBB.index() == 0)) |
| return true; |
| } |
| |
| // Finally, after all blocks have been traversed and values mapped, connect |
| // the PHI nodes to the results of preceding blocks. |
| connectPHINodes(cfgFunc); |
| return false; |
| } |
| |
| bool ModuleLowerer::convertFunctions(const Module &mlirModule, |
| llvm::Module &llvmModule) { |
| // Declare all functions first because there may be function calls that form a |
| // call graph with cycles. We don't expect MLFunctions here. |
| for (const Function &function : mlirModule) { |
| const Function *functionPtr = &function; |
| if (!isa<ExtFunction>(functionPtr) && !isa<CFGFunction>(functionPtr)) |
| continue; |
| llvm::Constant *llvmFuncCst = llvmModule.getOrInsertFunction( |
| function.getName(), convertFunctionType(function.getType())); |
| assert(isa<llvm::Function>(llvmFuncCst)); |
| functionMapping[functionPtr] = cast<llvm::Function>(llvmFuncCst); |
| } |
| |
| // Convert CFG functions. |
| for (const Function &function : mlirModule) { |
| const Function *functionPtr = &function; |
| auto cfgFunction = dyn_cast<CFGFunction>(functionPtr); |
| if (!cfgFunction) |
| continue; |
| llvm::Function *llvmFunc = functionMapping[cfgFunction]; |
| |
| // Add function arguments to the value remapping table. In CFGFunction, |
| // arguments of the first block are those of the function. |
| assert(!cfgFunction->getBlocks().empty() && |
| "expected at least one basic block in a CFGFunction"); |
| const BasicBlock &firstBlock = *cfgFunction->begin(); |
| for (auto arg : llvm::enumerate(llvmFunc->args())) { |
| valueMapping[firstBlock.getArgument(arg.index())] = &arg.value(); |
| } |
| |
| if (convertCFGFunction(*cfgFunction, *functionMapping[cfgFunction])) |
| return true; |
| } |
| return false; |
| } |
| |
| bool ModuleLowerer::runOnModule(Module &m, llvm::Module &llvmModule) { |
| // Create index type once for the entire module, it needs module info that is |
| // not available in the convert*Type calls. |
| indexType = |
| builder.getIntNTy(llvmModule.getDataLayout().getPointerSizeInBits()); |
| |
| // Declare or obtain (de)allocation functions. |
| allocFunc = llvmModule.getOrInsertFunction("__mlir_alloc", |
| builder.getInt8PtrTy(), indexType); |
| freeFunc = llvmModule.getOrInsertFunction("__mlir_free", builder.getVoidTy(), |
| builder.getInt8PtrTy()); |
| |
| return convertFunctions(m, llvmModule); |
| } |
| } // namespace |
| |
| // Entry point for the lowering procedure. |
| std::unique_ptr<llvm::Module> |
| mlir::convertModuleToLLVMIR(Module &module, llvm::LLVMContext &llvmContext) { |
| auto llvmModule = llvm::make_unique<llvm::Module>("FIXME_name", llvmContext); |
| if (ModuleLowerer(llvmContext).runOnModule(module, *llvmModule)) |
| return nullptr; |
| return llvmModule; |
| } |
| |
| // MLIR to LLVM IR translation registration. |
| static TranslateFromMLIRRegistration MLIRToLLVMIRTranslate( |
| "mlir-to-llvmir", [](Module *module, llvm::StringRef outputFilename) { |
| if (!module) |
| return true; |
| |
| llvm::LLVMContext llvmContext; |
| auto llvmModule = convertModuleToLLVMIR(*module, llvmContext); |
| if (!llvmModule) |
| return true; |
| |
| auto file = openOutputFile(outputFilename); |
| if (!file) |
| return true; |
| |
| llvmModule->print(file->os(), nullptr); |
| file->keep(); |
| return false; |
| }); |