| //===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| // |
| // This file implements a pass to convert MLIR standard and builtin dialects |
| // into the LLVM IR dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" |
| #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" |
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/StandardOps/Ops.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/Functional.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/Passes.h" |
| #include "mlir/Transforms/Utils.h" |
| |
| #include "llvm/IR/DerivedTypes.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/Type.h" |
| |
| using namespace mlir; |
| |
| LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) |
| : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) { |
| assert(llvmDialect && "LLVM IR dialect is not registered"); |
| module = &llvmDialect->getLLVMModule(); |
| } |
| |
| // Get the LLVM context. |
| llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { |
| return module->getContext(); |
| } |
| |
| // Extract an LLVM IR type from the LLVM IR dialect type. |
| LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) { |
| if (!type) |
| return nullptr; |
| auto *mlirContext = type.getContext(); |
| auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>(); |
| if (!wrappedLLVMType) |
| emitError(UnknownLoc::get(mlirContext), |
| "conversion resulted in a non-LLVM type"); |
| return wrappedLLVMType; |
| } |
| |
| LLVM::LLVMType LLVMTypeConverter::getIndexType() { |
| return LLVM::LLVMType::getIntNTy( |
| llvmDialect, module->getDataLayout().getPointerSizeInBits()); |
| } |
| |
| Type LLVMTypeConverter::convertIndexType(IndexType type) { |
| return getIndexType(); |
| } |
| |
| Type LLVMTypeConverter::convertIntegerType(IntegerType type) { |
| return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth()); |
| } |
| |
| Type LLVMTypeConverter::convertFloatType(FloatType type) { |
| switch (type.getKind()) { |
| case mlir::StandardTypes::F32: |
| return LLVM::LLVMType::getFloatTy(llvmDialect); |
| case mlir::StandardTypes::F64: |
| return LLVM::LLVMType::getDoubleTy(llvmDialect); |
| case mlir::StandardTypes::F16: |
| return LLVM::LLVMType::getHalfTy(llvmDialect); |
| case mlir::StandardTypes::BF16: { |
| auto *mlirContext = llvmDialect->getContext(); |
| return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"), |
| Type(); |
| } |
| default: |
| llvm_unreachable("non-float type in convertFloatType"); |
| } |
| } |
| |
| // Function types are converted to LLVM Function types by recursively converting |
| // argument and result types. If MLIR Function has zero results, the LLVM |
| // Function has one VoidType result. If MLIR Function has more than one result, |
| // they are into an LLVM StructType in their order of appearance. |
| Type LLVMTypeConverter::convertFunctionType(FunctionType type) { |
| // Convert argument types one by one and check for errors. |
| SmallVector<LLVM::LLVMType, 8> argTypes; |
| for (auto t : type.getInputs()) { |
| auto converted = convertType(t); |
| if (!converted) |
| return {}; |
| argTypes.push_back(unwrap(converted)); |
| } |
| |
| // If function does not return anything, create the void result type, |
| // if it returns on element, convert it, otherwise pack the result types into |
| // a struct. |
| LLVM::LLVMType resultType = |
| type.getNumResults() == 0 |
| ? LLVM::LLVMType::getVoidTy(llvmDialect) |
| : unwrap(packFunctionResults(type.getResults())); |
| if (!resultType) |
| return {}; |
| return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false) |
| .getPointerTo(); |
| } |
| |
| // Convert a MemRef to an LLVM type. The result is a MemRef descriptor which |
| // contains: |
| // 1. the pointer to the data buffer, followed by |
| // 2. an array containing as many 64-bit integers as the rank of the MemRef: |
| // the array represents the size, in number of elements, of the memref along |
| // the given dimension. For constant MemRef dimensions, the corresponding size |
| // entry is a constant whose runtime value must match the static value. |
| // TODO(ntv, zinenko): add assertions for the static cases. |
| // |
| // template <typename Elem, size_t Rank> |
| // struct { |
| // Elem *ptr; |
| // int64_t sizes[Rank]; // omitted when rank == 0 |
| // }; |
| static unsigned kPtrPosInMemRefDescriptor = 0; |
| static unsigned kSizePosInMemRefDescriptor = 1; |
| Type LLVMTypeConverter::convertMemRefType(MemRefType type) { |
| assert((type.getAffineMaps().empty() || |
| (type.getAffineMaps().size() == 1 && |
| type.getAffineMaps().back().isIdentity())) && |
| "Non-identity layout maps must have been normalized away"); |
| LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); |
| if (!elementType) |
| return {}; |
| auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); |
| auto indexTy = getIndexType(); |
| auto rank = type.getRank(); |
| if (rank > 0) { |
| auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank()); |
| return LLVM::LLVMType::getStructTy(ptrTy, arrayTy); |
| } |
| return LLVM::LLVMType::getStructTy(ptrTy); |
| } |
| |
| // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when |
| // n > 1. |
| // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and |
| // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. |
| Type LLVMTypeConverter::convertVectorType(VectorType type) { |
| auto elementType = unwrap(convertType(type.getElementType())); |
| if (!elementType) |
| return {}; |
| auto vectorType = |
| LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); |
| auto shape = type.getShape(); |
| for (int i = shape.size() - 2; i >= 0; --i) |
| vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); |
| return vectorType; |
| } |
| |
| // Dispatch based on the actual type. Return null type on error. |
| Type LLVMTypeConverter::convertStandardType(Type type) { |
| if (auto funcType = type.dyn_cast<FunctionType>()) |
| return convertFunctionType(funcType); |
| if (auto intType = type.dyn_cast<IntegerType>()) |
| return convertIntegerType(intType); |
| if (auto floatType = type.dyn_cast<FloatType>()) |
| return convertFloatType(floatType); |
| if (auto indexType = type.dyn_cast<IndexType>()) |
| return convertIndexType(indexType); |
| if (auto memRefType = type.dyn_cast<MemRefType>()) |
| return convertMemRefType(memRefType); |
| if (auto vectorType = type.dyn_cast<VectorType>()) |
| return convertVectorType(vectorType); |
| if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) |
| return llvmType; |
| |
| return {}; |
| } |
| |
| // Convert the element type of the memref `t` to to an LLVM type using |
| // `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it |
| // into the MLIR LLVM dialect type and return. |
| static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) { |
| auto elementType = t.getElementType(); |
| auto converted = lowering.convertType(elementType); |
| if (!converted) |
| return {}; |
| return converted.cast<LLVM::LLVMType>().getPointerTo(t.getMemorySpace()); |
| } |
| |
| LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, |
| LLVMTypeConverter &lowering_, |
| PatternBenefit benefit) |
| : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} |
| |
| namespace { |
| // Base class for Standard to LLVM IR op conversions. Matches the Op type |
| // provided as template argument. Carries a reference to the LLVM dialect in |
| // case it is necessary for rewriters. |
| template <typename SourceOp> |
| class LLVMLegalizationPattern : public LLVMOpLowering { |
| public: |
| // Construct a conversion pattern. |
| explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_, |
| LLVMTypeConverter &lowering_) |
| : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(), |
| lowering_), |
| dialect(dialect_) {} |
| |
| // Get the LLVM IR dialect. |
| LLVM::LLVMDialect &getDialect() const { return dialect; } |
| // Get the LLVM context. |
| llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); } |
| // Get the LLVM module in which the types are constructed. |
| llvm::Module &getModule() const { return dialect.getLLVMModule(); } |
| |
| // Get the MLIR type wrapping the LLVM integer type whose bit width is defined |
| // by the pointer size used in the LLVM module. |
| LLVM::LLVMType getIndexType() const { |
| return LLVM::LLVMType::getIntNTy( |
| &dialect, getModule().getDataLayout().getPointerSizeInBits()); |
| } |
| |
| // Get the MLIR type wrapping the LLVM i8* type. |
| LLVM::LLVMType getVoidPtrType() const { |
| return LLVM::LLVMType::getInt8PtrTy(&dialect); |
| } |
| |
| // Create an LLVM IR pseudo-operation defining the given index constant. |
| Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc, |
| uint64_t value) const { |
| auto attr = builder.getIntegerAttr(builder.getIndexType(), value); |
| return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr); |
| } |
| |
| // Extract raw data pointer value from a value representing a memref. |
| static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder, |
| Location loc, |
| Value *convertedMemRefValue, |
| Type elementTypePtr, |
| bool hasStaticShape) { |
| Value *buffer; |
| if (hasStaticShape) |
| return convertedMemRefValue; |
| else |
| return builder.create<LLVM::ExtractValueOp>(loc, elementTypePtr, |
| convertedMemRefValue, |
| builder.getIndexArrayAttr(0)); |
| return buffer; |
| } |
| |
| protected: |
| LLVM::LLVMDialect &dialect; |
| }; |
| |
| struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> { |
| using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto funcOp = cast<FuncOp>(op); |
| FunctionType type = funcOp.getType(); |
| SmallVector<Type, 4> argTypes; |
| argTypes.reserve(type.getNumInputs()); |
| SmallVector<unsigned, 4> promotedArgIndices; |
| promotedArgIndices.reserve(type.getNumInputs()); |
| |
| // Convert the original function arguments. Struct arguments are promoted to |
| // pointer to struct arguments to allow calling external functions with |
| // various ABIs (e.g. compiled from C/C++ on platform X). |
| TypeConverter::SignatureConversion result(type.getNumInputs()); |
| for (auto en : llvm::enumerate(type.getInputs())) { |
| auto t = en.value(); |
| auto converted = lowering.convertType(t); |
| if (!converted) |
| return matchFailure(); |
| if (t.isa<MemRefType>()) { |
| converted = converted.cast<LLVM::LLVMType>().getPointerTo(); |
| promotedArgIndices.push_back(en.index()); |
| } |
| argTypes.push_back(converted); |
| } |
| for (unsigned idx = 0, e = argTypes.size(); idx < e; ++idx) |
| result.addInputs(idx, argTypes[idx]); |
| |
| // Pack the result types into a struct. |
| Type packedResult; |
| if (type.getNumResults() != 0) { |
| if (!(packedResult = lowering.packFunctionResults(type.getResults()))) |
| return matchFailure(); |
| } |
| |
| // Create a new function with an updated signature. |
| auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); |
| rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), |
| newFuncOp.end()); |
| newFuncOp.setType(FunctionType::get( |
| result.getConvertedTypes(), |
| packedResult ? ArrayRef<Type>(packedResult) : llvm::None, |
| funcOp.getContext())); |
| |
| // Tell the rewriter to convert the region signature. |
| rewriter.applySignatureConversion(&newFuncOp.getBody(), result); |
| |
| // Insert loads from memref descriptor pointers in function bodies. |
| if (!newFuncOp.getBody().empty()) { |
| Block *firstBlock = &newFuncOp.getBody().front(); |
| rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); |
| for (unsigned idx : promotedArgIndices) { |
| BlockArgument *arg = firstBlock->getArgument(idx); |
| Value *loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg); |
| rewriter.replaceUsesOfBlockArgument(arg, loaded); |
| } |
| } |
| |
| rewriter.replaceOp(op, llvm::None); |
| return matchSuccess(); |
| } |
| }; |
| |
| // Basic lowering implementation for one-to-one rewriting from Standard Ops to |
| // LLVM Dialect Ops. |
| template <typename SourceOp, typename TargetOp> |
| struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { |
| using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; |
| using Super = OneToOneLLVMOpLowering<SourceOp, TargetOp>; |
| |
| // Convert the type of the result to an LLVM type, pass operands as is, |
| // preserve attributes. |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| unsigned numResults = op->getNumResults(); |
| |
| Type packedType; |
| if (numResults != 0) { |
| packedType = this->lowering.packFunctionResults( |
| llvm::to_vector<4>(op->getResultTypes())); |
| assert(packedType && "type conversion failed, such operation should not " |
| "have been matched"); |
| } |
| |
| auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands, |
| op->getAttrs()); |
| |
| // If the operation produced 0 or 1 result, return them immediately. |
| if (numResults == 0) |
| return rewriter.replaceOp(op, llvm::None), this->matchSuccess(); |
| if (numResults == 1) |
| return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), |
| this->matchSuccess(); |
| |
| // Otherwise, it had been converted to an operation producing a structure. |
| // Extract individual results from the structure and return them as list. |
| SmallVector<Value *, 4> results; |
| results.reserve(numResults); |
| for (unsigned i = 0; i < numResults; ++i) { |
| auto type = this->lowering.convertType(op->getResult(i)->getType()); |
| results.push_back(rewriter.create<LLVM::ExtractValueOp>( |
| op->getLoc(), type, newOp.getOperation()->getResult(0), |
| rewriter.getIndexArrayAttr(i))); |
| } |
| rewriter.replaceOp(op, results); |
| return this->matchSuccess(); |
| } |
| }; |
| |
| // Express `linearIndex` in terms of coordinates of `basis`. |
| // Returns the empty vector when linearIndex is out of the range [0, P] where |
| // P is the product of all the basis coordinates. |
| // |
| // Prerequisites: |
| // Basis is an array of nonnegative integers (signed type inherited from |
| // vector shape type). |
| static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis, |
| unsigned linearIndex) { |
| SmallVector<int64_t, 4> res; |
| res.reserve(basis.size()); |
| for (unsigned basisElement : llvm::reverse(basis)) { |
| res.push_back(linearIndex % basisElement); |
| linearIndex = linearIndex / basisElement; |
| } |
| if (linearIndex > 0) |
| return {}; |
| std::reverse(res.begin(), res.end()); |
| return res; |
| } |
| |
| // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect |
| // Ops for binary ops with one result. This supports higher-dimensional vector |
| // types. |
| template <typename SourceOp, typename TargetOp> |
| struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { |
| using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; |
| using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>; |
| |
| // Convert the type of the result to an LLVM type, pass operands as is, |
| // preserve attributes. |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| static_assert( |
| std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value, |
| "expected binary op"); |
| static_assert( |
| std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, |
| "expected single result op"); |
| static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, |
| SourceOp>::value, |
| "expected single result op"); |
| |
| auto loc = op->getLoc(); |
| auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>(); |
| |
| if (!llvmArrayTy.isArrayTy()) { |
| auto newOp = rewriter.create<TargetOp>( |
| op->getLoc(), operands[0]->getType(), operands, op->getAttrs()); |
| rewriter.replaceOp(op, newOp.getResult()); |
| return this->matchSuccess(); |
| } |
| |
| // Unroll iterated array type until we hit a non-array type. |
| auto llvmTy = llvmArrayTy; |
| SmallVector<int64_t, 4> arraySizes; |
| while (llvmTy.isArrayTy()) { |
| arraySizes.push_back(llvmTy.getArrayNumElements()); |
| llvmTy = llvmTy.getArrayElementType(); |
| } |
| assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type"); |
| auto llvmVectorTy = llvmTy; |
| |
| // Iteratively extract a position coordinates with basis `arraySize` from a |
| // `linearIndex` that is incremented at each step. This terminates when |
| // `linearIndex` exceeds the range specified by `arraySize`. |
| // This has the effect of fully unrolling the dimensions of the n-D array |
| // type, getting to the underlying vector element. |
| Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy); |
| unsigned ub = 1; |
| for (auto s : arraySizes) |
| ub *= s; |
| for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { |
| auto coords = getCoordinates(arraySizes, linearIndex); |
| // Linear index is out of bounds, we are done. |
| if (coords.empty()) |
| break; |
| |
| auto position = rewriter.getIndexArrayAttr(coords); |
| |
| // For this unrolled `position` corresponding to the `linearIndex`^th |
| // element, extract operand vectors |
| Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>( |
| loc, llvmVectorTy, operands[0], position); |
| Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>( |
| loc, llvmVectorTy, operands[1], position); |
| Value *newVal = rewriter.create<TargetOp>( |
| loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS}, |
| op->getAttrs()); |
| desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, |
| newVal, position); |
| } |
| rewriter.replaceOp(op, desc); |
| return this->matchSuccess(); |
| } |
| }; |
| |
| // Specific lowerings. |
| // FIXME: this should be tablegen'ed. |
| struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> { |
| using Super::Super; |
| }; |
| struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> { |
| using Super::Super; |
| }; |
| struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> { |
| using Super::Super; |
| }; |
| struct DivISOpLowering : public BinaryOpLLVMOpLowering<DivISOp, LLVM::SDivOp> { |
| using Super::Super; |
| }; |
| struct DivIUOpLowering : public BinaryOpLLVMOpLowering<DivIUOp, LLVM::UDivOp> { |
| using Super::Super; |
| }; |
| struct RemISOpLowering : public BinaryOpLLVMOpLowering<RemISOp, LLVM::SRemOp> { |
| using Super::Super; |
| }; |
| struct RemIUOpLowering : public BinaryOpLLVMOpLowering<RemIUOp, LLVM::URemOp> { |
| using Super::Super; |
| }; |
| struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> { |
| using Super::Super; |
| }; |
| struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> { |
| using Super::Super; |
| }; |
| struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> { |
| using Super::Super; |
| }; |
| struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> { |
| using Super::Super; |
| }; |
| struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> { |
| using Super::Super; |
| }; |
| struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> { |
| using Super::Super; |
| }; |
| struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> { |
| using Super::Super; |
| }; |
| struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> { |
| using Super::Super; |
| }; |
| struct SelectOpLowering |
| : public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> { |
| using Super::Super; |
| }; |
| struct ConstLLVMOpLowering |
| : public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> { |
| using Super::Super; |
| }; |
| |
| // Check if the MemRefType `type` is supported by the lowering. We currently |
| // only support memrefs with identity maps. |
| static bool isSupportedMemRefType(MemRefType type) { |
| return llvm::all_of(type.getAffineMaps(), |
| [](AffineMap map) { return map.isIdentity(); }); |
| } |
| |
| // An `alloc` is converted into a definition of a memref descriptor value and |
| // a call to `malloc` to allocate the underlying data buffer. The memref |
| // descriptor is of the LLVM structure type where the first element is a pointer |
| // to the (typed) data buffer, and the remaining elements serve to store |
| // dynamic sizes of the memref using LLVM-converted `index` type. |
| struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { |
| using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern; |
| |
| PatternMatchResult match(Operation *op) const override { |
| MemRefType type = cast<AllocOp>(op).getType(); |
| return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); |
| } |
| |
| void rewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto allocOp = cast<AllocOp>(op); |
| MemRefType type = allocOp.getType(); |
| |
| // Get actual sizes of the memref as values: static sizes are constant |
| // values and dynamic sizes are passed to 'alloc' as operands. In case of |
| // zero-dimensional memref, assume a scalar (size 1). |
| SmallVector<Value *, 4> sizes; |
| auto numOperands = allocOp.getNumOperands(); |
| sizes.reserve(numOperands); |
| unsigned i = 0; |
| for (int64_t s : type.getShape()) |
| sizes.push_back(s == -1 ? operands[i++] |
| : createIndexConstant(rewriter, op->getLoc(), s)); |
| if (sizes.empty()) |
| sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1)); |
| |
| // Compute the total number of memref elements. |
| Value *cumulativeSize = sizes.front(); |
| for (unsigned i = 1, e = sizes.size(); i < e; ++i) |
| cumulativeSize = rewriter.create<LLVM::MulOp>( |
| op->getLoc(), getIndexType(), |
| ArrayRef<Value *>{cumulativeSize, sizes[i]}); |
| |
| // Compute the total amount of bytes to allocate. |
| auto elementType = type.getElementType(); |
| assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) && |
| "invalid memref element type"); |
| uint64_t elementSize = 0; |
| if (auto vectorType = elementType.dyn_cast<VectorType>()) |
| elementSize = vectorType.getNumElements() * |
| llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); |
| else |
| elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); |
| cumulativeSize = rewriter.create<LLVM::MulOp>( |
| op->getLoc(), getIndexType(), |
| ArrayRef<Value *>{ |
| cumulativeSize, |
| createIndexConstant(rewriter, op->getLoc(), elementSize)}); |
| |
| // Insert the `malloc` declaration if it is not already present. |
| auto module = op->getParentOfType<ModuleOp>(); |
| FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc"); |
| if (!mallocFunc) { |
| auto mallocType = |
| rewriter.getFunctionType(getIndexType(), getVoidPtrType()); |
| mallocFunc = |
| FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType); |
| module.push_back(mallocFunc); |
| } |
| |
| // Allocate the underlying buffer and store a pointer to it in the MemRef |
| // descriptor. |
| Value *allocated = |
| rewriter |
| .create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(), |
| rewriter.getSymbolRefAttr(mallocFunc), |
| cumulativeSize) |
| .getResult(0); |
| auto structElementType = lowering.convertType(elementType); |
| auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo( |
| type.getMemorySpace()); |
| allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType, |
| ArrayRef<Value *>(allocated)); |
| |
| // Create the MemRef descriptor. |
| auto structType = lowering.convertType(type); |
| Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>( |
| op->getLoc(), structType, ArrayRef<Value *>{}); |
| |
| memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( |
| op->getLoc(), structType, memRefDescriptor, allocated, |
| rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor)); |
| |
| // Store dynamically allocated sizes in the descriptor. Static and dynamic |
| // sizes are all passed in as operands. |
| for (auto indexedSize : llvm::enumerate(sizes)) { |
| int64_t index = indexedSize.index(); |
| memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( |
| op->getLoc(), structType, memRefDescriptor, indexedSize.value(), |
| rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index})); |
| } |
| |
| // Return the final value of the descriptor. |
| rewriter.replaceOp(op, memRefDescriptor); |
| } |
| }; |
| |
| // Helper structure which extracts the necessary information from CallOp-like |
| // ops for the purpose of generating an LLVM::CallOp. |
| struct FunctionInfo { |
| FunctionType type; |
| CallInterfaceCallable callable; |
| }; |
| static FunctionInfo getFuncOp(ModuleOp module, CallOp op) { |
| return FunctionInfo{module.lookupSymbol<FuncOp>(op.getCallee()).getType(), |
| SymbolRefAttr::get(op.getCallee(), op.getContext())}; |
| } |
| static FunctionInfo getFuncOp(ModuleOp module, CallIndirectOp op) { |
| if (auto fAttr = op.getCallableForCallee().dyn_cast<SymbolRefAttr>()) |
| return FunctionInfo{module.lookupSymbol<FuncOp>(fAttr.getValue()).getType(), |
| fAttr}; |
| // Else, this must be an SSA value of FunctionType type. |
| Value *fValue = op.getCallableForCallee().get<Value *>(); |
| FunctionType fType = fValue->getType().cast<FunctionType>(); |
| return FunctionInfo{fType, fValue}; |
| } |
| template <typename CallOpType> |
| static LLVM::CallOp |
| createLLVMCall(FunctionInfo fInfo, ConversionPatternRewriter &rewriter, |
| Location loc, Type returnType, ArrayRef<Value *> operands) { |
| if (fInfo.callable.dyn_cast<Value *>()) |
| return rewriter.create<LLVM::CallOp>(loc, returnType, operands); |
| auto fAttr = fInfo.callable.get<SymbolRefAttr>(); |
| auto namedFAttr = rewriter.getNamedAttr("callee", fAttr); |
| return rewriter.create<LLVM::CallOp>(loc, returnType, operands, |
| ArrayRef<NamedAttribute>{namedFAttr}); |
| } |
| |
| // A CallOp automatically promotes MemRefType to a sequence of alloca/store and |
| // passes the pointer to the MemRef across function boundaries. |
| template <typename CallOpType> |
| struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> { |
| using LLVMLegalizationPattern<CallOpType>::LLVMLegalizationPattern; |
| using Super = CallOpInterfaceLowering<CallOpType>; |
| using Base = LLVMLegalizationPattern<CallOpType>; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| OperandAdaptor<CallOpType> transformed(operands); |
| auto callOp = cast<CallOpType>(op); |
| auto module = op->getParentOfType<ModuleOp>(); |
| FunctionInfo fInfo = getFuncOp(module, callOp); |
| auto functionType = fInfo.type; |
| |
| // Pack the result types into a struct. |
| Type packedResult; |
| unsigned numResults = callOp.getNumResults(); |
| if (numResults != 0) { |
| if (!(packedResult = |
| this->lowering.packFunctionResults(functionType.getResults()))) |
| return this->matchFailure(); |
| } |
| |
| SmallVector<Value *, 4> opOperands(op->getOperands()); |
| auto promoted = this->lowering.promoteMemRefDescriptors( |
| op->getLoc(), opOperands, operands, rewriter); |
| auto newOp = createLLVMCall<CallOpType>(fInfo, rewriter, op->getLoc(), |
| packedResult, promoted); |
| |
| // If < 2 results, packingdid not do anything and we can just return. |
| if (numResults < 2) { |
| SmallVector<Value *, 4> results(newOp.getResults()); |
| rewriter.replaceOp(op, results); |
| return this->matchSuccess(); |
| } |
| |
| // Otherwise, it had been converted to an operation producing a structure. |
| // Extract individual results from the structure and return them as list. |
| SmallVector<Value *, 4> results; |
| results.reserve(numResults); |
| for (unsigned i = 0; i < numResults; ++i) { |
| auto type = this->lowering.convertType(op->getResult(i)->getType()); |
| results.push_back(rewriter.create<LLVM::ExtractValueOp>( |
| op->getLoc(), type, newOp.getOperation()->getResult(0), |
| rewriter.getIndexArrayAttr(i))); |
| } |
| rewriter.replaceOp(op, results); |
| |
| return this->matchSuccess(); |
| } |
| }; |
| |
| struct CallOpLowering : public CallOpInterfaceLowering<CallOp> { |
| using Super::Super; |
| }; |
| |
| struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> { |
| using Super::Super; |
| }; |
| |
| // A `dealloc` is converted into a call to `free` on the underlying data buffer. |
| // The memref descriptor being an SSA value, there is no need to clean it up |
| // in any way. |
| struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { |
| using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| assert(operands.size() == 1 && "dealloc takes one operand"); |
| OperandAdaptor<DeallocOp> transformed(operands); |
| |
| // Insert the `free` declaration if it is not already present. |
| FuncOp freeFunc = |
| op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free"); |
| if (!freeFunc) { |
| auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); |
| freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); |
| op->getParentOfType<ModuleOp>().push_back(freeFunc); |
| } |
| |
| auto type = transformed.memref()->getType().cast<LLVM::LLVMType>(); |
| auto hasStaticShape = type.isPointerTy(); |
| Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0); |
| Value *bufferPtr = |
| extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(), |
| elementPtrType, hasStaticShape); |
| Value *casted = rewriter.create<LLVM::BitcastOp>( |
| op->getLoc(), getVoidPtrType(), bufferPtr); |
| rewriter.replaceOpWithNewOp<LLVM::CallOp>( |
| op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted); |
| return matchSuccess(); |
| } |
| }; |
| |
| struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { |
| using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern; |
| |
| PatternMatchResult match(Operation *op) const override { |
| auto memRefCastOp = cast<MemRefCastOp>(op); |
| MemRefType sourceType = |
| memRefCastOp.getOperand()->getType().cast<MemRefType>(); |
| MemRefType targetType = memRefCastOp.getType(); |
| return (isSupportedMemRefType(targetType) && |
| isSupportedMemRefType(sourceType)) |
| ? matchSuccess() |
| : matchFailure(); |
| } |
| |
| void rewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto memRefCastOp = cast<MemRefCastOp>(op); |
| OperandAdaptor<MemRefCastOp> transformed(operands); |
| // memref_cast is defined for source and destination memref types with the |
| // same element type, same mappings, same address space and same rank. |
| // Therefore a simple bitcast suffices. If not it is undefined behavior. |
| auto targetStructType = lowering.convertType(memRefCastOp.getType()); |
| rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType, |
| transformed.source()); |
| } |
| }; |
| |
| // A `dim` is converted to a constant for static sizes and to an access to the |
| // size stored in the memref descriptor for dynamic sizes. |
| struct DimOpLowering : public LLVMLegalizationPattern<DimOp> { |
| using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern; |
| |
| PatternMatchResult match(Operation *op) const override { |
| auto dimOp = cast<DimOp>(op); |
| MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>(); |
| return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); |
| } |
| |
| void rewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto dimOp = cast<DimOp>(op); |
| OperandAdaptor<DimOp> transformed(operands); |
| MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>(); |
| |
| auto shape = type.getShape(); |
| int64_t index = dimOp.getIndex(); |
| // Extract dynamic size from the memref descriptor and define static size |
| // as a constant. |
| if (ShapedType::isDynamic(shape[index])) |
| rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( |
| op, getIndexType(), transformed.memrefOrTensor(), |
| rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index})); |
| else |
| rewriter.replaceOp( |
| op, createIndexConstant(rewriter, op->getLoc(), shape[index])); |
| } |
| }; |
| |
| // Common base for load and store operations on MemRefs. Restricts the match |
| // to supported MemRef types. Provides functionality to emit code accessing a |
| // specific element of the underlying data buffer. |
| template <typename Derived> |
| struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { |
| using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern; |
| using Base = LoadStoreOpLowering<Derived>; |
| |
| PatternMatchResult match(Operation *op) const override { |
| MemRefType type = cast<Derived>(op).getMemRefType(); |
| return isSupportedMemRefType(type) ? this->matchSuccess() |
| : this->matchFailure(); |
| } |
| |
| // Given subscript indices and array sizes in row-major order, |
| // i_n, i_{n-1}, ..., i_1 |
| // s_n, s_{n-1}, ..., s_1 |
| // obtain a value that corresponds to the linearized subscript |
| // \sum_k i_k * \prod_{j=1}^{k-1} s_j |
| // by accumulating the running linearized value. |
| // Note that `indices` and `allocSizes` are passed in the same order as they |
| // appear in load/store operations and memref type declarations. |
| Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, |
| ArrayRef<Value *> indices, |
| ArrayRef<Value *> allocSizes) const { |
| assert(indices.size() == allocSizes.size() && |
| "mismatching number of indices and allocation sizes"); |
| assert(!indices.empty() && "cannot linearize a 0-dimensional access"); |
| |
| Value *linearized = indices.front(); |
| for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { |
| linearized = builder.create<LLVM::MulOp>( |
| loc, this->getIndexType(), |
| ArrayRef<Value *>{linearized, allocSizes[i]}); |
| linearized = builder.create<LLVM::AddOp>( |
| loc, this->getIndexType(), ArrayRef<Value *>{linearized, indices[i]}); |
| } |
| return linearized; |
| } |
| |
| // Given the MemRef type, a descriptor and a list of indices, extract the data |
| // buffer pointer from the descriptor, convert multi-dimensional subscripts |
| // into a linearized index (using dynamic size data from the descriptor if |
| // necessary) and get the pointer to the buffer element identified by the |
| // indices. |
| Value *getElementPtr(Location loc, Type elementTypePtr, |
| ArrayRef<int64_t> shape, Value *memRefDescriptor, |
| ArrayRef<Value *> indices, |
| ConversionPatternRewriter &rewriter) const { |
| // Get the list of MemRef sizes. Static sizes are defined as constants. |
| // Dynamic sizes are extracted from the MemRef descriptor, where they start |
| // from the position 1 (the buffer is at position 0). |
| SmallVector<Value *, 4> sizes; |
| for (auto en : llvm::enumerate(shape)) { |
| int64_t s = en.value(); |
| int64_t index = en.index(); |
| if (s == -1) { |
| Value *size = rewriter.create<LLVM::ExtractValueOp>( |
| loc, this->getIndexType(), memRefDescriptor, |
| rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index})); |
| sizes.push_back(size); |
| } else { |
| sizes.push_back(this->createIndexConstant(rewriter, loc, s)); |
| // TODO(ntv, zinenko): assert dynamic descriptor size is constant. |
| } |
| } |
| |
| // The second and subsequent operands are access subscripts. Obtain the |
| // linearized address in the buffer. |
| Value *subscript = indices.empty() |
| ? nullptr |
| : linearizeSubscripts(rewriter, loc, indices, sizes); |
| |
| Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>( |
| loc, elementTypePtr, memRefDescriptor, |
| rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor)); |
| SmallVector<Value *, 2> gepSubValues(1, dataPtr); |
| if (subscript) |
| gepSubValues.push_back(subscript); |
| return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, gepSubValues, |
| ArrayRef<NamedAttribute>{}); |
| } |
| Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr, |
| ArrayRef<Value *> indices, |
| ConversionPatternRewriter &rewriter, |
| llvm::Module &module) const { |
| auto ptrType = getMemRefElementPtrType(type, this->lowering); |
| auto shape = type.getShape(); |
| return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter); |
| } |
| }; |
| |
| // Load operation is lowered to obtaining a pointer to the indexed element |
| // and loading it. |
| struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> { |
| using Base::Base; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loadOp = cast<LoadOp>(op); |
| OperandAdaptor<LoadOp> transformed(operands); |
| auto type = loadOp.getMemRefType(); |
| |
| Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), |
| transformed.indices(), rewriter, getModule()); |
| rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr); |
| return matchSuccess(); |
| } |
| }; |
| |
| // Store opreation is lowered to obtaining a pointer to the indexed element, |
| // and storing the given value to it. |
| struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> { |
| using Base::Base; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto type = cast<StoreOp>(op).getMemRefType(); |
| OperandAdaptor<StoreOp> transformed(operands); |
| |
| Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), |
| transformed.indices(), rewriter, getModule()); |
| rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(), |
| dataPtr); |
| return matchSuccess(); |
| } |
| }; |
| |
| // The lowering of index_cast becomes an integer conversion since index becomes |
| // an integer. If the bit width of the source and target integer types is the |
| // same, just erase the cast. If the target type is wider, sign-extend the |
| // value, otherwise truncate it. |
| struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> { |
| using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| IndexCastOpOperandAdaptor transformed(operands); |
| auto indexCastOp = cast<IndexCastOp>(op); |
| |
| auto targetType = |
| this->lowering.convertType(indexCastOp.getResult()->getType()) |
| .cast<LLVM::LLVMType>(); |
| auto sourceType = transformed.in()->getType().cast<LLVM::LLVMType>(); |
| unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth(); |
| unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth(); |
| |
| if (targetBits == sourceBits) |
| rewriter.replaceOp(op, transformed.in()); |
| else if (targetBits < sourceBits) |
| rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, |
| transformed.in()); |
| else |
| rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, |
| transformed.in()); |
| return matchSuccess(); |
| } |
| }; |
| |
| // Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two |
| // enums share the numerical values so just cast. |
| template <typename LLVMPredType, typename StdPredType> |
| static LLVMPredType convertCmpPredicate(StdPredType pred) { |
| return static_cast<LLVMPredType>(pred); |
| } |
| |
| struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> { |
| using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto cmpiOp = cast<CmpIOp>(op); |
| CmpIOpOperandAdaptor transformed(operands); |
| |
| rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( |
| op, lowering.convertType(cmpiOp.getResult()->getType()), |
| rewriter.getI64IntegerAttr(static_cast<int64_t>( |
| convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))), |
| transformed.lhs(), transformed.rhs()); |
| |
| return matchSuccess(); |
| } |
| }; |
| |
| struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> { |
| using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto cmpfOp = cast<CmpFOp>(op); |
| CmpFOpOperandAdaptor transformed(operands); |
| |
| rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( |
| op, lowering.convertType(cmpfOp.getResult()->getType()), |
| rewriter.getI64IntegerAttr(static_cast<int64_t>( |
| convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))), |
| transformed.lhs(), transformed.rhs()); |
| |
| return matchSuccess(); |
| } |
| }; |
| |
| struct SIToFPLowering |
| : public OneToOneLLVMOpLowering<SIToFPOp, LLVM::SIToFPOp> { |
| using Super::Super; |
| }; |
| |
| struct SignExtendIOpLowering |
| : public OneToOneLLVMOpLowering<SignExtendIOp, LLVM::SExtOp> { |
| using Super::Super; |
| }; |
| |
| struct TruncateIOpLowering |
| : public OneToOneLLVMOpLowering<TruncateIOp, LLVM::TruncOp> { |
| using Super::Super; |
| }; |
| |
| struct ZeroExtendIOpLowering |
| : public OneToOneLLVMOpLowering<ZeroExtendIOp, LLVM::ZExtOp> { |
| using Super::Super; |
| }; |
| |
| // Base class for LLVM IR lowering terminator operations with successors. |
| template <typename SourceOp, typename TargetOp> |
| struct OneToOneLLVMTerminatorLowering |
| : public LLVMLegalizationPattern<SourceOp> { |
| using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; |
| using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands, |
| ArrayRef<Block *> destinations, |
| ArrayRef<ArrayRef<Value *>> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations, |
| operands, op->getAttrs()); |
| return this->matchSuccess(); |
| } |
| }; |
| |
| // Special lowering pattern for `ReturnOps`. Unlike all other operations, |
| // `ReturnOp` interacts with the function signature and must have as many |
| // operands as the function has return values. Because in LLVM IR, functions |
| // can only return 0 or 1 value, we pack multiple values into a structure type. |
| // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if |
| // necessary before returning it |
| struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { |
| using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| unsigned numArguments = op->getNumOperands(); |
| |
| // If ReturnOp has 0 or 1 operand, create it and return immediately. |
| if (numArguments == 0) { |
| rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, llvm::ArrayRef<Value *>(), |
| llvm::ArrayRef<Block *>(), |
| op->getAttrs()); |
| return matchSuccess(); |
| } |
| if (numArguments == 1) { |
| rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( |
| op, llvm::ArrayRef<Value *>(operands.front()), |
| llvm::ArrayRef<Block *>(), op->getAttrs()); |
| return matchSuccess(); |
| } |
| |
| // Otherwise, we need to pack the arguments into an LLVM struct type before |
| // returning. |
| auto packedType = |
| lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes())); |
| |
| Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType); |
| for (unsigned i = 0; i < numArguments; ++i) { |
| packed = rewriter.create<LLVM::InsertValueOp>( |
| op->getLoc(), packedType, packed, operands[i], |
| rewriter.getIndexArrayAttr(i)); |
| } |
| rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, llvm::makeArrayRef(packed), |
| llvm::ArrayRef<Block *>(), |
| op->getAttrs()); |
| return matchSuccess(); |
| } |
| }; |
| |
| // FIXME: this should be tablegen'ed as well. |
| struct BranchOpLowering |
| : public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> { |
| using Super::Super; |
| }; |
| struct CondBranchOpLowering |
| : public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> { |
| using Super::Super; |
| }; |
| |
| // The Splat operation is lowered to an insertelement + a shufflevector |
| // operation. Splat to only 1-d vector result types are lowered. |
| struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> { |
| using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern; |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto splatOp = cast<SplatOp>(op); |
| VectorType resultType = splatOp.getType().dyn_cast<VectorType>(); |
| if (!resultType || resultType.getRank() != 1) |
| return matchFailure(); |
| |
| // First insert it into an undef vector so we can shuffle it. |
| auto vectorType = lowering.convertType(splatOp.getType()); |
| Value *undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType); |
| auto zero = rewriter.create<LLVM::ConstantOp>( |
| op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)), |
| rewriter.getZeroAttr(rewriter.getIntegerType(32))); |
| |
| auto v = rewriter.create<LLVM::InsertElementOp>( |
| op->getLoc(), vectorType, undef, splatOp.getOperand(), zero); |
| |
| int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0); |
| SmallVector<int32_t, 4> zeroValues(width, 0); |
| |
| // Shuffle the value across the desired number of elements. |
| ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); |
| rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs); |
| return matchSuccess(); |
| } |
| }; |
| |
| } // namespace |
| |
| static void ensureDistinctSuccessors(Block &bb) { |
| auto *terminator = bb.getTerminator(); |
| |
| // Find repeated successors with arguments. |
| llvm::SmallDenseMap<Block *, llvm::SmallVector<int, 4>> successorPositions; |
| for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) { |
| Block *successor = terminator->getSuccessor(i); |
| // Blocks with no arguments are safe even if they appear multiple times |
| // because they don't need PHI nodes. |
| if (successor->getNumArguments() == 0) |
| continue; |
| successorPositions[successor].push_back(i); |
| } |
| |
| // If a successor appears for the second or more time in the terminator, |
| // create a new dummy block that unconditionally branches to the original |
| // destination, and retarget the terminator to branch to this new block. |
| // There is no need to pass arguments to the dummy block because it will be |
| // dominated by the original block and can therefore use any values defined in |
| // the original block. |
| for (const auto &successor : successorPositions) { |
| const auto &positions = successor.second; |
| // Start from the second occurrence of a block in the successor list. |
| for (auto position = std::next(positions.begin()), end = positions.end(); |
| position != end; ++position) { |
| auto *dummyBlock = new Block(); |
| bb.getParent()->push_back(dummyBlock); |
| auto builder = OpBuilder(dummyBlock); |
| SmallVector<Value *, 8> operands( |
| terminator->getSuccessorOperands(*position)); |
| builder.create<BranchOp>(terminator->getLoc(), successor.first, operands); |
| terminator->setSuccessor(dummyBlock, *position); |
| for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e; |
| ++i) |
| terminator->eraseSuccessorOperand(*position, i); |
| } |
| } |
| } |
| |
| void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { |
| for (auto f : m.getOps<FuncOp>()) { |
| for (auto &bb : f.getBlocks()) { |
| ::ensureDistinctSuccessors(bb); |
| } |
| } |
| } |
| |
| /// Collect a set of patterns to convert from the Standard dialect to LLVM. |
| void mlir::populateStdToLLVMConversionPatterns( |
| LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { |
| // FIXME: this should be tablegen'ed |
| patterns.insert< |
| AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering, |
| BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, |
| CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, |
| DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering, |
| DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering, |
| MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, |
| RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, |
| SelectOpLowering, SIToFPLowering, SignExtendIOpLowering, SplatOpLowering, |
| StoreOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering, |
| XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter); |
| } |
| |
| // Convert types using the stored LLVM IR module. |
| Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); } |
| |
| // Create an LLVM IR structure type if there is more than one result. |
| Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) { |
| assert(!types.empty() && "expected non-empty list of type"); |
| |
| if (types.size() == 1) |
| return convertType(types.front()); |
| |
| SmallVector<LLVM::LLVMType, 8> resultTypes; |
| resultTypes.reserve(types.size()); |
| for (auto t : types) { |
| auto converted = convertType(t).dyn_cast<LLVM::LLVMType>(); |
| if (!converted) |
| return {}; |
| resultTypes.push_back(converted); |
| } |
| |
| return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); |
| } |
| |
| Value *LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, |
| Value *operand, |
| OpBuilder &builder) { |
| auto *context = builder.getContext(); |
| auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect()); |
| auto indexType = IndexType::get(context); |
| // Alloca with proper alignment. We do not expect optimizations of this |
| // alloca op and so we omit allocating at the entry block. |
| auto ptrType = operand->getType().cast<LLVM::LLVMType>().getPointerTo(); |
| Value *one = builder.create<LLVM::ConstantOp>(loc, int64Ty, |
| IntegerAttr::get(indexType, 1)); |
| Value *allocated = |
| builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0); |
| // Store into the alloca'ed descriptor. |
| builder.create<LLVM::StoreOp>(loc, operand, allocated); |
| return allocated; |
| } |
| |
| SmallVector<Value *, 4> LLVMTypeConverter::promoteMemRefDescriptors( |
| Location loc, ArrayRef<Value *> opOperands, ArrayRef<Value *> operands, |
| OpBuilder &builder) { |
| SmallVector<Value *, 4> promotedOperands; |
| promotedOperands.reserve(operands.size()); |
| for (auto it : llvm::zip(opOperands, operands)) { |
| auto *operand = std::get<0>(it); |
| auto *llvmOperand = std::get<1>(it); |
| if (!operand->getType().isa<MemRefType>()) { |
| promotedOperands.push_back(operand); |
| continue; |
| } |
| promotedOperands.push_back( |
| promoteOneMemRefDescriptor(loc, llvmOperand, builder)); |
| } |
| return promotedOperands; |
| } |
| |
| /// Create an instance of LLVMTypeConverter in the given context. |
| static std::unique_ptr<LLVMTypeConverter> |
| makeStandardToLLVMTypeConverter(MLIRContext *context) { |
| return std::make_unique<LLVMTypeConverter>(context); |
| } |
| |
| namespace { |
| /// A pass converting MLIR operations into the LLVM IR dialect. |
| struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> { |
| // By default, the patterns are those converting Standard operations to the |
| // LLVMIR dialect. |
| explicit LLVMLoweringPass( |
| LLVMPatternListFiller patternListFiller = |
| populateStdToLLVMConversionPatterns, |
| LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter) |
| : patternListFiller(patternListFiller), |
| typeConverterMaker(converterBuilder) {} |
| |
| // Run the dialect converter on the module. |
| void runOnModule() override { |
| if (!typeConverterMaker || !patternListFiller) |
| return signalPassFailure(); |
| |
| ModuleOp m = getModule(); |
| LLVM::ensureDistinctSuccessors(m); |
| std::unique_ptr<LLVMTypeConverter> typeConverter = |
| typeConverterMaker(&getContext()); |
| if (!typeConverter) |
| return signalPassFailure(); |
| |
| OwningRewritePatternList patterns; |
| populateLoopToStdConversionPatterns(patterns, m.getContext()); |
| patternListFiller(*typeConverter, patterns); |
| |
| ConversionTarget target(getContext()); |
| target.addLegalDialect<LLVM::LLVMDialect>(); |
| target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { |
| return typeConverter->isSignatureLegal(op.getType()); |
| }); |
| if (failed(applyPartialConversion(m, target, patterns, &*typeConverter))) |
| signalPassFailure(); |
| } |
| |
| // Callback for creating a list of patterns. It is called every time in |
| // runOnModule since applyPartialConversion consumes the list. |
| LLVMPatternListFiller patternListFiller; |
| |
| // Callback for creating an instance of type converter. The converter |
| // constructor needs an MLIRContext, which is not available until runOnModule. |
| LLVMTypeConverterMaker typeConverterMaker; |
| }; |
| } // end namespace |
| |
| std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerToLLVMPass() { |
| return std::make_unique<LLVMLoweringPass>(); |
| } |
| |
| std::unique_ptr<OpPassBase<ModuleOp>> |
| mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, |
| LLVMTypeConverterMaker typeConverterMaker) { |
| return std::make_unique<LLVMLoweringPass>(patternListFiller, |
| typeConverterMaker); |
| } |
| |
| static PassRegistration<LLVMLoweringPass> |
| pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect"); |