blob: 84452a2ec2c211a65f08bd0694e14e80e013908d [file] [log] [blame]
//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/Linalg/IR/LinalgOps.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
#include "mlir/Linalg/Passes.h"
#include "mlir/Linalg/Utils/Intrinsics.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/LowerAffine.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::LLVM;
using namespace mlir::linalg;
using namespace mlir::linalg::intrinsics;
using add = ValueBuilder<mlir::LLVM::AddOp>;
using addi = ValueBuilder<mlir::AddIOp>;
using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>;
using cmpi = ValueBuilder<mlir::CmpIOp>;
using constant = ValueBuilder<mlir::LLVM::ConstantOp>;
using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>;
using gep = ValueBuilder<mlir::LLVM::GEPOp>;
using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
using llvm_call = OperationBuilder<mlir::LLVM::CallOp>;
using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
using llvm_load = ValueBuilder<LLVM::LoadOp>;
using llvm_store = OperationBuilder<LLVM::StoreOp>;
using llvm_select = ValueBuilder<LLVM::SelectOp>;
using mul = ValueBuilder<mlir::LLVM::MulOp>;
using sub = ValueBuilder<mlir::LLVM::SubOp>;
using undef = ValueBuilder<mlir::LLVM::UndefOp>;
using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
template <typename T>
static LLVMType getPtrToElementType(T containerType,
LLVMTypeConverter &lowering) {
return lowering.convertType(containerType.getElementType())
.template cast<LLVMType>()
.getPointerTo();
}
// Convert the given type to the LLVM IR Dialect type. The following
// conversions are supported:
// - an Index type is converted into an LLVM integer type with pointer
// bitwidth (analogous to intptr_t in C);
// - an Integer type is converted into an LLVM integer type of the same width;
// - an F32 type is converted into an LLVM float type
// - a Buffer, Range or View is converted into an LLVM structure type
// containing the respective dynamic values.
static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
auto *context = t.getContext();
auto int64Ty = lowering.convertType(IntegerType::get(64, context))
.cast<LLVM::LLVMType>();
// A buffer descriptor contains the pointer to a flat region of storage and
// the size of the region.
//
// template <typename Elem, size_t Rank>
// struct {
// Elem *ptr;
// int64_t size;
// };
if (auto bufferType = t.dyn_cast<BufferType>()) {
auto ptrTy = getPtrToElementType(bufferType, lowering);
return LLVMType::getStructTy(ptrTy, int64Ty);
}
// Range descriptor contains the range bounds and the step as 64-bit integers.
//
// struct {
// int64_t min;
// int64_t max;
// int64_t step;
// };
if (t.isa<RangeType>())
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
// View descriptor contains the pointer to the data buffer, followed by a
// 64-bit integer containing the distance between the beginning of the buffer
// and the first element to be accessed through the view, followed by two
// arrays, each containing as many 64-bit integers as the rank of the View.
// The first array represents the size, in number of original elements, of the
// view along the given dimension. When taking the view, the size is the
// difference between the upper and the lower bound of the range. The second
// array represents the "stride" (in tensor abstraction sense), i.e. the
// number of consecutive elements of the underlying buffer that separate two
// consecutive elements addressable through the view along the given
// dimension. When taking the view, the strides are constructed as products
// of the original sizes along the trailing dimensions, multiplied by the view
// step. For example, a view of a MxN memref with ranges {0:M:1}, {0:N:1},
// i.e. the view of a complete memref, will have strides N and 1. A view with
// ranges {0:M:2}, {0:N:3} will have strides 2*N and 3.
//
// template <typename Elem, size_t Rank>
// struct {
// Elem *ptr;
// int64_t offset;
// int64_t sizes[Rank];
// int64_t strides[Rank];
// };
if (auto viewType = t.dyn_cast<ViewType>()) {
auto ptrTy = getPtrToElementType(viewType, lowering);
auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank());
return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy);
}
return Type();
}
// Create an array attribute containing integer attributes with values provided
// in `position`.
static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
SmallVector<Attribute, 4> attrs;
attrs.reserve(position.size());
for (auto p : position)
attrs.push_back(builder.getI64IntegerAttr(p));
return builder.getArrayAttr(attrs);
}
// BufferAllocOp creates a new `!linalg.buffer` value.
class BufferAllocOpConversion : public LLVMOpLowering {
public:
explicit BufferAllocOpConversion(MLIRContext *context,
LLVMTypeConverter &lowering_)
: LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto indexType = IndexType::get(op->getContext());
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
mallocFunc =
FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
module.push_back(mallocFunc);
}
// Get MLIR types for injecting element pointer.
auto allocOp = cast<BufferAllocOp>(op);
auto elementType = allocOp.getElementType();
uint64_t elementSize = 0;
if (auto vectorType = elementType.dyn_cast<VectorType>())
elementSize = vectorType.getNumElements() *
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
else
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
auto bufferType = allocOp.getResult()->getType().cast<BufferType>();
auto elementPtrType = getPtrToElementType(bufferType, lowering);
auto bufferDescriptorType =
convertLinalgType(allocOp.getResult()->getType(), lowering);
// Emit IR for creating a new buffer descriptor with an underlying malloc.
edsc::ScopedContext context(rewriter, op->getLoc());
auto constantSize = bufferType.getBufferSize();
Value *size =
constantSize
? constant(int64Ty, IntegerAttr::get(indexType, *constantSize))
.getValue()
: operands[0];
Value *allocSize =
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
Value *allocated =
llvm_call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize)
.getOperation()
->getResult(0);
allocated = bitcast(elementPtrType, allocated);
Value *desc = undef(bufferDescriptorType);
desc = insertvalue(bufferDescriptorType, desc, allocated,
positionAttr(rewriter, 0));
desc = insertvalue(bufferDescriptorType, desc, size,
positionAttr(rewriter, 1));
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
// BufferDeallocOp creates no value.
class BufferDeallocOpConversion : public LLVMOpLowering {
public:
explicit BufferDeallocOpConversion(MLIRContext *context,
LLVMTypeConverter &lowering_)
: LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
FuncOp freeFunc = module.lookupSymbol<FuncOp>("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
module.push_back(freeFunc);
}
// Get MLIR types for extracting element pointer.
auto deallocOp = cast<BufferDeallocOp>(op);
auto elementPtrTy = getPtrToElementType(
deallocOp.getOperand()->getType().cast<BufferType>(), lowering);
// Emit MLIR for buffer_dealloc.
edsc::ScopedContext context(rewriter, op->getLoc());
Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
positionAttr(rewriter, 0)));
llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
};
// BufferSizeOp creates a new `index` value.
class BufferSizeOpConversion : public LLVMOpLowering {
public:
BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
edsc::ScopedContext context(rewriter, op->getLoc());
rewriter.replaceOp(
op, {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))});
return matchSuccess();
}
};
// DimOp creates a new `index` value.
class DimOpConversion : public LLVMOpLowering {
public:
explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<linalg::DimOp>(op);
auto indexTy = lowering.convertType(rewriter.getIndexType());
edsc::ScopedContext context(rewriter, op->getLoc());
rewriter.replaceOp(
op,
{extractvalue(
indexTy, operands[0],
positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))});
return matchSuccess();
}
};
namespace {
// Common functionality for Linalg LoadOp and StoreOp conversion to the
// LLVM IR Dialect.
template <typename Op> class LoadStoreOpConversion : public LLVMOpLowering {
public:
explicit LoadStoreOpConversion(MLIRContext *context,
LLVMTypeConverter &lowering_)
: LLVMOpLowering(Op::getOperationName(), context, lowering_) {}
using Base = LoadStoreOpConversion<Op>;
// Compute the pointer to an element of the buffer underlying the view given
// current view indices. Use the base offset and strides stored in the view
// descriptor to emit IR iteratively computing the actual offset, followed by
// a getelementptr. This must be called under an edsc::ScopedContext.
Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter) const {
auto loadOp = cast<Op>(op);
auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering);
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {
return positionAttr(rewriter, values);
};
// Linearize subscripts as:
// base_offset + SUM_i index_i * stride_i.
Value *base = extractvalue(elementTy, viewDescriptor, pos(0));
Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1));
for (int i = 0, e = loadOp.getRank(); i < e; ++i) {
Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i}));
Value *additionalOffset = mul(indices[i], stride);
offset = add(offset, additionalOffset);
}
return gep(elementTy, base, offset);
}
};
} // namespace
// A load is converted into the actual address computation, getelementptr and
// an LLVM IR load.
class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
using Base::Base;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
auto elementTy = lowering.convertType(*op->result_type_begin());
Value *viewDescriptor = operands[0];
ArrayRef<Value *> indices = operands.drop_front();
auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
rewriter.replaceOp(op, {llvm_load(elementTy, ptr)});
return matchSuccess();
}
};
// RangeOp creates a new range descriptor.
class RangeOpConversion : public LLVMOpLowering {
public:
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy =
convertLinalgType(rangeOp.getResult()->getType(), lowering);
edsc::ScopedContext context(rewriter, op->getLoc());
// Fill in an aggregate value of the descriptor.
Value *desc = undef(rangeDescriptorTy);
desc = insertvalue(rangeDescriptorTy, desc, operands[0],
positionAttr(rewriter, 0));
desc = insertvalue(rangeDescriptorTy, desc, operands[1],
positionAttr(rewriter, 1));
desc = insertvalue(rangeDescriptorTy, desc, operands[2],
positionAttr(rewriter, 2));
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
class SliceOpConversion : public LLVMOpLowering {
public:
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto sliceOp = cast<SliceOp>(op);
auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
auto viewType = sliceOp.getBaseViewType();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Helper function to create an integer array attribute out of a list of
// values.
auto pos = [&rewriter](ArrayRef<int> values) {
return positionAttr(rewriter, values);
};
// Helper function to obtain the ptr of the given `view`.
auto getViewPtr = [pos, this](ViewType type, Value *view) -> Value * {
auto elementPtrTy = getPtrToElementType(type, lowering);
return extractvalue(elementPtrTy, view, pos(0));
};
edsc::ScopedContext context(rewriter, op->getLoc());
// Declare the view descriptor and insert data ptr.
Value *desc = undef(viewDescriptorTy);
desc = insertvalue(viewDescriptorTy, desc,
getViewPtr(viewType, operands[0]), pos(0));
// TODO(ntv): extract sizes and emit asserts.
SmallVector<Value *, 4> strides(viewType.getRank());
for (int dim = 0, e = viewType.getRank(); dim < e; ++dim) {
strides[dim] = extractvalue(int64Ty, operands[0], pos({3, dim}));
}
// Compute and insert base offset.
Value *baseOffset = extractvalue(int64Ty, operands[0], pos(1));
for (int j = 0, e = viewType.getRank(); j < e; ++j) {
Value *indexing = operands[1 + j];
Value *min =
sliceOp.getIndexing(j)->getType().isa<RangeType>()
? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0)))
: indexing;
Value *product = mul(min, strides[j]);
baseOffset = add(baseOffset, product);
}
desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
// Compute and insert view sizes (max - min along the range). Skip the
// non-range operands as they will be projected away from the view.
int i = 0;
for (Value *index : sliceOp.getIndexings()) {
if (!index->getType().isa<RangeType>())
continue;
Value *rangeDescriptor = operands[1 + i];
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
Value *size = sub(max, min);
desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
++i;
}
// Compute and insert view strides. Step over the strides that correspond
// to non-range operands as they are projected away from the view.
i = 0;
for (int j = 0, e = strides.size(); j < e; ++j) {
if (!sliceOp.getIndexing(j)->getType().isa<RangeType>())
continue;
Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
Value *stride = mul(strides[j], step);
desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
++i;
}
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
// A store is converted into the actual address computation, getelementptr and
// an LLVM IR store.
class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
using Base::Base;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
Value *data = operands[0];
Value *viewDescriptor = operands[1];
ArrayRef<Value *> indices = operands.drop_front(2);
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
llvm_store(data, ptr);
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
};
class ViewOpConversion : public LLVMOpLowering {
public:
explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto viewOp = cast<ViewOp>(op);
auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {
return positionAttr(rewriter, values);
};
// First operand to `view` is the buffer descriptor.
Value *bufferDescriptor = operands[0];
// Declare the descriptor of the view.
edsc::ScopedContext context(rewriter, op->getLoc());
Value *desc = undef(viewDescriptorTy);
// Copy the buffer pointer from the old descriptor to the new one.
Value *buffer = extractvalue(elementTy, bufferDescriptor, pos(0));
desc = insertvalue(viewDescriptorTy, desc, buffer, pos(0));
// Zero base offset.
auto indexTy = rewriter.getIndexType();
Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
// Compute and insert view sizes (max - min along the range).
int numRanges = llvm::size(viewOp.ranges());
Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
for (int i = numRanges - 1; i >= 0; --i) {
// Update stride.
Value *rangeDescriptor = operands[1 + i];
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
Value *stride = mul(runningStride, step);
desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
// Update size.
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
Value *size = sub(max, min);
desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
// Update stride for the next dimension.
if (i > 0)
runningStride = mul(runningStride, max);
}
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
// Create a function definition which takes as argument pointers to the input
// types and returns pointers to the output types.
static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) {
auto implFnName = (libFn.getName().str() + "_impl");
auto module = libFn.getParentOfType<ModuleOp>();
if (auto f = module.lookupSymbol<FuncOp>(implFnName)) {
return f;
}
SmallVector<Type, 4> fnArgTypes;
for (auto t : libFn.getType().getInputs()) {
assert(t && t.isa<LLVMType>() &&
"Expected LLVM Type for argument while generating library Call "
"Implementation Definition");
fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
}
auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext());
// Insert the implementation function definition.
auto implFnDefn = FuncOp::create(libFn.getLoc(), implFnName, implFnType);
module.push_back(implFnDefn);
return implFnDefn;
}
// Get function definition for the LinalgOp. If it doesn't exist, insert a
// definition.
template <typename LinalgOp>
static FuncOp
getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering,
ConversionPatternRewriter &rewriter) {
auto linalgOp = cast<LinalgOp>(op);
auto fnName = linalgOp.getLibraryCallName();
if (fnName.empty()) {
op->emitWarning("No library call defined for: ") << *op;
return FuncOp();
}
auto module = op->getParentOfType<ModuleOp>();
if (auto f = module.lookupSymbol<FuncOp>(fnName)) {
return f;
}
// Get the Function type consistent with LLVM Lowering.
SmallVector<Type, 4> inputTypes;
for (auto operand : op->getOperands())
inputTypes.push_back(lowering.convertType(operand->getType()));
assert(op->getNumResults() == 0 &&
"Library call for linalg operation can be generated only for ops that "
"have void return types");
auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
auto libFn = FuncOp::create(op->getLoc(), fnName, libFnType);
module.push_back(libFn);
// Return after creating the function definition. The body will be created
// later.
return libFn;
}
static void getLLVMLibraryCallDefinition(FuncOp fn,
LLVMTypeConverter &lowering) {
// Generate the implementation function definition.
auto implFn = getLLVMLibraryCallImplDefinition(fn);
// Generate the function body.
OpBuilder builder(fn.addEntryBlock());
edsc::ScopedContext scope(builder, fn.getLoc());
SmallVector<Value *, 4> implFnArgs;
// Create a constant 1.
auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()),
IntegerAttr::get(IndexType::get(fn.getContext()), 1));
for (auto arg : fn.getArguments()) {
// Allocate a stack for storing the argument value. The stack is passed to
// the implementation function.
auto alloca =
llvm_alloca(arg->getType().cast<LLVMType>().getPointerTo(), one)
.getValue();
implFnArgs.push_back(alloca);
llvm_store(arg, alloca);
}
llvm_call(ArrayRef<Type>(), builder.getSymbolRefAttr(implFn), implFnArgs);
llvm_return{ArrayRef<Value *>()};
}
namespace {
// The conversion class from Linalg to LLVMIR.
class LinalgTypeConverter : public LLVMTypeConverter {
using LLVMTypeConverter::LLVMTypeConverter;
public:
Type convertType(Type t) override {
if (auto result = LLVMTypeConverter::convertType(t))
return result;
return convertLinalgType(t, *this);
}
void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.insert(fn); }
ArrayRef<FuncOp> getLibraryFnDeclarations() {
return libraryFnDeclarations.getArrayRef();
}
private:
/// List of library functions declarations needed during dialect conversion
llvm::SetVector<FuncOp> libraryFnDeclarations;
};
} // end anonymous namespace
// LinalgOpConversion<LinalgOp> creates a new call to the
// `LinalgOp::getLibraryCallName()` function.
// The implementation of the function can be either in the same module or in an
// externally linked library.
template <typename LinalgOp> class LinalgOpConversion : public LLVMOpLowering {
public:
explicit LinalgOpConversion(MLIRContext *context,
LinalgTypeConverter &lowering_)
: LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
// Only emit library call declaration. Fill in the body later.
auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
if (!f)
return matchFailure();
static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
auto fAttr = rewriter.getSymbolRefAttr(f);
auto named = rewriter.getNamedAttr("callee", fAttr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
ArrayRef<NamedAttribute>{named});
return matchSuccess();
}
};
/// Populate the given list with patterns that convert from Linalg to LLVM.
static void
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns
.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, DimOpConversion,
LinalgOpConversion<DotOp>, LinalgOpConversion<FillOp>,
LinalgOpConversion<MatmulOp>, LoadOpConversion, RangeOpConversion,
SliceOpConversion, StoreOpConversion, ViewOpConversion>(
ctx, converter);
}
namespace {
struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
void runOnModule();
};
} // namespace
// This is currently written as a standalone function because the lowering to
// affine will look different than lowering to LLVM and it is still unclear how
// everything will be eventually structured.
static void lowerLinalgSubViewOps(FuncOp &f) {
f.walk<SubViewOp>([&](SubViewOp op) {
OpBuilder b(op);
ScopedContext scope(b, op.getLoc());
auto *view = op.getView();
SmallVector<Value *, 8> ranges;
for (auto en : llvm::enumerate(op.getRanges())) {
using edsc::op::operator<;
using linalg::intrinsics::dim;
unsigned rank = en.index();
auto sliceRange = en.value();
auto size = dim(view, rank);
ValueHandle ub(sliceRange.max);
auto max = edsc::intrinsics::select(size < ub, size, ub);
ranges.push_back(range(sliceRange.min, max, sliceRange.step));
}
op.replaceAllUsesWith(slice(view, ranges));
op.erase();
});
}
void LowerLinalgToLLVMPass::runOnModule() {
auto module = getModule();
for (auto f : module.getOps<FuncOp>())
lowerLinalgSubViewOps(f);
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LinalgTypeConverter converter(&getContext());
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
if (failed(applyPartialConversion(module, target, patterns, &converter))) {
signalPassFailure();
}
// Emit the function body of any Library function that was declared.
for (auto fn : converter.getLibraryFnDeclarations()) {
getLLVMLibraryCallDefinition(fn, converter);
}
}
ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() {
return new LowerLinalgToLLVMPass();
}
static PassRegistration<LowerLinalgToLLVMPass>
pass("linalg-lower-to-llvm-dialect",
"Lower the operations from the linalg dialect into the LLVM dialect");