blob: c62a5d8719d339f81b3f015dc29ec62b8696a2d0 [file] [log] [blame]
//===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a pass to convert MLIR standard and builtin dialects
// into the LLVM IR dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Support/Functional.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
using namespace mlir;
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
assert(llvmDialect && "LLVM IR dialect is not registered");
module = &llvmDialect->getLLVMModule();
}
// Get the LLVM context.
llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
return module->getContext();
}
// Extract an LLVM IR type from the LLVM IR dialect type.
LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
if (!wrappedLLVMType)
emitError(UnknownLoc::get(mlirContext),
"conversion resulted in a non-LLVM type");
return wrappedLLVMType;
}
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
return LLVM::LLVMType::getIntNTy(
llvmDialect, module->getDataLayout().getPointerSizeInBits());
}
Type LLVMTypeConverter::convertIndexType(IndexType type) {
return getIndexType();
}
Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
}
Type LLVMTypeConverter::convertFloatType(FloatType type) {
switch (type.getKind()) {
case mlir::StandardTypes::F32:
return LLVM::LLVMType::getFloatTy(llvmDialect);
case mlir::StandardTypes::F64:
return LLVM::LLVMType::getDoubleTy(llvmDialect);
case mlir::StandardTypes::F16:
return LLVM::LLVMType::getHalfTy(llvmDialect);
case mlir::StandardTypes::BF16: {
auto *mlirContext = llvmDialect->getContext();
return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"),
Type();
}
default:
llvm_unreachable("non-float type in convertFloatType");
}
}
// Function types are converted to LLVM Function types by recursively converting
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
// Convert argument types one by one and check for errors.
SmallVector<LLVM::LLVMType, 8> argTypes;
for (auto t : type.getInputs()) {
auto converted = convertType(t);
if (!converted)
return {};
argTypes.push_back(unwrap(converted));
}
// If function does not return anything, create the void result type,
// if it returns on element, convert it, otherwise pack the result types into
// a struct.
LLVM::LLVMType resultType =
type.getNumResults() == 0
? LLVM::LLVMType::getVoidTy(llvmDialect)
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false)
.getPointerTo();
}
// Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
// we return a pointer to the converted element type. Otherwise we return an
// LLVM stucture type, where the first element of the structure type is a
// pointer to the elemental type of the MemRef and the following N elements are
// values of the Index type, one for each of N dynamic dimensions of the MemRef.
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
auto ptrType = elementType.getPointerTo();
// Extra value for the memory space.
unsigned numDynamicSizes = type.getNumDynamicDims();
// If memref is statically-shaped we return the underlying pointer type.
if (numDynamicSizes == 0)
return ptrType;
SmallVector<LLVM::LLVMType, 8> types(numDynamicSizes + 1, getIndexType());
types.front() = ptrType;
return LLVM::LLVMType::getStructTy(llvmDialect, types);
}
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
// n > 1.
// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
// `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
Type LLVMTypeConverter::convertVectorType(VectorType type) {
auto elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
auto vectorType =
LLVM::LLVMType::getVectorTy(elementType, type.getShape().back());
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]);
return vectorType;
}
// Dispatch based on the actual type. Return null type on error.
Type LLVMTypeConverter::convertStandardType(Type type) {
if (auto funcType = type.dyn_cast<FunctionType>())
return convertFunctionType(funcType);
if (auto intType = type.dyn_cast<IntegerType>())
return convertIntegerType(intType);
if (auto floatType = type.dyn_cast<FloatType>())
return convertFloatType(floatType);
if (auto indexType = type.dyn_cast<IndexType>())
return convertIndexType(indexType);
if (auto memRefType = type.dyn_cast<MemRefType>())
return convertMemRefType(memRefType);
if (auto vectorType = type.dyn_cast<VectorType>())
return convertVectorType(vectorType);
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
return llvmType;
return {};
}
// Convert the element type of the memref `t` to to an LLVM type using
// `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it
// into the MLIR LLVM dialect type and return.
static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) {
auto elementType = t.getElementType();
auto converted = lowering.convertType(elementType);
if (!converted)
return {};
return converted.cast<LLVM::LLVMType>().getPointerTo();
}
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
LLVMTypeConverter &lowering_)
: ConversionPattern(rootOpName, /*benefit=*/1, context),
lowering(lowering_) {}
namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
// case it is necessary for rewriters.
template <typename SourceOp>
class LLVMLegalizationPattern : public LLVMOpLowering {
public:
// Construct a conversion pattern.
explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
LLVMTypeConverter &lowering_)
: LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(),
lowering_),
dialect(dialect_) {}
// Get the LLVM IR dialect.
LLVM::LLVMDialect &getDialect() const { return dialect; }
// Get the LLVM context.
llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); }
// Get the LLVM module in which the types are constructed.
llvm::Module &getModule() const { return dialect.getLLVMModule(); }
// Get the MLIR type wrapping the LLVM integer type whose bit width is defined
// by the pointer size used in the LLVM module.
LLVM::LLVMType getIndexType() const {
return LLVM::LLVMType::getIntNTy(
&dialect, getModule().getDataLayout().getPointerSizeInBits());
}
// Get the MLIR type wrapping the LLVM i8* type.
LLVM::LLVMType getVoidPtrType() const {
return LLVM::LLVMType::getInt8PtrTy(&dialect);
}
// Create an LLVM IR pseudo-operation defining the given index constant.
Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc,
uint64_t value) const {
auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
}
// Get the array attribute named "position" containing the given list of
// integers as integer attribute elements.
static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder,
ArrayRef<int64_t> values) {
SmallVector<Attribute, 4> attrs;
attrs.reserve(values.size());
for (int64_t pos : values)
attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos));
return builder.getArrayAttr(attrs);
}
// Extract raw data pointer value from a value representing a memref.
static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
Location loc,
Value *convertedMemRefValue,
Type elementTypePtr,
bool hasStaticShape) {
Value *buffer;
if (hasStaticShape)
return convertedMemRefValue;
else
return builder.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, convertedMemRefValue,
getIntegerArrayAttr(builder, 0));
return buffer;
}
protected:
LLVM::LLVMDialect &dialect;
};
struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
FunctionType type = funcOp.getType();
// Convert the original function arguments.
TypeConverter::SignatureConversion result(type.getNumInputs());
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
if (failed(lowering.convertSignatureArg(i, type.getInput(i), result)))
return matchFailure();
// Pack the result types into a struct.
Type packedResult;
if (type.getNumResults() != 0) {
if (!(packedResult = lowering.packFunctionResults(type.getResults())))
return matchFailure();
}
// Create a new function with an updated signature.
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
newFuncOp.setType(FunctionType::get(
result.getConvertedTypes(),
packedResult ? ArrayRef<Type>(packedResult) : llvm::None,
funcOp.getContext()));
// Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
};
// Basic lowering implementation for one-to-one rewriting from Standard Ops to
// LLVM Dialect Ops.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = OneToOneLLVMOpLowering<SourceOp, TargetOp>;
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numResults = op->getNumResults();
Type packedType;
if (numResults != 0) {
packedType = this->lowering.packFunctionResults(
llvm::to_vector<4>(op->getResultTypes()));
assert(packedType && "type conversion failed, such operation should not "
"have been matched");
}
auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
op->getAttrs());
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
return rewriter.replaceOp(op, llvm::None), this->matchSuccess();
if (numResults == 1)
return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
this->matchSuccess();
// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
SmallVector<Value *, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
auto type = this->lowering.convertType(op->getResult(i)->getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp.getOperation()->getResult(0),
this->getIntegerArrayAttr(rewriter, i)));
}
rewriter.replaceOp(op, results);
return this->matchSuccess();
}
};
// Specific lowerings.
// FIXME: this should be tablegen'ed.
struct AddIOpLowering : public OneToOneLLVMOpLowering<AddIOp, LLVM::AddOp> {
using Super::Super;
};
struct SubIOpLowering : public OneToOneLLVMOpLowering<SubIOp, LLVM::SubOp> {
using Super::Super;
};
struct MulIOpLowering : public OneToOneLLVMOpLowering<MulIOp, LLVM::MulOp> {
using Super::Super;
};
struct DivISOpLowering : public OneToOneLLVMOpLowering<DivISOp, LLVM::SDivOp> {
using Super::Super;
};
struct DivIUOpLowering : public OneToOneLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
using Super::Super;
};
struct RemISOpLowering : public OneToOneLLVMOpLowering<RemISOp, LLVM::SRemOp> {
using Super::Super;
};
struct RemIUOpLowering : public OneToOneLLVMOpLowering<RemIUOp, LLVM::URemOp> {
using Super::Super;
};
struct AndOpLowering : public OneToOneLLVMOpLowering<AndOp, LLVM::AndOp> {
using Super::Super;
};
struct OrOpLowering : public OneToOneLLVMOpLowering<OrOp, LLVM::OrOp> {
using Super::Super;
};
struct XOrOpLowering : public OneToOneLLVMOpLowering<XOrOp, LLVM::XOrOp> {
using Super::Super;
};
struct AddFOpLowering : public OneToOneLLVMOpLowering<AddFOp, LLVM::FAddOp> {
using Super::Super;
};
struct SubFOpLowering : public OneToOneLLVMOpLowering<SubFOp, LLVM::FSubOp> {
using Super::Super;
};
struct MulFOpLowering : public OneToOneLLVMOpLowering<MulFOp, LLVM::FMulOp> {
using Super::Super;
};
struct DivFOpLowering : public OneToOneLLVMOpLowering<DivFOp, LLVM::FDivOp> {
using Super::Super;
};
struct RemFOpLowering : public OneToOneLLVMOpLowering<RemFOp, LLVM::FRemOp> {
using Super::Super;
};
struct SelectOpLowering
: public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> {
using Super::Super;
};
struct CallOpLowering : public OneToOneLLVMOpLowering<CallOp, LLVM::CallOp> {
using Super::Super;
};
struct CallIndirectOpLowering
: public OneToOneLLVMOpLowering<CallIndirectOp, LLVM::CallOp> {
using Super::Super;
};
struct ConstLLVMOpLowering
: public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
using Super::Super;
};
// Check if the MemRefType `type` is supported by the lowering. We currently do
// not support memrefs with affine maps and non-default memory spaces.
static bool isSupportedMemRefType(MemRefType type) {
if (!type.getAffineMaps().empty())
return false;
if (type.getMemorySpace() != 0)
return false;
return true;
}
// An `alloc` is converted into a definition of a memref descriptor value and
// a call to `malloc` to allocate the underlying data buffer. The memref
// descriptor is of the LLVM structure type where the first element is a pointer
// to the (typed) data buffer, and the remaining elements serve to store
// dynamic sizes of the memref using LLVM-converted `index` type.
struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
PatternMatchResult match(Operation *op) const override {
MemRefType type = cast<AllocOp>(op).getType();
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
}
void rewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto allocOp = cast<AllocOp>(op);
MemRefType type = allocOp.getType();
// Get actual sizes of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands. In case of
// zero-dimensional memref, assume a scalar (size 1).
SmallVector<Value *, 4> sizes;
auto numOperands = allocOp.getNumOperands();
sizes.reserve(numOperands);
unsigned i = 0;
for (int64_t s : type.getShape())
sizes.push_back(s == -1 ? operands[i++]
: createIndexConstant(rewriter, op->getLoc(), s));
if (sizes.empty())
sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1));
// Compute the total number of memref elements.
Value *cumulativeSize = sizes.front();
for (unsigned i = 1, e = sizes.size(); i < e; ++i)
cumulativeSize = rewriter.create<LLVM::MulOp>(
op->getLoc(), getIndexType(),
ArrayRef<Value *>{cumulativeSize, sizes[i]});
// Compute the total amount of bytes to allocate.
auto elementType = type.getElementType();
assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) &&
"invalid memref element type");
uint64_t elementSize = 0;
if (auto vectorType = elementType.dyn_cast<VectorType>())
elementSize = vectorType.getNumElements() *
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
else
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
cumulativeSize = rewriter.create<LLVM::MulOp>(
op->getLoc(), getIndexType(),
ArrayRef<Value *>{
cumulativeSize,
createIndexConstant(rewriter, op->getLoc(), elementSize)});
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
if (!mallocFunc) {
auto mallocType =
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
mallocFunc =
FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
module.push_back(mallocFunc);
}
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
Value *allocated =
rewriter
.create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(),
rewriter.getSymbolRefAttr(mallocFunc),
cumulativeSize)
.getResult(0);
auto structElementType = lowering.convertType(elementType);
auto elementPtrType =
structElementType.cast<LLVM::LLVMType>().getPointerTo();
allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
ArrayRef<Value *>(allocated));
// Deal with static memrefs
if (numOperands == 0)
return rewriter.replaceOp(op, allocated);
// Create the MemRef descriptor.
auto structType = lowering.convertType(type);
Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
op->getLoc(), structType, ArrayRef<Value *>{});
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, allocated,
getIntegerArrayAttr(rewriter, 0));
// Store dynamically allocated sizes in the descriptor. Dynamic sizes are
// passed in as operands.
for (auto indexedSize : llvm::enumerate(operands)) {
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
getIntegerArrayAttr(rewriter, 1 + indexedSize.index()));
}
// Return the final value of the descriptor.
rewriter.replaceOp(op, memRefDescriptor);
}
};
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
FuncOp freeFunc =
op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
op->getParentOfType<ModuleOp>().push_back(freeFunc);
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
Value *bufferPtr =
extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(),
elementPtrType, hasStaticShape);
Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(), bufferPtr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
return matchSuccess();
}
};
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
PatternMatchResult match(Operation *op) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
MemRefType sourceType =
memRefCastOp.getOperand()->getType().cast<MemRefType>();
MemRefType targetType = memRefCastOp.getType();
return (isSupportedMemRefType(targetType) &&
isSupportedMemRefType(sourceType))
? matchSuccess()
: matchFailure();
}
void rewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
OperandAdaptor<MemRefCastOp> transformed(operands);
auto targetType = memRefCastOp.getType();
auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
// Copy the data buffer pointer.
auto elementTypePtr = getMemRefElementPtrType(targetType, lowering);
Value *buffer =
extractMemRefElementPtr(rewriter, op->getLoc(), transformed.source(),
elementTypePtr, sourceType.hasStaticShape());
// Account for static memrefs as target types
if (targetType.hasStaticShape())
return rewriter.replaceOp(op, buffer);
// Create the new MemRef descriptor.
auto structType = lowering.convertType(targetType);
Value *newDescriptor = rewriter.create<LLVM::UndefOp>(
op->getLoc(), structType, ArrayRef<Value *>{});
// Otherwise target type is dynamic memref, so create a proper descriptor.
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, newDescriptor, buffer,
getIntegerArrayAttr(rewriter, 0));
// Fill in the dynamic sizes of the new descriptor. If the size was
// dynamic, copy it from the old descriptor. If the size was static, insert
// the constant. Note that the positions of dynamic sizes in the
// descriptors start from 1 (the buffer pointer is at position zero).
int64_t sourceDynamicDimIdx = 1;
int64_t targetDynamicDimIdx = 1;
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
// Ignore new static sizes (they will be known from the type). If the
// size was dynamic, update the index of dynamic types.
if (targetType.getShape()[i] != -1) {
if (sourceType.getShape()[i] == -1)
++sourceDynamicDimIdx;
continue;
}
auto sourceSize = sourceType.getShape()[i];
Value *size =
sourceSize == -1
? rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(),
transformed.source(), // NB: dynamic memref
getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++))
: createIndexConstant(rewriter, op->getLoc(), sourceSize);
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, newDescriptor, size,
getIntegerArrayAttr(rewriter, targetDynamicDimIdx++));
}
assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
"source dynamic dimensions were not processed");
assert(targetDynamicDimIdx - 1 == targetType.getNumDynamicDims() &&
"target dynamic dimensions were not set up");
rewriter.replaceOp(op, newDescriptor);
}
};
// A `dim` is converted to a constant for static sizes and to an access to the
// size stored in the memref descriptor for dynamic sizes.
struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
PatternMatchResult match(Operation *op) const override {
auto dimOp = cast<DimOp>(op);
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
}
void rewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<DimOp>(op);
OperandAdaptor<DimOp> transformed(operands);
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
auto shape = type.getShape();
uint64_t index = dimOp.getIndex();
// Extract dynamic size from the memref descriptor and define static size
// as a constant.
if (shape[index] == -1) {
// Find the position of the dynamic dimension in the list of dynamic sizes
// by counting the number of preceding dynamic dimensions. Start from 1
// because the buffer pointer is at position zero.
int64_t position = 1;
for (uint64_t i = 0; i < index; ++i) {
if (shape[i] == -1)
++position;
}
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
op, getIndexType(), transformed.memrefOrTensor(),
getIntegerArrayAttr(rewriter, position));
} else {
rewriter.replaceOp(
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
}
}
};
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
using Base = LoadStoreOpLowering<Derived>;
PatternMatchResult match(Operation *op) const override {
MemRefType type = cast<Derived>(op).getMemRefType();
return isSupportedMemRefType(type) ? this->matchSuccess()
: this->matchFailure();
}
// Given subscript indices and array sizes in row-major order,
// i_n, i_{n-1}, ..., i_1
// s_n, s_{n-1}, ..., s_1
// obtain a value that corresponds to the linearized subscript
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
// by accumulating the running linearized value.
// Note that `indices` and `allocSizes` are passed in the same order as they
// appear in load/store operations and memref type declarations.
Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
ArrayRef<Value *> indices,
ArrayRef<Value *> allocSizes) const {
assert(indices.size() == allocSizes.size() &&
"mismatching number of indices and allocation sizes");
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
Value *linearized = indices.front();
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
linearized = builder.create<LLVM::MulOp>(
loc, this->getIndexType(),
ArrayRef<Value *>{linearized, allocSizes[i]});
linearized = builder.create<LLVM::AddOp>(
loc, this->getIndexType(), ArrayRef<Value *>{linearized, indices[i]});
}
return linearized;
}
// Given the MemRef type, a descriptor and a list of indices, extract the data
// buffer pointer from the descriptor, convert multi-dimensional subscripts
// into a linearized index (using dynamic size data from the descriptor if
// necessary) and get the pointer to the buffer element identified by the
// indices.
Value *getElementPtr(Location loc, Type elementTypePtr,
ArrayRef<int64_t> shape, Value *memRefDescriptor,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter) const {
// Get the list of MemRef sizes. Static sizes are defined as constants.
// Dynamic sizes are extracted from the MemRef descriptor, where they start
// from the position 1 (the buffer is at position 0).
SmallVector<Value *, 4> sizes;
unsigned dynamicSizeIdx = 1;
for (int64_t s : shape) {
if (s == -1) {
Value *size = rewriter.create<LLVM::ExtractValueOp>(
loc, this->getIndexType(), memRefDescriptor,
this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++));
sizes.push_back(size);
} else {
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
}
}
// The second and subsequent operands are access subscripts. Obtain the
// linearized address in the buffer.
Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, memRefDescriptor,
this->getIntegerArrayAttr(rewriter, 0));
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
ArrayRef<Value *>{dataPtr, subscript},
ArrayRef<NamedAttribute>{});
}
// This is a getElementPtr variant, where the value is a direct raw pointer.
// If a shape is empty, we are dealing with a zero-dimensional memref. Return
// the pointer unmodified in this case. Otherwise, linearize subscripts to
// obtain the offset with respect to the base pointer. Use this offset to
// compute and return the element pointer.
Value *getRawElementPtr(Location loc, Type elementTypePtr,
ArrayRef<int64_t> shape, Value *rawDataPtr,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter) const {
if (shape.empty())
return rawDataPtr;
SmallVector<Value *, 4> sizes;
for (int64_t s : shape) {
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
}
Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
return rewriter.create<LLVM::GEPOp>(
loc, elementTypePtr, ArrayRef<Value *>{rawDataPtr, subscript},
ArrayRef<NamedAttribute>{});
}
Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
auto ptrType = getMemRefElementPtrType(type, this->lowering);
auto shape = type.getShape();
if (type.hasStaticShape()) {
// NB: If memref was statically-shaped, dataPtr is pointer to raw data.
return getRawElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
}
return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
}
};
// Load operation is lowered to obtaining a pointer to the indexed element
// and loading it.
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
using Base::Base;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loadOp = cast<LoadOp>(op);
OperandAdaptor<LoadOp> transformed(operands);
auto type = loadOp.getMemRefType();
Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
auto elementType = lowering.convertType(type.getElementType());
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elementType,
ArrayRef<Value *>{dataPtr});
return matchSuccess();
}
};
// Store opreation is lowered to obtaining a pointer to the indexed element,
// and storing the given value to it.
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
using Base::Base;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto type = cast<StoreOp>(op).getMemRefType();
OperandAdaptor<StoreOp> transformed(operands);
Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
return matchSuccess();
}
};
// The lowering of index_cast becomes an integer conversion since index becomes
// an integer. If the bit width of the source and target integer types is the
// same, just erase the cast. If the target type is wider, sign-extend the
// value, otherwise truncate it.
struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
IndexCastOpOperandAdaptor transformed(operands);
auto indexCastOp = cast<IndexCastOp>(op);
auto targetType =
this->lowering.convertType(indexCastOp.getResult()->getType())
.cast<LLVM::LLVMType>();
auto sourceType = transformed.in()->getType().cast<LLVM::LLVMType>();
unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth();
unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth();
if (targetBits == sourceBits)
rewriter.replaceOp(op, transformed.in());
else if (targetBits < sourceBits)
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
transformed.in());
else
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
transformed.in());
return matchSuccess();
}
};
// Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two
// enums share the numerical values so just cast.
template <typename LLVMPredType, typename StdPredType>
static LLVMPredType convertCmpPredicate(StdPredType pred) {
return static_cast<LLVMPredType>(pred);
}
struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpiOp = cast<CmpIOp>(op);
CmpIOpOperandAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
op, lowering.convertType(cmpiOp.getResult()->getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
return matchSuccess();
}
};
struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpfOp = cast<CmpFOp>(op);
CmpFOpOperandAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
op, lowering.convertType(cmpfOp.getResult()->getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
return matchSuccess();
}
};
struct SIToFPLowering
: public OneToOneLLVMOpLowering<SIToFPOp, LLVM::SIToFPOp> {
using Super::Super;
};
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
operands, op->getAttrs());
return this->matchSuccess();
}
};
// Special lowering pattern for `ReturnOps`. Unlike all other operations,
// `ReturnOp` interacts with the function signature and must have as many
// operands as the function has return values. Because in LLVM IR, functions
// can only return 0 or 1 value, we pack multiple values into a structure type.
// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
// necessary before returning it
struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op->getNumOperands();
// If ReturnOp has 0 or 1 operand, create it and return immediately.
if (numArguments == 0) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(),
llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
return matchSuccess();
}
if (numArguments == 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, llvm::ArrayRef<Value *>(operands.front()),
llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(),
op->getAttrs());
return matchSuccess();
}
// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
auto packedType =
lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
for (unsigned i = 0; i < numArguments; ++i) {
packed = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), packedType, packed, operands[i],
getIntegerArrayAttr(rewriter, i));
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
return matchSuccess();
}
};
// FIXME: this should be tablegen'ed as well.
struct BranchOpLowering
: public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
using Super::Super;
};
struct CondBranchOpLowering
: public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
using Super::Super;
};
} // namespace
static void ensureDistinctSuccessors(Block &bb) {
auto *terminator = bb.getTerminator();
// Find repeated successors with arguments.
llvm::SmallDenseMap<Block *, llvm::SmallVector<int, 4>> successorPositions;
for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) {
Block *successor = terminator->getSuccessor(i);
// Blocks with no arguments are safe even if they appear multiple times
// because they don't need PHI nodes.
if (successor->getNumArguments() == 0)
continue;
successorPositions[successor].push_back(i);
}
// If a successor appears for the second or more time in the terminator,
// create a new dummy block that unconditionally branches to the original
// destination, and retarget the terminator to branch to this new block.
// There is no need to pass arguments to the dummy block because it will be
// dominated by the original block and can therefore use any values defined in
// the original block.
for (const auto &successor : successorPositions) {
const auto &positions = successor.second;
// Start from the second occurrence of a block in the successor list.
for (auto position = std::next(positions.begin()), end = positions.end();
position != end; ++position) {
auto *dummyBlock = new Block();
bb.getParent()->push_back(dummyBlock);
auto builder = OpBuilder(dummyBlock);
SmallVector<Value *, 8> operands(
terminator->getSuccessorOperands(*position));
builder.create<BranchOp>(terminator->getLoc(), successor.first, operands);
terminator->setSuccessor(dummyBlock, *position);
for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e;
++i)
terminator->eraseSuccessorOperand(*position, i);
}
}
}
void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
for (auto f : m.getOps<FuncOp>()) {
for (auto &bb : f.getBlocks()) {
::ensureDistinctSuccessors(bb);
}
}
}
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// FIXME: this should be tablegen'ed
patterns.insert<
AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering,
DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering,
DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering,
SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter);
}
// Convert types using the stored LLVM IR module.
Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); }
// Create an LLVM IR structure type if there is more than one result.
Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
assert(!types.empty() && "expected non-empty list of type");
if (types.size() == 1)
return convertType(types.front());
SmallVector<LLVM::LLVMType, 8> resultTypes;
resultTypes.reserve(types.size());
for (auto t : types) {
auto converted = convertType(t).dyn_cast<LLVM::LLVMType>();
if (!converted)
return {};
resultTypes.push_back(converted);
}
return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
}
/// Create an instance of LLVMTypeConverter in the given context.
static std::unique_ptr<LLVMTypeConverter>
makeStandardToLLVMTypeConverter(MLIRContext *context) {
return llvm::make_unique<LLVMTypeConverter>(context);
}
namespace {
/// A pass converting MLIR operations into the LLVM IR dialect.
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
// By default, the patterns are those converting Standard operations to the
// LLVMIR dialect.
explicit LLVMLoweringPass(
LLVMPatternListFiller patternListFiller =
populateStdToLLVMConversionPatterns,
LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter)
: patternListFiller(patternListFiller),
typeConverterMaker(converterBuilder) {}
// Run the dialect converter on the module.
void runOnModule() override {
if (!typeConverterMaker || !patternListFiller)
return signalPassFailure();
ModuleOp m = getModule();
LLVM::ensureDistinctSuccessors(m);
std::unique_ptr<LLVMTypeConverter> typeConverter =
typeConverterMaker(&getContext());
if (!typeConverter)
return signalPassFailure();
OwningRewritePatternList patterns;
populateLoopToStdConversionPatterns(patterns, m.getContext());
patternListFiller(*typeConverter, patterns);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter->isSignatureLegal(op.getType());
});
if (failed(applyPartialConversion(m, target, patterns, &*typeConverter)))
signalPassFailure();
}
// Callback for creating a list of patterns. It is called every time in
// runOnModule since applyPartialConversion consumes the list.
LLVMPatternListFiller patternListFiller;
// Callback for creating an instance of type converter. The converter
// constructor needs an MLIRContext, which is not available until runOnModule.
LLVMTypeConverterMaker typeConverterMaker;
};
} // end namespace
ModulePassBase *mlir::createConvertToLLVMIRPass() {
return new LLVMLoweringPass;
}
ModulePassBase *
mlir::createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller,
LLVMTypeConverterMaker typeConverterMaker) {
return new LLVMLoweringPass(patternListFiller, typeConverterMaker);
}
static PassRegistration<LLVMLoweringPass>
pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");