blob: 554e3cb47a983b1f8ff67b931ec3eeaef381092d [file] [log] [blame]
//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements mlir::applyPatternsGreedily.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
namespace {
class WorklistRewriter;
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
/// applies the locally optimal patterns in a roughly "bottom up" way.
class GreedyPatternRewriteDriver {
public:
explicit GreedyPatternRewriteDriver(OwningRewritePatternList &&patterns)
: matcher(std::move(patterns)) {
worklist.reserve(64);
}
void simplifyFunction(Function *currentFunction, WorklistRewriter &rewriter);
void addToWorklist(Operation *op) {
// Check to see if the worklist already contains this op.
if (worklistMap.count(op))
return;
worklistMap[op] = worklist.size();
worklist.push_back(op);
}
Operation *popFromWorklist() {
auto *op = worklist.back();
worklist.pop_back();
// This operation is no longer in the worklist, keep worklistMap up to date.
if (op)
worklistMap.erase(op);
return op;
}
/// If the specified operation is in the worklist, remove it. If not, this is
/// a no-op.
void removeFromWorklist(Operation *op) {
auto it = worklistMap.find(op);
if (it != worklistMap.end()) {
assert(worklist[it->second] == op && "malformed worklist data structure");
worklist[it->second] = nullptr;
}
}
private:
/// The low-level pattern matcher.
PatternMatcher matcher;
/// The worklist for this transformation keeps track of the operations that
/// need to be revisited, plus their index in the worklist. This allows us to
/// efficiently remove operations from the worklist when they are removed even
/// if they aren't the root of a pattern.
std::vector<Operation *> worklist;
DenseMap<Operation *, unsigned> worklistMap;
/// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for.
DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
};
}; // end anonymous namespace
/// This is a listener object that updates our worklists and other data
/// structures in response to operations being added and removed.
namespace {
class WorklistRewriter : public PatternRewriter {
public:
WorklistRewriter(GreedyPatternRewriteDriver &driver, MLIRContext *context)
: PatternRewriter(context), driver(driver) {}
virtual void setInsertionPoint(Operation *op) = 0;
// If an operation is about to be removed, make sure it is not in our
// worklist anymore because we'd get dangling references to it.
void notifyOperationRemoved(Operation *op) override {
driver.removeFromWorklist(op);
}
// When the root of a pattern is about to be replaced, it can trigger
// simplifications to its users - make sure to add them to the worklist
// before the root is changed.
void notifyRootReplaced(Operation *op) override {
for (auto *result : op->getResults())
// TODO: Add a result->getUsers() iterator.
for (auto &user : result->getUses()) {
if (auto *op = dyn_cast<Operation>(user.getOwner()))
driver.addToWorklist(op);
}
// TODO: Walk the operand list dropping them as we go. If any of them
// drop to zero uses, then add them to the worklist to allow them to be
// deleted as dead.
}
GreedyPatternRewriteDriver &driver;
};
} // end anonymous namespace
void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
WorklistRewriter &rewriter) {
// These are scratch vectors used in the constant folding loop below.
SmallVector<Attribute, 8> operandConstants, resultConstants;
while (!worklist.empty()) {
auto *op = popFromWorklist();
// Nulls get added to the worklist when operations are removed, ignore them.
if (op == nullptr)
continue;
// If we have a constant op, unique it into the entry block.
if (auto constant = op->dyn_cast<ConstantOp>()) {
// If this constant is dead, remove it, being careful to keep
// uniquedConstants up to date.
if (constant->use_empty()) {
auto it =
uniquedConstants.find({constant->getValue(), constant->getType()});
if (it != uniquedConstants.end() && it->second == op)
uniquedConstants.erase(it);
constant->erase();
continue;
}
// Check to see if we already have a constant with this type and value:
auto &entry = uniquedConstants[std::make_pair(constant->getValue(),
constant->getType())];
if (entry) {
// If this constant is already our uniqued one, then leave it alone.
if (entry == op)
continue;
// Otherwise replace this redundant constant with the uniqued one. We
// know this is safe because we move constants to the top of the
// function when they are uniqued, so we know they dominate all uses.
constant->replaceAllUsesWith(entry->getResult(0));
constant->erase();
continue;
}
// If we have no entry, then we should unique this constant as the
// canonical version. To ensure safe dominance, move the operation to the
// top of the function.
entry = op;
// TODO: If we make terminators into Operations then we could turn this
// into a nice Operation::moveBefore(Operation*) method. We just need the
// guarantee that a block is non-empty.
if (auto *cfgFunc = dyn_cast<CFGFunction>(currentFunction)) {
auto &entryBB = cfgFunc->front();
cast<Instruction>(op)->moveBefore(&entryBB, entryBB.begin());
} else {
auto *mlFunc = cast<MLFunction>(currentFunction);
cast<OperationStmt>(op)->moveBefore(mlFunc, mlFunc->begin());
}
continue;
}
// If the operation has no side effects, and no users, then it is trivially
// dead - remove it.
if (op->hasNoSideEffect() && op->use_empty()) {
op->erase();
continue;
}
// Check to see if any operands to the instruction is constant and whether
// the operation knows how to constant fold itself.
operandConstants.clear();
for (auto *operand : op->getOperands()) {
Attribute operandCst;
if (auto *operandOp = operand->getDefiningOperation()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();
}
operandConstants.push_back(operandCst);
}
// If constant folding was successful, create the result constants, RAUW the
// operation and remove it.
resultConstants.clear();
if (!op->constantFold(operandConstants, resultConstants)) {
rewriter.setInsertionPoint(op);
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
auto *res = op->getResult(i);
if (res->use_empty()) // ignore dead uses.
continue;
// If we already have a canonicalized version of this constant, just
// reuse it. Otherwise create a new one.
SSAValue *cstValue;
auto it = uniquedConstants.find({resultConstants[i], res->getType()});
if (it != uniquedConstants.end())
cstValue = it->second->getResult(0);
else
cstValue = rewriter.create<ConstantOp>(
op->getLoc(), resultConstants[i], res->getType());
// Add all the users of the result to the worklist so we make sure to
// revisit them.
//
// TODO: Add a result->getUsers() iterator.
for (auto &operand : op->getResult(i)->getUses()) {
if (auto *op = dyn_cast<Operation>(operand.getOwner()))
addToWorklist(op);
}
res->replaceAllUsesWith(cstValue);
}
assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
op->erase();
continue;
}
// If this is a commutative binary operation with a constant on the left
// side move it to the right side.
if (operandConstants.size() == 2 && operandConstants[0] &&
!operandConstants[1] && op->isCommutative()) {
auto *newLHS = op->getOperand(1);
op->setOperand(1, op->getOperand(0));
op->setOperand(0, newLHS);
}
// Check to see if we have any patterns that match this node.
auto match = matcher.findMatch(op);
if (!match.first)
continue;
// Make sure that any new operations are inserted at this point.
rewriter.setInsertionPoint(op);
// We know that any pattern that matched is RewritePattern because we
// initialized the matcher with RewritePatterns.
auto *rewritePattern = static_cast<RewritePattern *>(match.first);
rewritePattern->rewrite(op, std::move(match.second), rewriter);
}
uniquedConstants.clear();
}
static void processMLFunction(MLFunction *fn,
OwningRewritePatternList &&patterns) {
class MLFuncRewriter : public WorklistRewriter {
public:
MLFuncRewriter(GreedyPatternRewriteDriver &driver, MLFuncBuilder &builder)
: WorklistRewriter(driver, builder.getContext()), builder(builder) {}
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
Operation *createOperation(const OperationState &state) override {
auto *result = builder.createOperation(state);
driver.addToWorklist(result);
return result;
}
void setInsertionPoint(Operation *op) override {
// Any new operations should be added before this statement.
builder.setInsertionPoint(cast<OperationStmt>(op));
}
private:
MLFuncBuilder &builder;
};
GreedyPatternRewriteDriver driver(std::move(patterns));
fn->walk([&](OperationStmt *stmt) { driver.addToWorklist(stmt); });
MLFuncBuilder mlBuilder(fn);
MLFuncRewriter rewriter(driver, mlBuilder);
driver.simplifyFunction(fn, rewriter);
}
static void processCFGFunction(CFGFunction *fn,
OwningRewritePatternList &&patterns) {
class CFGFuncRewriter : public WorklistRewriter {
public:
CFGFuncRewriter(GreedyPatternRewriteDriver &driver, CFGFuncBuilder &builder)
: WorklistRewriter(driver, builder.getContext()), builder(builder) {}
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
Operation *createOperation(const OperationState &state) override {
auto *result = builder.createOperation(state);
driver.addToWorklist(result);
return result;
}
void setInsertionPoint(Operation *op) override {
// Any new operations should be added before this instruction.
builder.setInsertionPoint(cast<Instruction>(op));
}
private:
CFGFuncBuilder &builder;
};
GreedyPatternRewriteDriver driver(std::move(patterns));
for (auto &bb : *fn)
for (auto &op : bb)
driver.addToWorklist(&op);
CFGFuncBuilder cfgBuilder(fn);
CFGFuncRewriter rewriter(driver, cfgBuilder);
driver.simplifyFunction(fn, rewriter);
}
/// Rewrite the specified function by repeatedly applying the highest benefit
/// patterns in a greedy work-list driven manner.
///
void mlir::applyPatternsGreedily(Function *fn,
OwningRewritePatternList &&patterns) {
if (auto *cfg = dyn_cast<CFGFunction>(fn)) {
processCFGFunction(cfg, std::move(patterns));
} else {
processMLFunction(cast<MLFunction>(fn), std::move(patterns));
}
}