blob: 791164931e13f3a564a208be78aa14fa0f3e7d8e [file] [log] [blame]
//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements late lowering of IR mixing Toy and Linalg to LLVM.
// It involves intemerdiate steps:
// -
// - a mix of affine and standard dialect.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "linalg1/Intrinsics.h"
#include "linalg1/ViewOp.h"
#include "linalg3/ConvertToLLVMDialect.h"
#include "linalg3/TensorOps.h"
#include "linalg3/Transforms.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
#include <algorithm>
using namespace mlir;
namespace {
/// Utility function for type casting: this is making the type checker happy,
/// while delaying the actual work involved to convert the type. Most of the
/// time both side of the cast (producer and consumer) will be lowered to a
/// dialect like LLVM and end up with the same LLVM representation, at which
/// point this becomes a no-op and is eliminated.
Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) {
if (val->getType() == destTy)
return val;
return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
.getResult();
}
/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
/// lowered to a memref during buffer allocation, at which point the type cast
/// becomes useless.
Value *memRefTypeCast(FuncBuilder &builder, Value *val) {
if (val->getType().isa<MemRefType>())
return val;
auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
if (!toyArrayTy)
return val;
return typeCast(builder, val, toyArrayTy.toMemref());
}
/// Lower a toy.add to an affine loop nest.
///
/// This class inherit from `DialectConversionPattern` and override `rewrite`,
/// similarly to the PatternRewriter introduced in the previous chapter.
/// It will be called by the DialectConversion framework (see `LateLowering`
/// class below).
class AddOpConversion : public DialectConversionPattern {
public:
explicit AddOpConversion(MLIRContext *context)
: DialectConversionPattern(toy::AddOp::getOperationName(), 1, context) {}
/// Lower the `op` by generating IR using the `rewriter` builder. The builder
/// is setup with a new function, the `operands` array has been populated with
/// the rewritten operands for `op` in the new function.
/// The results created by the new IR with the builder are returned, and their
/// number must match the number of result of `op`.
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto add = cast<toy::AddOp>(op);
auto loc = add.getLoc();
// Create a `toy.alloc` operation to allocate the output buffer for this op.
Value *result = memRefTypeCast(
rewriter, rewriter.create<toy::AllocOp>(loc, add.getResult()->getType())
.getResult());
Value *lhs = memRefTypeCast(rewriter, operands[0]);
Value *rhs = memRefTypeCast(rewriter, operands[1]);
using namespace edsc;
ScopedContext scope(rewriter, loc);
ValueHandle zero = intrinsics::constant_index(0);
MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
IndexHandle i, j, M(vRes.ub(0));
if (vRes.rank() == 1) {
LoopNestBuilder({&i}, {zero}, {M},
{1})([&] { iRes(i) = iLHS(i) + iRHS(i); });
} else {
assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now");
IndexHandle N(vRes.ub(1));
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N},
{1, 1})([&] { iRes(i, j) = iLHS(i, j) + iRHS(i, j); });
}
// Return the newly allocated buffer, with a type.cast to preserve the
// consumers.
rewriter.replaceOp(op, {typeCast(rewriter, result, add.getType())});
}
};
/// Lowers `toy.print` to a loop nest calling `printf` on every individual
/// elements of the array.
class PrintOpConversion : public DialectConversionPattern {
public:
explicit PrintOpConversion(MLIRContext *context)
: DialectConversionPattern(toy::PrintOp::getOperationName(), 1, context) {
}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Get or create the declaration of the printf function in the module.
Function *printfFunc = getPrintf(*op->getFunction()->getModule());
auto print = cast<toy::PrintOp>(op);
auto loc = print.getLoc();
// We will operate on a MemRef abstraction, we use a type.cast to get one
// if our operand is still a Toy array.
Value *operand = memRefTypeCast(rewriter, operands[0]);
Type retTy = printfFunc->getType().getResult(0);
// Create our loop nest now
using namespace edsc;
using llvmCall = intrinsics::ValueBuilder<LLVM::CallOp>;
ScopedContext scope(rewriter, loc);
ValueHandle zero = intrinsics::constant_index(0);
ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f "));
MemRefView vOp(operand);
IndexedValue iOp(operand);
IndexHandle i, j, M(vOp.ub(0));
ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
if (vOp.rank() == 1) {
// clang-format off
LoopBuilder(&i, zero, M, 1)([&]{
llvmCall(retTy,
rewriter.getFunctionAttr(printfFunc),
{fmtCst, iOp(i)});
});
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
// clang-format on
} else {
IndexHandle N(vOp.ub(1));
// clang-format off
LoopBuilder(&i, zero, M, 1)([&]{
LoopBuilder(&j, zero, N, 1)([&]{
llvmCall(retTy,
rewriter.getFunctionAttr(printfFunc),
{fmtCst, iOp(i, j)});
});
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
});
// clang-format on
}
rewriter.replaceOp(op, llvm::None);
}
private:
// Turn a string into a toy.alloc (malloc/free abstraction) and a sequence
// of stores into the buffer, and return a MemRef into the buffer.
Value *getConstantCharBuffer(FuncBuilder &builder, Location loc,
StringRef data) const {
auto retTy =
builder.getMemRefType(data.size() + 1, builder.getIntegerType(8));
Value *result = builder.create<toy::AllocOp>(loc, retTy).getResult();
using namespace edsc;
using intrinsics::constant_index;
using intrinsics::constant_int;
ScopedContext scope(builder, loc);
MemRefView vOp(result);
IndexedValue iOp(result);
for (uint64_t i = 0; i < data.size(); ++i) {
iOp(constant_index(i)) = constant_int(data[i], 8);
}
iOp(constant_index(data.size())) = constant_int(0, 8);
return result;
}
/// Return the prototype declaration for printf in the module, create it if
/// necessary.
Function *getPrintf(Module &module) const {
auto *printfFunc = module.getNamedFunction("printf");
if (printfFunc)
return printfFunc;
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
Builder builder(&module);
MLIRContext *context = module.getContext();
auto *llvmDialect =
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto &llvmModule = llvmDialect->getLLVMModule();
llvm::IRBuilder<> llvmBuilder(llvmModule.getContext());
auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32));
auto llvmI8PtrTy =
LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo());
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy);
// It should be variadic, but we don't support it fully just yet.
printfFunc->setAttr("std.varargs", builder.getBoolAttr(true));
module.getFunctions().push_back(printfFunc);
return printfFunc;
}
};
/// Lowers constant to a sequence of store in a buffer.
class ConstantOpConversion : public DialectConversionPattern {
public:
explicit ConstantOpConversion(MLIRContext *context)
: DialectConversionPattern(toy::ConstantOp::getOperationName(), 1,
context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
toy::ConstantOp cstOp = cast<toy::ConstantOp>(op);
auto loc = cstOp.getLoc();
auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
auto shape = retTy.getShape();
Value *result = memRefTypeCast(
rewriter, rewriter.create<toy::AllocOp>(loc, retTy).getResult());
auto cstValue = cstOp.getValue();
auto f64Ty = rewriter.getF64Type();
using namespace edsc;
using intrinsics::constant_float;
using intrinsics::constant_index;
ScopedContext scope(rewriter, loc);
MemRefView vOp(result);
IndexedValue iOp(result);
for (uint64_t i = 0, ie = shape[0]; i < ie; ++i) {
if (shape.size() == 1) {
auto value = cstValue.getValue(ArrayRef<uint64_t>{i})
.cast<FloatAttr>()
.getValue();
iOp(constant_index(i)) = constant_float(value, f64Ty);
continue;
}
for (uint64_t j = 0, je = shape[1]; j < je; ++j) {
auto value = cstValue.getValue(ArrayRef<uint64_t>{i, j})
.cast<FloatAttr>()
.getValue();
iOp(constant_index(i), constant_index(j)) =
constant_float(value, f64Ty);
}
}
rewriter.replaceOp(op, result);
}
};
/// Lower transpose operation to an affine loop nest.
class TransposeOpConversion : public DialectConversionPattern {
public:
explicit TransposeOpConversion(MLIRContext *context)
: DialectConversionPattern(toy::TransposeOp::getOperationName(), 1,
context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto transpose = cast<toy::TransposeOp>(op);
auto loc = transpose.getLoc();
Value *result = memRefTypeCast(
rewriter,
rewriter.create<toy::AllocOp>(loc, transpose.getResult()->getType())
.getResult());
Value *operand = memRefTypeCast(rewriter, operands[0]);
using namespace edsc;
ScopedContext scope(rewriter, loc);
ValueHandle zero = intrinsics::constant_index(0);
MemRefView vRes(result), vOperand(operand);
IndexedValue iRes(result), iOperand(operand);
IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1));
// clang-format off
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
iRes(i, j) = iOperand(j, i);
});
// clang-format on
rewriter.replaceOp(op, {typeCast(rewriter, result, transpose.getType())});
}
};
// Lower toy.return to standard return operation.
class ReturnOpConversion : public DialectConversionPattern {
public:
explicit ReturnOpConversion(MLIRContext *context)
: DialectConversionPattern(toy::ReturnOp::getOperationName(), 1,
context) {}
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Argument is optional, handle both cases.
if (op->getNumOperands())
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands[0]);
else
rewriter.replaceOpWithNewOp<ReturnOp>(op);
}
};
/// This is the main class registering our individual converter classes with
/// the DialectConversion framework in MLIR.
class LateLowering : public DialectConversion {
protected:
/// Initialize the list of converters.
void initConverters(OwningRewritePatternList &patterns,
MLIRContext *context) override {
RewriteListBuilder<AddOpConversion, PrintOpConversion, ConstantOpConversion,
TransposeOpConversion,
ReturnOpConversion>::build(patterns, context);
}
/// Convert a Toy type, this gets called for block and region arguments, and
/// attributes.
Type convertType(Type t) override {
if (auto array = t.dyn_cast<toy::ToyArrayType>())
return array.toMemref();
return t;
}
};
/// This is lowering to Linalg the parts that can be (matmul and add on arrays)
/// and is targeting LLVM otherwise.
struct LateLoweringPass : public ModulePass<LateLoweringPass> {
void runOnModule() override {
// Perform Toy specific lowering
if (failed(LateLowering().convert(&getModule()))) {
getModule().getContext()->emitError(
UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n");
signalPassFailure();
}
// At this point the IR is almost using only standard and affine dialects.
// A few things remain before we emit LLVM IR. First to reuse as much of
// MLIR as possible we will try to lower everything to the standard and/or
// affine dialect: they already include conversion to the LLVM dialect.
// First patch calls type to return memref instead of ToyArray
for (auto &function : getModule()) {
function.walk([&](Operation *op) {
auto callOp = dyn_cast<CallOp>(op);
if (!callOp)
return;
if (!callOp.getNumResults())
return;
auto retToyTy =
callOp.getResult(0)->getType().dyn_cast<toy::ToyArrayType>();
if (!retToyTy)
return;
callOp.getResult(0)->setType(retToyTy.toMemref());
});
}
for (auto &function : getModule()) {
function.walk([&](Operation *op) {
// Turns toy.alloc into sequence of alloc/dealloc (later malloc/free).
if (auto allocOp = dyn_cast<toy::AllocOp>(op)) {
auto result = allocTensor(allocOp);
allocOp.replaceAllUsesWith(result);
allocOp.erase();
return;
}
// Eliminate all type.cast before lowering to LLVM.
if (auto typeCastOp = dyn_cast<toy::TypeCastOp>(op)) {
typeCastOp.replaceAllUsesWith(typeCastOp.getOperand());
typeCastOp.erase();
return;
}
});
}
// Lower Linalg to affine
for (auto &function : getModule())
linalg::lowerToLoops(&function);
getModule().dump();
// Finally convert to LLVM Dialect
linalg::convertLinalg3ToLLVM(getModule());
}
/// Allocate buffers (malloc/free) for Toy operations. This can't be done as
/// part of dialect conversion framework since we need to insert `dealloc`
/// operations just before the return, but the conversion framework is
/// operating in a brand new function: we don't have the return to hook the
/// dealloc operations.
Value *allocTensor(toy::AllocOp alloc) {
FuncBuilder builder(alloc);
auto retTy = alloc.getResult()->getType();
auto memRefTy = retTy.dyn_cast<MemRefType>();
if (!memRefTy)
memRefTy = retTy.cast<toy::ToyArrayType>().toMemref();
if (!memRefTy) {
alloc.emitOpError("is expected to allocate a Toy array or a MemRef");
llvm_unreachable("fatal error");
}
auto loc = alloc.getLoc();
Value *result = builder.create<AllocOp>(loc, memRefTy).getResult();
// Insert a `dealloc` operation right before the `return` operations, unless
// it is returned itself in which case the caller is responsible for it.
builder.getFunction()->walk([&](Operation *op) {
auto returnOp = dyn_cast<ReturnOp>(op);
if (!returnOp)
return;
if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc)
return;
builder.setInsertionPoint(returnOp);
builder.create<DeallocOp>(alloc.getLoc(), result);
});
return result;
}
};
} // end anonymous namespace
namespace toy {
Pass *createLateLoweringPass() { return new LateLoweringPass(); }
std::unique_ptr<DialectConversion> makeToyLateLowering() {
return llvm::make_unique<LateLowering>();
}
} // namespace toy