blob: dc38ae3eaaf53550e1f2f4fc9049569cc0daee53 [file] [log] [blame]
//===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a pass to convert MLIR standard and builtin dialects
// into the LLVM IR dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
using namespace mlir;
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
assert(llvmDialect && "LLVM IR dialect is not registered");
module = &llvmDialect->getLLVMModule();
}
// Get the LLVM context.
llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
return module->getContext();
}
// Extract an LLVM IR type from the LLVM IR dialect type.
LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
if (!wrappedLLVMType)
emitError(UnknownLoc::get(mlirContext),
"conversion resulted in a non-LLVM type");
return wrappedLLVMType;
}
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
return LLVM::LLVMType::getIntNTy(
llvmDialect, module->getDataLayout().getPointerSizeInBits());
}
Type LLVMTypeConverter::convertIndexType(IndexType type) {
return getIndexType();
}
Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
}
Type LLVMTypeConverter::convertFloatType(FloatType type) {
switch (type.getKind()) {
case mlir::StandardTypes::F32:
return LLVM::LLVMType::getFloatTy(llvmDialect);
case mlir::StandardTypes::F64:
return LLVM::LLVMType::getDoubleTy(llvmDialect);
case mlir::StandardTypes::F16:
return LLVM::LLVMType::getHalfTy(llvmDialect);
case mlir::StandardTypes::BF16: {
auto *mlirContext = llvmDialect->getContext();
return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"),
Type();
}
default:
llvm_unreachable("non-float type in convertFloatType");
}
}
// Except for signatures, MLIR function types are converted into LLVM
// pointer-to-function types.
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
SignatureConversion conversion(type.getNumInputs());
LLVM::LLVMType converted =
convertFunctionSignature(type, /*isVariadic=*/false, conversion);
return converted.getPointerTo();
}
// Function types are converted to LLVM Function types by recursively converting
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
FunctionType type, bool isVariadic,
LLVMTypeConverter::SignatureConversion &result) {
// Convert argument types one by one and check for errors.
for (auto &en : llvm::enumerate(type.getInputs()))
if (failed(convertSignatureArg(en.index(), en.value(), result)))
return {};
SmallVector<LLVM::LLVMType, 8> argTypes;
argTypes.reserve(llvm::size(result.getConvertedTypes()));
for (Type type : result.getConvertedTypes())
argTypes.push_back(unwrap(type));
// If function does not return anything, create the void result type,
// if it returns on element, convert it, otherwise pack the result types into
// a struct.
LLVM::LLVMType resultType =
type.getNumResults() == 0
? LLVM::LLVMType::getVoidTy(llvmDialect)
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
}
// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
// contains:
// 1. the pointer to the data buffer, followed by
// 2. a lowered `index`-type integer containing the distance between the
// beginning of the buffer and the first element to be accessed through the
// view, followed by
// 3. an array containing as many `index`-type integers as the rank of the
// MemRef: the array represents the size, in number of elements, of the memref
// along the given dimension. For constant MemRef dimensions, the
// corresponding size entry is a constant whose runtime value must match the
// static value, followed by
// 4. a second array containing as many `index`-type integers as the rank of
// the MemRef: the second array represents the "stride" (in tensor abstraction
// sense), i.e. the number of consecutive elements of the underlying buffer.
// TODO(ntv, zinenko): add assertions for the static cases.
//
// template <typename Elem, size_t Rank>
// struct {
// Elem *allocatedPtr;
// Elem *alignedPtr;
// int64_t offset;
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
static constexpr unsigned kSizePosInMemRefDescriptor = 3;
static constexpr unsigned kStridePosInMemRefDescriptor = 4;
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
int64_t offset;
SmallVector<int64_t, 4> strides;
bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset));
assert(strideSuccess &&
"Non-strided layout maps must have been normalized away");
(void)strideSuccess;
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
auto indexTy = getIndexType();
auto rank = type.getRank();
if (rank > 0) {
auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank());
return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy, arrayTy, arrayTy);
}
return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy);
}
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
// n > 1.
// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
// `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
Type LLVMTypeConverter::convertVectorType(VectorType type) {
auto elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
auto vectorType =
LLVM::LLVMType::getVectorTy(elementType, type.getShape().back());
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]);
return vectorType;
}
// Dispatch based on the actual type. Return null type on error.
Type LLVMTypeConverter::convertStandardType(Type type) {
if (auto funcType = type.dyn_cast<FunctionType>())
return convertFunctionType(funcType);
if (auto intType = type.dyn_cast<IntegerType>())
return convertIntegerType(intType);
if (auto floatType = type.dyn_cast<FloatType>())
return convertFloatType(floatType);
if (auto indexType = type.dyn_cast<IndexType>())
return convertIndexType(indexType);
if (auto memRefType = type.dyn_cast<MemRefType>())
return convertMemRefType(memRefType);
if (auto vectorType = type.dyn_cast<VectorType>())
return convertVectorType(vectorType);
if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
return llvmType;
return {};
}
// Convert the element type of the memref `t` to to an LLVM type using
// `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it
// into the MLIR LLVM dialect type and return.
static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) {
auto elementType = t.getElementType();
auto converted = lowering.convertType(elementType);
if (!converted)
return {};
return converted.cast<LLVM::LLVMType>().getPointerTo(t.getMemorySpace());
}
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
LLVMTypeConverter &lowering_,
PatternBenefit benefit)
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
/*============================================================================*/
/* MemRefDescriptor implementation */
/*============================================================================*/
/// Construct a helper for the given descriptor value.
MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) {
if (value) {
structType = value->getType().cast<LLVM::LLVMType>();
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
kOffsetPosInMemRefDescriptor);
}
}
/// Builds IR creating an `undef` value of the descriptor type.
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
Type descriptorType) {
Value *descriptor =
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
return MemRefDescriptor(descriptor);
}
/// Builds IR extracting the allocated pointer from the descriptor.
Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the allocated pointer into the descriptor.
void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
Value *ptr) {
setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
}
/// Builds IR extracting the aligned pointer from the descriptor.
Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the aligned pointer into the descriptor.
void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
Value *ptr) {
setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
}
/// Builds IR extracting the offset from the descriptor.
Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
}
/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value *offset) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, offset,
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
}
/// Builds IR extracting the pos-th size from the descriptor.
Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
Value *size) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, size,
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
}
/// Builds IR extracting the pos-th size from the descriptor.
Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc,
unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
}
/// Builds IR inserting the pos-th stride into the descriptor
void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
Value *stride) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, stride,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
}
Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc,
unsigned pos) {
Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
builder.getI64ArrayAttr(pos));
}
void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
Value *ptr) {
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
builder.getI64ArrayAttr(pos));
}
LLVM::LLVMType MemRefDescriptor::getElementType() {
return value->getType().cast<LLVM::LLVMType>().getStructElementType(
kAlignedPtrPosInMemRefDescriptor);
}
namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
// case it is necessary for rewriters.
template <typename SourceOp>
class LLVMLegalizationPattern : public LLVMOpLowering {
public:
// Construct a conversion pattern.
explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
LLVMTypeConverter &lowering_)
: LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(),
lowering_),
dialect(dialect_) {}
// Get the LLVM IR dialect.
LLVM::LLVMDialect &getDialect() const { return dialect; }
// Get the LLVM context.
llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); }
// Get the LLVM module in which the types are constructed.
llvm::Module &getModule() const { return dialect.getLLVMModule(); }
// Get the MLIR type wrapping the LLVM integer type whose bit width is defined
// by the pointer size used in the LLVM module.
LLVM::LLVMType getIndexType() const {
return LLVM::LLVMType::getIntNTy(
&dialect, getModule().getDataLayout().getPointerSizeInBits());
}
LLVM::LLVMType getVoidType() const {
return LLVM::LLVMType::getVoidTy(&dialect);
}
// Get the MLIR type wrapping the LLVM i8* type.
LLVM::LLVMType getVoidPtrType() const {
return LLVM::LLVMType::getInt8PtrTy(&dialect);
}
// Create an LLVM IR pseudo-operation defining the given index constant.
Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc,
uint64_t value) const {
auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
}
protected:
LLVM::LLVMDialect &dialect;
};
struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
FunctionType type = funcOp.getType();
// Pack the result types into a struct.
Type packedResult;
if (type.getNumResults() != 0)
if (!(packedResult = lowering.packFunctionResults(type.getResults())))
return matchFailure();
LLVM::LLVMType resultType = packedResult
? packedResult.cast<LLVM::LLVMType>()
: LLVM::LLVMType::getVoidTy(&dialect);
SmallVector<LLVM::LLVMType, 4> argTypes;
argTypes.reserve(type.getNumInputs());
SmallVector<unsigned, 4> promotedArgIndices;
promotedArgIndices.reserve(type.getNumInputs());
// Convert the original function arguments. Struct arguments are promoted to
// pointer to struct arguments to allow calling external functions with
// various ABIs (e.g. compiled from C/C++ on platform X).
auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
for (auto en : llvm::enumerate(type.getInputs())) {
auto t = en.value();
auto converted = lowering.convertType(t).dyn_cast<LLVM::LLVMType>();
if (!converted)
return matchFailure();
if (t.isa<MemRefType>()) {
converted = converted.getPointerTo();
promotedArgIndices.push_back(en.index());
}
argTypes.push_back(converted);
}
for (unsigned idx = 0, e = argTypes.size(); idx < e; ++idx)
result.addInputs(idx, argTypes[idx]);
auto llvmType = LLVM::LLVMType::getFunctionTy(
resultType, argTypes, varargsAttr && varargsAttr.getValue());
// Only retain those attributes that are not constructed by build.
SmallVector<NamedAttribute, 4> attributes;
for (const auto &attr : funcOp.getAttrs()) {
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
attr.first.is(impl::getTypeAttrName()) ||
attr.first.is("std.varargs"))
continue;
attributes.push_back(attr);
}
// Create an LLVM funcion.
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
op->getLoc(), funcOp.getName(), llvmType, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
// Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
// Insert loads from memref descriptor pointers in function bodies.
if (!newFuncOp.getBody().empty()) {
Block *firstBlock = &newFuncOp.getBody().front();
rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
for (unsigned idx : promotedArgIndices) {
BlockArgument *arg = firstBlock->getArgument(idx);
Value *loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg);
rewriter.replaceUsesOfBlockArgument(arg, loaded);
}
}
rewriter.eraseOp(op);
return matchSuccess();
}
};
//////////////// Support for Lowering operations on n-D vectors ////////////////
namespace {
// Helper struct to "unroll" operations on n-D vectors in terms of operations on
// 1-D LLVM vectors.
struct NDVectorTypeInfo {
// LLVM array struct which encodes n-D vectors.
LLVM::LLVMType llvmArrayTy;
// LLVM vector type which encodes the inner 1-D vector type.
LLVM::LLVMType llvmVectorTy;
// Multiplicity of llvmArrayTy to llvmVectorTy.
SmallVector<int64_t, 4> arraySizes;
};
} // namespace
// For >1-D vector types, extracts the necessary information to iterate over all
// 1-D subvectors in the underlying llrepresentation of the n-D vecotr
// Iterates on the llvm array type until we hit a non-array type (which is
// asserted to be an llvm vector type).
static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
LLVMTypeConverter &converter) {
assert(vectorType.getRank() > 1 && "extpected >1D vector type");
NDVectorTypeInfo info;
info.llvmArrayTy =
converter.convertType(vectorType).dyn_cast<LLVM::LLVMType>();
if (!info.llvmArrayTy)
return info;
info.arraySizes.reserve(vectorType.getRank() - 1);
auto llvmTy = info.llvmArrayTy;
while (llvmTy.isArrayTy()) {
info.arraySizes.push_back(llvmTy.getArrayNumElements());
llvmTy = llvmTy.getArrayElementType();
}
if (!llvmTy.isVectorTy())
return info;
info.llvmVectorTy = llvmTy;
return info;
}
// Express `linearIndex` in terms of coordinates of `basis`.
// Returns the empty vector when linearIndex is out of the range [0, P] where
// P is the product of all the basis coordinates.
//
// Prerequisites:
// Basis is an array of nonnegative integers (signed type inherited from
// vector shape type).
static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
unsigned linearIndex) {
SmallVector<int64_t, 4> res;
res.reserve(basis.size());
for (unsigned basisElement : llvm::reverse(basis)) {
res.push_back(linearIndex % basisElement);
linearIndex = linearIndex / basisElement;
}
if (linearIndex > 0)
return {};
std::reverse(res.begin(), res.end());
return res;
}
// Iterate of linear index, convert to coords space and insert splatted 1-D
// vector in each position.
template <typename Lambda>
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
Lambda fun) {
unsigned ub = 1;
for (auto s : info.arraySizes)
ub *= s;
for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
auto coords = getCoordinates(info.arraySizes, linearIndex);
// Linear index is out of bounds, we are done.
if (coords.empty())
break;
assert(coords.size() == info.arraySizes.size());
auto position = builder.getI64ArrayAttr(coords);
fun(position);
}
}
////////////// End Support for Lowering operations on n-D vectors //////////////
// Basic lowering implementation for one-to-one rewriting from Standard Ops to
// LLVM Dialect Ops.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = OneToOneLLVMOpLowering<SourceOp, TargetOp>;
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numResults = op->getNumResults();
Type packedType;
if (numResults != 0) {
packedType = this->lowering.packFunctionResults(
llvm::to_vector<4>(op->getResultTypes()));
if (!packedType)
return this->matchFailure();
}
auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
op->getAttrs());
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
return rewriter.eraseOp(op), this->matchSuccess();
if (numResults == 1)
return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
this->matchSuccess();
// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
SmallVector<Value *, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
auto type = this->lowering.convertType(op->getResult(i)->getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp.getOperation()->getResult(0),
rewriter.getI64ArrayAttr(i)));
}
rewriter.replaceOp(op, results);
return this->matchSuccess();
}
};
template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
static_assert(
std::is_base_of<
typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>,
SourceOp>::value,
"wrong operand count");
};
template <typename SourceOp> struct OpCountValidator<SourceOp, 1> {
static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value,
"expected a single operand");
};
template <typename SourceOp, unsigned OpCount> void ValidateOpCount() {
OpCountValidator<SourceOp, OpCount>();
}
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
// Ops for N-ary ops with one result. This supports higher-dimensional vector
// types.
template <typename SourceOp, typename TargetOp, unsigned OpCount>
struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
ValidateOpCount<SourceOp, OpCount>();
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
SourceOp>::value,
"expected same operands and result type");
// Cannot convert ops if their operands are not of LLVM type.
for (Value *operand : operands) {
if (!operand || !operand->getType().isa<LLVM::LLVMType>())
return this->matchFailure();
}
auto loc = op->getLoc();
auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>();
if (!llvmArrayTy.isArrayTy()) {
auto newOp = rewriter.create<TargetOp>(
op->getLoc(), operands[0]->getType(), operands, op->getAttrs());
rewriter.replaceOp(op, newOp.getResult());
return this->matchSuccess();
}
auto vectorType = op->getResult(0)->getType().dyn_cast<VectorType>();
if (!vectorType)
return this->matchFailure();
auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, this->lowering);
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
return this->matchFailure();
Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value *, OpCount> extractedOperands;
for (unsigned i = 0; i < OpCount; ++i) {
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, llvmVectorTy, operands[i], position));
}
Value *newVal = rewriter.create<TargetOp>(
loc, llvmVectorTy, extractedOperands, op->getAttrs());
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
newVal, position);
});
rewriter.replaceOp(op, desc);
return this->matchSuccess();
}
};
template <typename SourceOp, typename TargetOp>
using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
template <typename SourceOp, typename TargetOp>
using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
// Specific lowerings.
// FIXME: this should be tablegen'ed.
struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::ExpOp> {
using Super::Super;
};
struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
using Super::Super;
};
struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> {
using Super::Super;
};
struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
using Super::Super;
};
struct DivISOpLowering : public BinaryOpLLVMOpLowering<DivISOp, LLVM::SDivOp> {
using Super::Super;
};
struct DivIUOpLowering : public BinaryOpLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
using Super::Super;
};
struct RemISOpLowering : public BinaryOpLLVMOpLowering<RemISOp, LLVM::SRemOp> {
using Super::Super;
};
struct RemIUOpLowering : public BinaryOpLLVMOpLowering<RemIUOp, LLVM::URemOp> {
using Super::Super;
};
struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
using Super::Super;
};
struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> {
using Super::Super;
};
struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> {
using Super::Super;
};
struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> {
using Super::Super;
};
struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> {
using Super::Super;
};
struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> {
using Super::Super;
};
struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> {
using Super::Super;
};
struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> {
using Super::Super;
};
struct SelectOpLowering
: public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> {
using Super::Super;
};
struct ConstLLVMOpLowering
: public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
using Super::Super;
};
// Check if the MemRefType `type` is supported by the lowering. We currently
// only support memrefs with identity maps.
static bool isSupportedMemRefType(MemRefType type) {
return type.getAffineMaps().empty() ||
llvm::all_of(type.getAffineMaps(),
[](AffineMap map) { return map.isIdentity(); });
}
// An `alloc` is converted into a definition of a memref descriptor value and
// a call to `malloc` to allocate the underlying data buffer. The memref
// descriptor is of the LLVM structure type where:
// 1. the first element is a pointer to the allocated (typed) data buffer,
// 2. the second element is a pointer to the (typed) payload, aligned to the
// specified alignment,
// 3. the remaining elements serve to store all the sizes and strides of the
// memref using LLVM-converted `index` type.
//
// Alignment is obtained by allocating `alignment - 1` more bytes than requested
// and shifting the aligned pointer relative to the allocated memory. If
// alignment is unspecified, the two pointers are equal.
struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
PatternMatchResult match(Operation *op) const override {
MemRefType type = cast<AllocOp>(op).getType();
if (isSupportedMemRefType(type))
return matchSuccess();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(type, strides, offset);
if (failed(successStrides))
return matchFailure();
// Dynamic strides are ok if they can be deduced from dynamic sizes (which
// is guaranteed when succeeded(successStrides)). Dynamic offset however can
// never be alloc'ed.
if (offset == MemRefType::getDynamicStrideOrOffset())
return matchFailure();
return matchSuccess();
}
void rewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto allocOp = cast<AllocOp>(op);
MemRefType type = allocOp.getType();
// Get actual sizes of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands. In case of
// zero-dimensional memref, assume a scalar (size 1).
SmallVector<Value *, 4> sizes;
sizes.reserve(type.getRank());
unsigned i = 0;
for (int64_t s : type.getShape())
sizes.push_back(s == -1 ? operands[i++]
: createIndexConstant(rewriter, loc, s));
if (sizes.empty())
sizes.push_back(createIndexConstant(rewriter, loc, 1));
// Compute the total number of memref elements.
Value *cumulativeSize = sizes.front();
for (unsigned i = 1, e = sizes.size(); i < e; ++i)
cumulativeSize = rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, sizes[i]});
// Compute the size of an individual element. This emits the MLIR equivalent
// of the following sizeof(...) implementation in LLVM IR:
// %0 = getelementptr %elementType* null, %indexType 1
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
auto elementType = type.getElementType();
auto convertedPtrType =
lowering.convertType(elementType).cast<LLVM::LLVMType>().getPointerTo();
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto one = createIndexConstant(rewriter, loc, 1);
auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType,
ArrayRef<Value *>{nullPtr, one});
auto elementSize =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
cumulativeSize = rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize});
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
if (!mallocFunc) {
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "malloc",
LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
/*isVarArg=*/false));
}
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
Value *align = nullptr;
if (auto alignAttr = allocOp.alignment()) {
align = createIndexConstant(rewriter, loc,
alignAttr.getValue().getSExtValue());
cumulativeSize = rewriter.create<LLVM::SubOp>(
loc, rewriter.create<LLVM::AddOp>(loc, cumulativeSize, align), one);
}
Value *allocated =
rewriter
.create<LLVM::CallOp>(loc, getVoidPtrType(),
rewriter.getSymbolRefAttr(mallocFunc),
cumulativeSize)
.getResult(0);
auto structElementType = lowering.convertType(elementType);
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
type.getMemorySpace());
Value *bitcastAllocated = rewriter.create<LLVM::BitcastOp>(
loc, elementPtrType, ArrayRef<Value *>(allocated));
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(type, strides, offset);
assert(succeeded(successStrides) && "unexpected non-strided memref");
(void)successStrides;
assert(offset != MemRefType::getDynamicStrideOrOffset() &&
"unexpected dynamic offset");
// 0-D memref corner case: they have size 1 ...
assert(((type.getRank() == 0 && strides.empty() && sizes.size() == 1) ||
(strides.size() == sizes.size())) &&
"unexpected number of strides");
// Create the MemRef descriptor.
auto structType = lowering.convertType(type);
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
// Field 1: Allocated pointer, used for malloc/free.
memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated);
// Field 2: Actual aligned pointer to payload.
Value *bitcastAligned = bitcastAllocated;
if (align) {
// offset = (align - (ptr % align))% align
Value *intVal = rewriter.create<LLVM::PtrToIntOp>(
loc, this->getIndexType(), allocated);
Value *ptrModAlign = rewriter.create<LLVM::URemOp>(loc, intVal, align);
Value *subbed = rewriter.create<LLVM::SubOp>(loc, align, ptrModAlign);
Value *offset = rewriter.create<LLVM::URemOp>(loc, subbed, align);
Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
allocated, offset);
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
loc, elementPtrType, ArrayRef<Value *>(aligned));
}
memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned);
// Field 3: Offset in aligned pointer.
memRefDescriptor.setOffset(rewriter, loc,
createIndexConstant(rewriter, loc, offset));
if (type.getRank() == 0)
// No size/stride descriptor in memref, return the descriptor value.
return rewriter.replaceOp(op, {memRefDescriptor});
// Fields 4 and 5: Sizes and strides of the strided MemRef.
// Store all sizes in the descriptor. Only dynamic sizes are passed in as
// operands to AllocOp.
Value *runningStride = nullptr;
// Iterate strides in reverse order, compute runningStride and strideValues.
auto nStrides = strides.size();
SmallVector<Value *, 4> strideValues(nStrides, nullptr);
for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) {
int64_t index = nStrides - 1 - indexedStride.index();
if (strides[index] == MemRefType::getDynamicStrideOrOffset())
// Identity layout map is enforced in the match function, so we compute:
// `runningStride *= sizes[index]`
runningStride =
runningStride
? rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[index])
: createIndexConstant(rewriter, loc, 1);
else
runningStride = createIndexConstant(rewriter, loc, strides[index]);
strideValues[index] = runningStride;
}
// Fill size and stride descriptors in memref.
for (auto indexedSize : llvm::enumerate(sizes)) {
int64_t index = indexedSize.index();
memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value());
memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]);
}
// Return the final value of the descriptor.
rewriter.replaceOp(op, {memRefDescriptor});
}
};
// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
// passes the pointer to the MemRef across function boundaries.
template <typename CallOpType>
struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
using LLVMLegalizationPattern<CallOpType>::LLVMLegalizationPattern;
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = LLVMLegalizationPattern<CallOpType>;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<CallOpType> transformed(operands);
auto callOp = cast<CallOpType>(op);
// Pack the result types into a struct.
Type packedResult;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
if (numResults != 0) {
if (!(packedResult = this->lowering.packFunctionResults(resultTypes)))
return this->matchFailure();
}
SmallVector<Value *, 4> opOperands(op->getOperands());
auto promoted = this->lowering.promoteMemRefDescriptors(
op->getLoc(), opOperands, operands, rewriter);
auto newOp = rewriter.create<LLVM::CallOp>(op->getLoc(), packedResult,
promoted, op->getAttrs());
// If < 2 results, packing did not do anything and we can just return.
if (numResults < 2) {
SmallVector<Value *, 4> results(newOp.getResults());
rewriter.replaceOp(op, results);
return this->matchSuccess();
}
// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
// TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around
// a particular interaction between MemRefType and CallOp lowering. Find a
// way to avoid special casing.
SmallVector<Value *, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
auto type = this->lowering.convertType(op->getResult(i)->getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp.getOperation()->getResult(0),
rewriter.getI64ArrayAttr(i)));
}
rewriter.replaceOp(op, results);
return this->matchSuccess();
}
};
struct CallOpLowering : public CallOpInterfaceLowering<CallOp> {
using Super::Super;
};
struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
using Super::Super;
};
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
auto freeFunc =
op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
if (!freeFunc) {
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
freeFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "free",
LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(),
/*isVarArg=*/false));
}
MemRefDescriptor memref(transformed.memref());
Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(),
memref.allocatedPtr(rewriter, op->getLoc()));
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
return matchSuccess();
}
};
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
PatternMatchResult match(Operation *op) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
MemRefType sourceType =
memRefCastOp.getOperand()->getType().cast<MemRefType>();
MemRefType targetType = memRefCastOp.getType();
return (isSupportedMemRefType(targetType) &&
isSupportedMemRefType(sourceType))
? matchSuccess()
: matchFailure();
}
void rewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
OperandAdaptor<MemRefCastOp> transformed(operands);
// memref_cast is defined for source and destination memref types with the
// same element type, same mappings, same address space and same rank.
// Therefore a simple bitcast suffices. If not it is undefined behavior.
auto targetStructType = lowering.convertType(memRefCastOp.getType());
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, targetStructType,
transformed.source());
}
};
// A `dim` is converted to a constant for static sizes and to an access to the
// size stored in the memref descriptor for dynamic sizes.
struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<DimOp>(op);
OperandAdaptor<DimOp> transformed(operands);
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
auto shape = type.getShape();
int64_t index = dimOp.getIndex();
// Extract dynamic size from the memref descriptor.
if (ShapedType::isDynamic(shape[index]))
rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor())
.size(rewriter, op->getLoc(), index)});
else
// Use constant for static size.
rewriter.replaceOp(
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
return matchSuccess();
}
};
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
using Base = LoadStoreOpLowering<Derived>;
PatternMatchResult match(Operation *op) const override {
MemRefType type = cast<Derived>(op).getMemRefType();
return isSupportedMemRefType(type) ? this->matchSuccess()
: this->matchFailure();
}
// Given subscript indices and array sizes in row-major order,
// i_n, i_{n-1}, ..., i_1
// s_n, s_{n-1}, ..., s_1
// obtain a value that corresponds to the linearized subscript
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
// by accumulating the running linearized value.
// Note that `indices` and `allocSizes` are passed in the same order as they
// appear in load/store operations and memref type declarations.
Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
ArrayRef<Value *> indices,
ArrayRef<Value *> allocSizes) const {
assert(indices.size() == allocSizes.size() &&
"mismatching number of indices and allocation sizes");
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
Value *linearized = indices.front();
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
linearized = builder.create<LLVM::MulOp>(
loc, this->getIndexType(),
ArrayRef<Value *>{linearized, allocSizes[i]});
linearized = builder.create<LLVM::AddOp>(
loc, this->getIndexType(), ArrayRef<Value *>{linearized, indices[i]});
}
return linearized;
}
// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value *getStridedElementPtr(Location loc, Type elementTypePtr,
Value *descriptor, ArrayRef<Value *> indices,
ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const {
MemRefDescriptor memRefDescriptor(descriptor);
Value *base = memRefDescriptor.alignedPtr(rewriter, loc);
Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.offset(rewriter, loc)
: this->createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
Value *stride =
strides[i] == MemRefType::getDynamicStrideOrOffset()
? memRefDescriptor.stride(rewriter, loc, i)
: this->createIndexConstant(rewriter, loc, strides[i]);
Value *additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
offsetValue =
rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
}
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
}
Value *getDataPtr(Location loc, MemRefType type, Value *memRefDesc,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
auto ptrType = getMemRefElementPtrType(type, this->lowering);
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(type, strides, offset);
assert(succeeded(successStrides) && "unexpected non-strided memref");
(void)successStrides;
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
offset, rewriter);
}
};
// Load operation is lowered to obtaining a pointer to the indexed element
// and loading it.
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
using Base::Base;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loadOp = cast<LoadOp>(op);
OperandAdaptor<LoadOp> transformed(operands);
auto type = loadOp.getMemRefType();
Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
return matchSuccess();
}
};
// Store operation is lowered to obtaining a pointer to the indexed element,
// and storing the given value to it.
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
using Base::Base;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto type = cast<StoreOp>(op).getMemRefType();
OperandAdaptor<StoreOp> transformed(operands);
Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
return matchSuccess();
}
};
// The lowering of index_cast becomes an integer conversion since index becomes
// an integer. If the bit width of the source and target integer types is the
// same, just erase the cast. If the target type is wider, sign-extend the
// value, otherwise truncate it.
struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
IndexCastOpOperandAdaptor transformed(operands);
auto indexCastOp = cast<IndexCastOp>(op);
auto targetType =
this->lowering.convertType(indexCastOp.getResult()->getType())
.cast<LLVM::LLVMType>();
auto sourceType = transformed.in()->getType().cast<LLVM::LLVMType>();
unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth();
unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth();
if (targetBits == sourceBits)
rewriter.replaceOp(op, transformed.in());
else if (targetBits < sourceBits)
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
transformed.in());
else
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
transformed.in());
return matchSuccess();
}
};
// Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two
// enums share the numerical values so just cast.
template <typename LLVMPredType, typename StdPredType>
static LLVMPredType convertCmpPredicate(StdPredType pred) {
return static_cast<LLVMPredType>(pred);
}
struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpiOp = cast<CmpIOp>(op);
CmpIOpOperandAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
op, lowering.convertType(cmpiOp.getResult()->getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
return matchSuccess();
}
};
struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpfOp = cast<CmpFOp>(op);
CmpFOpOperandAdaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
op, lowering.convertType(cmpfOp.getResult()->getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
return matchSuccess();
}
};
struct SIToFPLowering
: public OneToOneLLVMOpLowering<SIToFPOp, LLVM::SIToFPOp> {
using Super::Super;
};
struct FPExtLowering : public OneToOneLLVMOpLowering<FPExtOp, LLVM::FPExtOp> {
using Super::Super;
};
struct FPTruncLowering
: public OneToOneLLVMOpLowering<FPTruncOp, LLVM::FPTruncOp> {
using Super::Super;
};
struct SignExtendIOpLowering
: public OneToOneLLVMOpLowering<SignExtendIOp, LLVM::SExtOp> {
using Super::Super;
};
struct TruncateIOpLowering
: public OneToOneLLVMOpLowering<TruncateIOp, LLVM::TruncOp> {
using Super::Super;
};
struct ZeroExtendIOpLowering
: public OneToOneLLVMOpLowering<ZeroExtendIOp, LLVM::ZExtOp> {
using Super::Super;
};
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value *>> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
operands, op->getAttrs());
return this->matchSuccess();
}
};
// Special lowering pattern for `ReturnOps`. Unlike all other operations,
// `ReturnOp` interacts with the function signature and must have as many
// operands as the function has return values. Because in LLVM IR, functions
// can only return 0 or 1 value, we pack multiple values into a structure type.
// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
// necessary before returning it
struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op->getNumOperands();
// If ReturnOp has 0 or 1 operand, create it and return immediately.
if (numArguments == 0) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, llvm::ArrayRef<Value *>(),
llvm::ArrayRef<Block *>(),
op->getAttrs());
return matchSuccess();
}
if (numArguments == 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, llvm::ArrayRef<Value *>(operands.front()),
llvm::ArrayRef<Block *>(), op->getAttrs());
return matchSuccess();
}
// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
auto packedType =
lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
for (unsigned i = 0; i < numArguments; ++i) {
packed = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), packedType, packed, operands[i],
rewriter.getI64ArrayAttr(i));
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, llvm::makeArrayRef(packed),
llvm::ArrayRef<Block *>(),
op->getAttrs());
return matchSuccess();
}
};
// FIXME: this should be tablegen'ed as well.
struct BranchOpLowering
: public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
using Super::Super;
};
struct CondBranchOpLowering
: public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
using Super::Super;
};
// The Splat operation is lowered to an insertelement + a shufflevector
// operation. Splat to only 1-d vector result types are lowered.
struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto splatOp = cast<SplatOp>(op);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() != 1)
return matchFailure();
// First insert it into an undef vector so we can shuffle it.
auto vectorType = lowering.convertType(splatOp.getType());
Value *undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
auto v = rewriter.create<LLVM::InsertElementOp>(
op->getLoc(), vectorType, undef, splatOp.getOperand(), zero);
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
SmallVector<int32_t, 4> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
return matchSuccess();
}
};
// The Splat operation is lowered to an insertelement + a shufflevector
// operation. Splat to only 2+-d vector result types are lowered by the
// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto splatOp = cast<SplatOp>(op);
OperandAdaptor<SplatOp> adaptor(operands);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() == 1)
return matchFailure();
// First insert it into an undef vector so we can shuffle it.
auto loc = op->getLoc();
auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, lowering);
auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
if (!llvmArrayTy || !llvmVectorTy)
return matchFailure();
// Construct returned value.
Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
// Construct a 1-D vector with the splatted value that we insert in all the
// places within the returned descriptor.
Value *vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, lowering.convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value *v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
adaptor.input(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
SmallVector<int32_t, 4> zeroValues(width, 0);
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
// Iterate of linear index, convert to coords space and insert splatted 1-D
// vector in each position.
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, v,
position);
});
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
/// Conversion pattern that transforms a subview op into:
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
/// and stride.
/// The subview op is replaced by the descriptor.
struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto viewOp = cast<SubViewOp>(op);
SubViewOpOperandAdaptor adaptor(operands);
auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>();
auto sourceElementTy =
lowering.convertType(sourceMemRefType.getElementType())
.dyn_cast_or_null<LLVM::LLVMType>();
auto viewMemRefType = viewOp.getType();
auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
.dyn_cast<LLVM::LLVMType>();
auto targetDescTy =
lowering.convertType(viewMemRefType).dyn_cast_or_null<LLVM::LLVMType>();
if (!sourceElementTy || !targetDescTy)
return matchFailure();
// Early exit for 0-D and operands lesser than `rank` corner cases.
unsigned rank = sourceMemRefType.getRank();
if (viewMemRefType.getRank() == 0 || rank != adaptor.offsets().size() ||
rank != adaptor.sizes().size() || rank != adaptor.strides().size())
return matchFailure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
if (failed(successStrides))
return matchFailure();
// Create the descriptor.
MemRefDescriptor sourceMemRef(adaptor.source());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Copy the buffer pointer from the old descriptor to the new one.
Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
extracted = sourceMemRef.alignedPtr(rewriter, loc);
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted);
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
// Extract strides needed to compute offset.
SmallVector<Value *, 4> strideValues;
strideValues.reserve(viewMemRefType.getRank());
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
// Offset.
Value *baseOffset = sourceMemRef.offset(rewriter, loc);
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
Value *min = adaptor.offsets()[i];
baseOffset = rewriter.create<LLVM::AddOp>(
loc, baseOffset,
rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
}
targetMemRef.setOffset(rewriter, loc, baseOffset);
// Update sizes and strides.
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
targetMemRef.setSize(rewriter, loc, i, adaptor.sizes()[i]);
targetMemRef.setStride(rewriter, loc, i,
rewriter.create<LLVM::MulOp>(
loc, adaptor.strides()[i], strideValues[i]));
}
rewriter.replaceOp(op, {targetMemRef});
return matchSuccess();
}
};
/// Conversion pattern that transforms a op into:
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
/// and stride.
/// The view op is replaced by the descriptor.
struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
using LLVMLegalizationPattern<ViewOp>::LLVMLegalizationPattern;
// Build and return the value for the idx^th shape dimension, either by
// returning the constant shape dimension or counting the proper dynamic size.
Value *getSize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<int64_t> shape, ArrayRef<Value *> dynamicSizes,
unsigned idx) const {
assert(idx < shape.size());
if (!ShapedType::isDynamic(shape[idx]))
return createIndexConstant(rewriter, loc, shape[idx]);
// Count the number of dynamic dims in range [0, idx]
unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
return ShapedType::isDynamic(v);
});
return dynamicSizes[nDynamic];
}
// Build and return the idx^th stride, either by returning the constant stride
// or by computing the dynamic stride from the current `runningStride` and
// `nextSize`. The caller should keep a running stride and update it with the
// result returned by this function.
Value *getStride(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<int64_t> strides, Value *nextSize,
Value *runningStride, unsigned idx) const {
assert(idx < strides.size());
if (strides[idx] != MemRefType::getDynamicStrideOrOffset())
return createIndexConstant(rewriter, loc, strides[idx]);
if (nextSize)
return runningStride
? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
: nextSize;
assert(!runningStride);
return createIndexConstant(rewriter, loc, 1);
}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto viewOp = cast<ViewOp>(op);
ViewOpOperandAdaptor adaptor(operands);
auto viewMemRefType = viewOp.getType();
auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
.dyn_cast<LLVM::LLVMType>();
auto targetDescTy =
lowering.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
if (!targetDescTy)
return op->emitWarning("Target descriptor type not converted to LLVM"),
matchFailure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
if (failed(successStrides))
return op->emitWarning("cannot cast to non-strided shape"),
matchFailure();
// Create the descriptor.
MemRefDescriptor sourceMemRef(adaptor.source());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Field 1: Copy the allocated pointer, used for malloc/free.
Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
// Field 2: Copy the actual aligned pointer to payload.
extracted = sourceMemRef.alignedPtr(rewriter, loc);
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted);
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
// Field 3: Copy the offset in aligned pointer.
unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
(void)numDynamicSizes;
auto sizeAndOffsetOperands = adaptor.operands();
assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + 1 ||
offset != MemRefType::getDynamicStrideOrOffset());
Value *baseOffset = (offset != MemRefType::getDynamicStrideOrOffset())
? createIndexConstant(rewriter, loc, offset)
// TODO(ntv): better adaptor.
: sizeAndOffsetOperands.back();
targetMemRef.setOffset(rewriter, loc, baseOffset);
// Early exit for 0-D corner case.
if (viewMemRefType.getRank() == 0)
return rewriter.replaceOp(op, {targetMemRef}), matchSuccess();
// Fields 4 and 5: Update sizes and strides.
if (strides.back() != 1)
return op->emitWarning("cannot cast to non-contiguous shape"),
matchFailure();
Value *stride = nullptr, *nextSize = nullptr;
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
// Update size.
Value *size = getSize(rewriter, loc, viewMemRefType.getShape(),
sizeAndOffsetOperands, i);
targetMemRef.setSize(rewriter, loc, i, size);
// Update stride.
stride = getStride(rewriter, loc, strides, nextSize, stride, i);
targetMemRef.setStride(rewriter, loc, i, stride);
nextSize = size;
}
rewriter.replaceOp(op, {targetMemRef});
return matchSuccess();
}
};
} // namespace
static void ensureDistinctSuccessors(Block &bb) {
auto *terminator = bb.getTerminator();
// Find repeated successors with arguments.
llvm::SmallDenseMap<Block *, llvm::SmallVector<int, 4>> successorPositions;
for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) {
Block *successor = terminator->getSuccessor(i);
// Blocks with no arguments are safe even if they appear multiple times
// because they don't need PHI nodes.
if (successor->getNumArguments() == 0)
continue;
successorPositions[successor].push_back(i);
}
// If a successor appears for the second or more time in the terminator,
// create a new dummy block that unconditionally branches to the original
// destination, and retarget the terminator to branch to this new block.
// There is no need to pass arguments to the dummy block because it will be
// dominated by the original block and can therefore use any values defined in
// the original block.
for (const auto &successor : successorPositions) {
const auto &positions = successor.second;
// Start from the second occurrence of a block in the successor list.
for (auto position = std::next(positions.begin()), end = positions.end();
position != end; ++position) {
auto *dummyBlock = new Block();
bb.getParent()->push_back(dummyBlock);
auto builder = OpBuilder(dummyBlock);
SmallVector<Value *, 8> operands(
terminator->getSuccessorOperands(*position));
builder.create<BranchOp>(terminator->getLoc(), successor.first, operands);
terminator->setSuccessor(dummyBlock, *position);
for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e;
++i)
terminator->eraseSuccessorOperand(*position, i);
}
}
}
void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
for (auto f : m.getOps<FuncOp>()) {
for (auto &bb : f.getBlocks()) {
::ensureDistinctSuccessors(bb);
}
}
}
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// FIXME: this should be tablegen'ed
// clang-format off
patterns.insert<
AddFOpLowering,
AddIOpLowering,
AllocOpLowering,
AndOpLowering,
BranchOpLowering,
CallIndirectOpLowering,
CallOpLowering,
CmpFOpLowering,
CmpIOpLowering,
CondBranchOpLowering,
ConstLLVMOpLowering,
DeallocOpLowering,
DimOpLowering,
DivFOpLowering,
DivISOpLowering,
DivIUOpLowering,
ExpOpLowering,
FPExtLowering,
FPTruncLowering,
FuncOpConversion,
IndexCastOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
MulFOpLowering,
MulIOpLowering,
OrOpLowering,
RemFOpLowering,
RemISOpLowering,
RemIUOpLowering,
ReturnOpLowering,
SIToFPLowering,
SelectOpLowering,
SignExtendIOpLowering,
SplatOpLowering,
SplatNdOpLowering,
StoreOpLowering,
SubFOpLowering,
SubIOpLowering,
SubViewOpLowering,
TruncateIOpLowering,
ViewOpLowering,
XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
// clang-format on
}
// Convert types using the stored LLVM IR module.
Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); }
// Create an LLVM IR structure type if there is more than one result.
Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
assert(!types.empty() && "expected non-empty list of type");
if (types.size() == 1)
return convertType(types.front());
SmallVector<LLVM::LLVMType, 8> resultTypes;
resultTypes.reserve(types.size());
for (auto t : types) {
auto converted = convertType(t).dyn_cast<LLVM::LLVMType>();
if (!converted)
return {};
resultTypes.push_back(converted);
}
return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
}
Value *LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc,
Value *operand,
OpBuilder &builder) {
auto *context = builder.getContext();
auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect());
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
auto ptrType = operand->getType().cast<LLVM::LLVMType>().getPointerTo();
Value *one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
IntegerAttr::get(indexType, 1));
Value *allocated =
builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
// Store into the alloca'ed descriptor.
builder.create<LLVM::StoreOp>(loc, operand, allocated);
return allocated;
}
SmallVector<Value *, 4> LLVMTypeConverter::promoteMemRefDescriptors(
Location loc, ArrayRef<Value *> opOperands, ArrayRef<Value *> operands,
OpBuilder &builder) {
SmallVector<Value *, 4> promotedOperands;
promotedOperands.reserve(operands.size());
for (auto it : llvm::zip(opOperands, operands)) {
auto *operand = std::get<0>(it);
auto *llvmOperand = std::get<1>(it);
if (!operand->getType().isa<MemRefType>()) {
promotedOperands.push_back(operand);
continue;
}
promotedOperands.push_back(
promoteOneMemRefDescriptor(loc, llvmOperand, builder));
}
return promotedOperands;
}
/// Create an instance of LLVMTypeConverter in the given context.
static std::unique_ptr<LLVMTypeConverter>
makeStandardToLLVMTypeConverter(MLIRContext *context) {
return std::make_unique<LLVMTypeConverter>(context);
}
namespace {
/// A pass converting MLIR operations into the LLVM IR dialect.
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
// By default, the patterns are those converting Standard operations to the
// LLVMIR dialect.
explicit LLVMLoweringPass(
LLVMPatternListFiller patternListFiller =
populateStdToLLVMConversionPatterns,
LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter)
: patternListFiller(patternListFiller),
typeConverterMaker(converterBuilder) {}
// Run the dialect converter on the module.
void runOnModule() override {
if (!typeConverterMaker || !patternListFiller)
return signalPassFailure();
ModuleOp m = getModule();
LLVM::ensureDistinctSuccessors(m);
std::unique_ptr<LLVMTypeConverter> typeConverter =
typeConverterMaker(&getContext());
if (!typeConverter)
return signalPassFailure();
OwningRewritePatternList patterns;
populateLoopToStdConversionPatterns(patterns, m.getContext());
patternListFiller(*typeConverter, patterns);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
if (failed(applyPartialConversion(m, target, patterns, &*typeConverter)))
signalPassFailure();
}
// Callback for creating a list of patterns. It is called every time in
// runOnModule since applyPartialConversion consumes the list.
LLVMPatternListFiller patternListFiller;
// Callback for creating an instance of type converter. The converter
// constructor needs an MLIRContext, which is not available until runOnModule.
LLVMTypeConverterMaker typeConverterMaker;
};
} // end namespace
std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerToLLVMPass() {
return std::make_unique<LLVMLoweringPass>();
}
std::unique_ptr<OpPassBase<ModuleOp>>
mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
LLVMTypeConverterMaker typeConverterMaker) {
return std::make_unique<LLVMLoweringPass>(patternListFiller,
typeConverterMaker);
}
static PassRegistration<LLVMLoweringPass>
pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");