blob: b2d9486996c25e843098e110b52222db3fb1fe04 [file] [log] [blame]
//===- LowerIfAndFor.cpp - Lower If and For instructions to CFG -----------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file lowers If and For instructions within a function into their lower
// level CFG equivalent blocks.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
namespace {
class LowerIfAndForPass : public FunctionPass {
public:
LowerIfAndForPass() : FunctionPass(&passID) {}
PassResult runOnFunction(Function *function) override;
void lowerForInst(ForInst *forInst);
void lowerIfInst(IfInst *ifInst);
static char passID;
};
} // end anonymous namespace
char LowerIfAndForPass::passID = 0;
// Given a range of values, emit the code that reduces them with "min" or "max"
// depending on the provided comparison predicate. The predicate defines which
// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
// `cmpi` operation followed by the `select` operation:
//
// %cond = cmpi "predicate" %v0, %v1
// %result = select %cond, %v0, %v1
//
// Multiple values are scanned in a linear sequence. This creates a data
// dependences that wouldn't exist in a tree reduction, but is easier to
// recognize as a reduction by the subsequent passes.
static Value *buildMinMaxReductionSeq(
Location loc, CmpIPredicate predicate,
llvm::iterator_range<OperationInst::result_iterator> values,
FuncBuilder &builder) {
assert(!llvm::empty(values) && "empty min/max chain");
auto valueIt = values.begin();
Value *value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
value = builder.create<SelectOp>(loc, cmpOp->getResult(), value, *valueIt);
}
return value;
}
// Convert a "for" loop to a flow of blocks.
//
// Create an SESE region for the loop (including its body) and append it to the
// end of the current region. The loop region consists of the initialization
// block that sets up the initial value of the loop induction variable (%iv) and
// computes the loop bounds that are loop-invariant in MLFunctions; the
// condition block that checks the exit condition of the loop; the body SESE
// region; and the end block that post-dominates the loop. The end block of the
// loop becomes the new end of the current SESE region. The body of the loop is
// constructed recursively after starting a new region (it may be, for example,
// a nested loop). Induction variable modification is appended to the body SESE
// region that always loops back to the condition block.
//
// +--------------------------------+
// | <code before the ForInst> |
// | <compute initial %iv value> |
// | br cond(%iv) |
// +--------------------------------+
// |
// -------| |
// | v v
// | +--------------------------------+
// | | cond(%iv): |
// | | <compare %iv to upper bound> |
// | | cond_br %r, body, end |
// | +--------------------------------+
// | | |
// | | -------------|
// | v |
// | +--------------------------------+ |
// | | body: | |
// | | <body contents> | |
// | | %new_iv =<add step to %iv> | |
// | | br cond(%new_iv) | |
// | +--------------------------------+ |
// | | |
// |----------- |--------------------
// v
// +--------------------------------+
// | end: |
// | <code after the ForInst> |
// +--------------------------------+
//
void LowerIfAndForPass::lowerForInst(ForInst *forInst) {
auto loc = forInst->getLoc();
// Start by splitting the block containing the 'for' into two parts. The part
// before will get the init code, the part after will be the end point.
auto *initBlock = forInst->getBlock();
auto *endBlock = initBlock->splitBlock(forInst);
// Create the condition block, with its argument for the loop induction
// variable. We set it up below.
auto *conditionBlock = new Block();
conditionBlock->insertBefore(endBlock);
auto *iv = conditionBlock->addArgument(IndexType::get(forInst->getContext()));
// Create the body block, moving the body of the forInst over to it.
auto *bodyBlock = new Block();
bodyBlock->insertBefore(endBlock);
auto *oldBody = forInst->getBody();
bodyBlock->getInstructions().splice(bodyBlock->begin(),
oldBody->getInstructions(),
oldBody->begin(), oldBody->end());
// The code in the body of the forInst now uses 'iv' as its indvar.
forInst->replaceAllUsesWith(iv);
// Append the induction variable stepping logic and branch back to the exit
// condition block. Construct an affine map f : (x -> x+step) and apply this
// map to the induction variable.
FuncBuilder builder(bodyBlock);
auto affStep = builder.getAffineConstantExpr(forInst->getStep());
auto affDim = builder.getAffineDimExpr(0);
auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {});
auto stepOp = builder.create<AffineApplyOp>(loc, affStepMap, iv);
builder.create<BranchOp>(loc, conditionBlock, stepOp->getResult(0));
// Now that the body block done, fill in the code to compute the bounds of the
// induction variable in the init block.
builder.setInsertionPointToEnd(initBlock);
// Compute loop bounds using an affine_apply.
SmallVector<Value *, 8> operands(forInst->getLowerBoundOperands());
auto lbAffineApply =
builder.create<AffineApplyOp>(loc, forInst->getLowerBoundMap(), operands);
Value *lowerBound = buildMinMaxReductionSeq(
loc, CmpIPredicate::SGT, lbAffineApply->getResults(), builder);
operands.assign(forInst->getUpperBoundOperands().begin(),
forInst->getUpperBoundOperands().end());
auto ubAffineApply =
builder.create<AffineApplyOp>(loc, forInst->getUpperBoundMap(), operands);
Value *upperBound = buildMinMaxReductionSeq(
loc, CmpIPredicate::SLT, ubAffineApply->getResults(), builder);
builder.create<BranchOp>(loc, conditionBlock, lowerBound);
// With the body block done, we can fill in the condition block.
builder.setInsertionPointToEnd(conditionBlock);
auto comparison =
builder.create<CmpIOp>(loc, CmpIPredicate::SLT, iv, upperBound);
builder.create<CondBranchOp>(loc, comparison, bodyBlock, ArrayRef<Value *>(),
endBlock, ArrayRef<Value *>());
// Ok, we're done!
forInst->erase();
}
// Convert an "if" instruction into a flow of basic blocks.
//
// Create an SESE region for the if instruction (including its "then" and
// optional "else" instruction blocks) and append it to the end of the current
// region. The conditional region consists of a sequence of condition-checking
// blocks that implement the short-circuit scheme, followed by a "then" SESE
// region and an "else" SESE region, and the continuation block that
// post-dominates all blocks of the "if" instruction. The flow of blocks that
// correspond to the "then" and "else" clauses are constructed recursively,
// enabling easy nesting of "if" instructions and if-then-else-if chains.
//
// +--------------------------------+
// | <code before the IfInst> |
// | %zero = constant 0 : index |
// | %v = affine_apply #expr1(%ops) |
// | %c = cmpi "sge" %v, %zero |
// | cond_br %c, %next, %else |
// +--------------------------------+
// | |
// | --------------|
// v |
// +--------------------------------+ |
// | next: | |
// | <repeat the check for expr2> | |
// | cond_br %c, %next2, %else | |
// +--------------------------------+ |
// | | |
// ... --------------|
// | <Per-expression checks> |
// v |
// +--------------------------------+ |
// | last: | |
// | <repeat the check for exprN> | |
// | cond_br %c, %then, %else | |
// +--------------------------------+ |
// | | |
// | --------------|
// v |
// +--------------------------------+ |
// | then: | |
// | <then contents> | |
// | br continue | |
// +--------------------------------+ |
// | |
// |---------- |-------------
// | V
// | +--------------------------------+
// | | else: |
// | | <else contents> |
// | | br continue |
// | +--------------------------------+
// | |
// ------| |
// v v
// +--------------------------------+
// | continue: |
// | <code after the IfInst> |
// +--------------------------------+
//
void LowerIfAndForPass::lowerIfInst(IfInst *ifInst) {
auto loc = ifInst->getLoc();
// Start by splitting the block containing the 'if' into two parts. The part
// before will contain the condition, the part after will be the continuation
// point.
auto *condBlock = ifInst->getBlock();
auto *continueBlock = condBlock->splitBlock(ifInst);
// Create a block for the 'then' code, inserting it between the cond and
// continue blocks. Move the instructions over from the IfInst and add a
// branch to the continuation point.
Block *thenBlock = new Block();
thenBlock->insertBefore(continueBlock);
auto *oldThen = ifInst->getThen();
thenBlock->getInstructions().splice(thenBlock->begin(),
oldThen->getInstructions(),
oldThen->begin(), oldThen->end());
FuncBuilder builder(thenBlock);
builder.create<BranchOp>(loc, continueBlock);
// Handle the 'else' block the same way, but we skip it if we have no else
// code.
Block *elseBlock = continueBlock;
if (auto *oldElse = ifInst->getElse()) {
elseBlock = new Block();
elseBlock->insertBefore(continueBlock);
elseBlock->getInstructions().splice(elseBlock->begin(),
oldElse->getInstructions(),
oldElse->begin(), oldElse->end());
builder.setInsertionPointToEnd(elseBlock);
builder.create<BranchOp>(loc, continueBlock);
}
// Ok, now we just have to handle the condition logic.
auto integerSet = ifInst->getCondition().getIntegerSet();
// Implement short-circuit logic. For each affine expression in the 'if'
// condition, convert it into an affine map and call `affine_apply` to obtain
// the resulting value. Perform the equality or the greater-than-or-equality
// test between this value and zero depending on the equality flag of the
// condition. If the test fails, jump immediately to the false branch, which
// may be the else block if it is present or the continuation block otherwise.
// If the test succeeds, jump to the next block testing the next conjunct of
// the condition in the similar way. When all conjuncts have been handled,
// jump to the 'then' block instead.
builder.setInsertionPointToEnd(condBlock);
Value *zeroConstant = builder.create<ConstantIndexOp>(loc, 0);
for (auto tuple :
llvm::zip(integerSet.getConstraints(), integerSet.getEqFlags())) {
AffineExpr constraintExpr = std::get<0>(tuple);
bool isEquality = std::get<1>(tuple);
// Create the fall-through block for the next condition right before the
// 'thenBlock'.
auto *nextBlock = new Block();
nextBlock->insertBefore(thenBlock);
// Build and apply an affine map.
auto affineMap =
builder.getAffineMap(integerSet.getNumDims(),
integerSet.getNumSymbols(), constraintExpr, {});
SmallVector<Value *, 8> operands(ifInst->getOperands());
auto affineApplyOp =
builder.create<AffineApplyOp>(loc, affineMap, operands);
Value *affResult = affineApplyOp->getResult(0);
// Compare the result of the apply and branch.
auto comparisonOp = builder.create<CmpIOp>(
loc, isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, affResult,
zeroConstant);
builder.create<CondBranchOp>(loc, comparisonOp->getResult(), nextBlock,
/*trueArgs*/ ArrayRef<Value *>(), elseBlock,
/*falseArgs*/ ArrayRef<Value *>());
builder.setInsertionPointToEnd(nextBlock);
}
// We will have ended up with an empty block as our continuation block (or, in
// the degenerate case where there were zero conditions, we have the original
// condition block). Redirect that to the thenBlock.
condBlock = builder.getInsertionBlock();
if (condBlock->empty()) {
condBlock->replaceAllUsesWith(thenBlock);
condBlock->eraseFromFunction();
} else {
builder.create<BranchOp>(loc, thenBlock);
}
// Ok, we're done!
ifInst->erase();
}
// Entry point of the function convertor.
//
// Conversion is performed by recursively visiting instructions of a Function.
// It reasons in terms of single-entry single-exit (SESE) regions that are not
// materialized in the code. Instead, the pointer to the last block of the
// region is maintained throughout the conversion as the insertion point of the
// IR builder since we never change the first block after its creation. "Block"
// instructions such as loops and branches create new SESE regions for their
// bodies, and surround them with additional basic blocks for the control flow.
// Individual operations are simply appended to the end of the last basic block
// of the current region. The SESE invariant allows us to easily handle nested
// structures of arbitrary complexity.
//
// During the conversion, we maintain a mapping between the Values present in
// the original function and their Value images in the function under
// construction. When an Value is used, it gets replaced with the
// corresponding Value that has been defined previously. The value flow
// starts with function arguments converted to basic block arguments.
PassResult LowerIfAndForPass::runOnFunction(Function *function) {
SmallVector<Instruction *, 8> instsToRewrite;
// Collect all the If and For statements. We do this as a prepass to avoid
// invalidating the walker with our rewrite.
function->walkInsts([&](Instruction *inst) {
if (isa<IfInst>(inst) || isa<ForInst>(inst))
instsToRewrite.push_back(inst);
});
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
// so we know that we will rewrite them in the same order.
for (auto *inst : instsToRewrite)
if (auto *ifInst = dyn_cast<IfInst>(inst))
lowerIfInst(ifInst);
else
lowerForInst(cast<ForInst>(inst));
// Change the kind of the function to indicate it has no If's or For's.
function->setKind(Function::Kind::CFGFunc);
return success();
}
/// Lowers If and For instructions within a function into their lower level CFG
/// equivalent blocks.
FunctionPass *mlir::createLowerIfAndForPass() {
return new LowerIfAndForPass();
}
static PassRegistration<LowerIfAndForPass>
pass("lower-if-and-for",
"Lower If and For instructions to CFG equivalents");