blob: 4864f394c88d2b49872011a3a514432f4daa2e06 [file] [log] [blame]
//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the linalg dialect Fusion pass.
//
//===----------------------------------------------------------------------===//
#include "mlir/EDSC/Helpers.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Linalg/IR/LinalgOps.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
#include "mlir/Linalg/Passes.h"
#include "mlir/Linalg/Utils/Intrinsics.h"
#include "mlir/Linalg/Utils/Utils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "linalg-fusion"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using namespace mlir::linalg::intrinsics;
using llvm::dbgs;
/// Implements a simple high-level fusion pass of linalg library operations.
///
/// In each block, linalg ops are processed in reverse textual order.
/// Given a linalg op, fusion occurs by:
/// 1. tiling the op by a given multi-dimensional tile size;
/// 2. inspecting the linalg ops that write into the views read by the op in
/// step 1. This uses the SSA value of the views to determine producer-
/// consumer dependences: only identical SSA views are considered for
/// fusion at this point;
/// 3. greedily fuse the producing linalg ops into the consuming loop tiles;
/// 4. inspect the fused ops and determine whether they have other remaining
/// LinalgOp uses. If not, then erase the original producing linalg op.
///
/// More advanced use cases, analyses as well as profitability heuristics are
/// left for future work.
static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
static llvm::cl::list<unsigned> clTileSizes(
"linalg-fusion-tile-sizes",
llvm::cl::desc(
"Tile sizes by which to tile linalg operations during linalg fusion"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::cat(clOptionsCategory));
// Return a cloned version of `op` that operates on `loopRanges`, assumed to be
// a subset of the original loop ranges of `op`.
// This is achieved by applying the `loopToOperandRangesMaps` permutation maps
// to the `loopRanges` in order to obtain view ranges.
static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
ArrayRef<SubViewOp::Range> loopRanges,
OperationFolder &state) {
ScopedContext scope(b, loc);
auto maps = loopToOperandRangesMaps(op);
SmallVector<Value *, 8> clonedViews;
clonedViews.reserve(op.getNumInputsAndOutputs());
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
SmallVector<Value *, 8> ios(op.getInputsAndOutputs());
for (auto en : llvm::enumerate(ios)) {
unsigned idx = en.index();
auto map = maps[idx];
LLVM_DEBUG(dbgs() << "map: " << map << "\n");
Value *view = en.value();
SmallVector<SubViewOp::Range, 8> viewRanges(map.getNumResults());
for (auto en2 : llvm::enumerate(map.getResults())) {
unsigned d = en2.index();
// loopToOperandRangesMaps are permutations-only.
unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
viewRanges[d] = loopRanges[loopPos];
LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
<< "\t"
<< "loopPos: " << loopPos << "\t" << viewRanges[d]);
}
// TODO(ntv) opportunities for folding/CSE here rather than build new IR.
clonedViews.push_back(b.create<SubViewOp>(loc, view, viewRanges));
}
auto operands = getAssumedNonViewOperands(op);
clonedViews.append(operands.begin(), operands.end());
return op.create(b, loc, clonedViews, op.getAttrs());
}
struct ViewDimension {
Value *view;
unsigned dimension;
};
static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
auto maps = loopToOperandRangesMaps(op);
SmallVector<Value *, 8> clonedViews;
clonedViews.reserve(op.getNumInputsAndOutputs());
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
SmallVector<Value *, 8> ios(op.getInputsAndOutputs());
for (auto en : llvm::enumerate(ios)) {
unsigned idx = en.index();
auto map = maps[idx];
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
Value *view = en.value();
SmallVector<Value *, 8> viewRanges(map.getNumResults(), nullptr);
for (auto en2 : llvm::enumerate(map.getResults())) {
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
<< "\n");
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view
<< "\n");
return ViewDimension{view, static_cast<unsigned>(en2.index())};
}
}
}
llvm_unreachable("Expect to be able to extract a view defining loop range");
}
static Optional<LinalgOp> fuse(Value *producedView, LinalgOp producer,
LinalgOp consumer, LinalgOp tiledConsumer,
OperationFolder &state) {
auto maybeConsumerIdx = consumer.getIndexOfInput(producedView);
if (!maybeConsumerIdx.hasValue())
return llvm::None;
unsigned consumerIdx = maybeConsumerIdx.getValue();
auto maybeProducerIdx = producer.getIndexOfOutput(producedView);
if (!maybeProducerIdx.hasValue())
return llvm::None;
unsigned producerIdx = maybeProducerIdx.getValue();
// If the view is the same between consumer and tiledConsumer, this means we
// don't have loops and the producer cannot be fused at this level.
if (consumer.getInput(consumerIdx) == tiledConsumer.getInput(consumerIdx))
return llvm::None;
auto tiledConsumerSubView = dyn_cast_or_null<SubViewOp>(
tiledConsumer.getInput(consumerIdx)->getDefiningOp());
// If we don't have a slice, this also means we don't have loops and the
// producer cannot be fused at this level.
if (!tiledConsumerSubView)
return llvm::None;
// loopToOperandRangesMaps are permutations-only by construction:
// we can always identify a data dimension with a (at least one) loop
// dimension.
AffineMap producerMap =
loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx];
LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << ", consumer map: "
<< loopToOperandRangesMaps(consumer)[consumerIdx] << "\n");
LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
<< ", producer map: " << producerMap << "\n");
unsigned nPar = producer.getNumParallelLoops();
unsigned nRed = producer.getNumReductionLoops();
unsigned nWin = producer.getNumWindowLoops();
SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
// Iterate over dimensions identified by the producer map for `producerIdx`.
// This defines a subset of the loop ranges that we need to complete later.
for (auto en : llvm::enumerate(producerMap.getResults())) {
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
loopRanges[posInProducerLoop] = tiledConsumerSubView.getRange(en.index());
}
OpBuilder b(tiledConsumer.getOperation());
auto loc = tiledConsumer.getLoc();
// Iterate over all dimensions. For the dimensions not identified by the
// producer map for `producerIdx`, we need to explicitly compute the view that
// defines the loop ranges using the `producer`.
for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
if (loopRanges[i].min)
LLVM_DEBUG(llvm::dbgs()
<< "existing LoopRange: " << loopRanges[i] << "\n");
else {
auto viewDim = getViewDefiningLoopRange(producer, i);
loopRanges[i] = SubViewOp::Range{
state.create<ConstantIndexOp>(b, loc, 0),
linalg::intrinsics::dim(viewDim.view, viewDim.dimension),
state.create<ConstantIndexOp>(b, loc, 1)};
LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
}
}
return cloneWithLoopRanges(b, loc, producer, loopRanges, state);
}
// Encode structural fusion safety preconditions.
// Some of these will be lifted in the future with better analysis.
static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
LinalgOp consumer) {
// If a producer has multiple outputs, the analysis needs to take the tiling
// of other outputs into account.
if (producer.getNumOutputs() != 1)
return false;
// Until subview analysis is available, same SSA value is required for fusion.
if (producer.getOutput(0) != readView)
return false;
// No control-flow divergence supported. Only straightline op fusion allowed.
// TODO(ntv) allow fusion when a dominance relation exists.
if (producer.getOperation()->getBlock() !=
consumer.getOperation()->getBlock())
return false;
return true;
}
static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
OperationFolder state;
DenseSet<Operation *> eraseSet;
LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
// 1. Record the linalg ops so we can traverse them in reverse order.
SmallVector<Operation *, 8> linalgOps;
f.walk<LinalgOp>(
[&](LinalgOp op) { linalgOps.push_back(op.getOperation()); });
// 2. Setup the dependences graph, aliases are populated lazily.
Aliases aliases;
LinalgDependenceGraph G(aliases, linalgOps);
// 2. For each original linalg op (in reverse order to allow chained
// fusions).
for (auto *op : llvm::reverse(linalgOps)) {
auto consumer = cast<LinalgOp>(op);
LLVM_DEBUG(dbgs() << "\n******\nStart processing:\t" << *op);
// 3. If marked for erasure, it has already been fused. Skip fusing op.
if (eraseSet.count(op) > 0) {
LLVM_DEBUG(dbgs() << "\nAlready fused and marked for erasure, skip.");
continue;
}
// 4. Apply loop tiling to enable fusion. If unsuccessful, skip fusing op.
auto tiledOp = tileLinalgOp(op, tileSizes, state);
if (!tiledOp) {
LLVM_DEBUG(dbgs() << "\nTile sizes did not produce loops, skip.");
continue;
}
// 5. For now, we only fuse RAW dependences.
SmallVector<Operation *, 8> fusedProducers;
SmallVector<Value *, 8> fusedViews;
for (auto dependence : G.getDependencesInto(
consumer, LinalgDependenceGraph::DependenceType::RAW)) {
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
<< *producer.getOperation() << "\n");
// a. For now we require fusion on identical SSA values, this allows us to
// not worry about partial writes etc.
// TODO(ntv) support more elaborate fusion with non identical SSA values.
auto *view = dependence.indexingView;
if (view != dependence.dependentOpView.view) {
LLVM_DEBUG(dbgs() << "\nviews are different SSA values, skip.");
continue;
}
// b. Make some simple structural checks that alleviate the need for more
// complex analyses.
if (!isStructurallyFusableProducer(producer, view, op)) {
LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation());
continue;
}
// c. Check for fusion-preventing write that would violate dependences.
// `view` is a producer write that cannot bypass any other write or read.
bool preventFusion = false;
for (auto *op : G.findCoveringDependences(producer, consumer))
if (eraseSet.count(op) == 0) {
preventFusion = true;
LLVM_DEBUG(dbgs() << "\n***Found fusion preventing dep via: " << *op);
break;
}
if (preventFusion)
continue;
// 6. Try to fuse `producer` just before `tiledOp`.
LLVM_DEBUG(f.print(dbgs() << "\nBefore tiledOp-fusion: \n"));
auto tOp = tiledOp->op;
OpBuilder builder(tOp.getOperation());
ScopedContext scope(builder, tOp.getLoc());
LLVM_DEBUG(dbgs() << "Try fuse into tiled consumer: " << *tOp << "\n");
auto maybeFusedProducer = fuse(view, producer, op, tOp, state);
if (!maybeFusedProducer) {
LLVM_DEBUG(dbgs() << "\nFusion did not do anything, skip.");
continue;
}
fusedProducers.push_back(producer.getOperation());
fusedViews.push_back(view);
}
// 7. If no fusion occurred, or a drop the outer tiled loop which undoes
// everything we did.
if (fusedProducers.empty()) {
tiledOp->loops[0].erase();
continue;
}
eraseSet.insert(op);
eraseSet.insert(fusedProducers.begin(), fusedProducers.end());
}
for (auto *op : eraseSet)
op->erase();
LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
}
namespace {
struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
LinalgFusionPass() = default;
LinalgFusionPass(ArrayRef<int64_t> sizes);
void runOnFunction() { fuseLinalgOps(getFunction(), tileSizes); }
SmallVector<int64_t, 8> tileSizes;
};
} // namespace
LinalgFusionPass::LinalgFusionPass(ArrayRef<int64_t> sizes)
: LinalgFusionPass() {
if (!sizes.empty())
this->tileSizes.assign(sizes.begin(), sizes.end());
}
FunctionPassBase *
mlir::linalg::createLinalgFusionPass(ArrayRef<int64_t> tileSizes) {
return new LinalgFusionPass(tileSizes);
}
static PassRegistration<LinalgFusionPass>
pass("linalg-fusion", "Fuse operations in the linalg dialect", [] {
auto *pass = new LinalgFusionPass();
pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
return pass;
});