| //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| // |
| // This file implements the linalg dialect Fusion pass. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/EDSC/Helpers.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/Linalg/Analysis/DependenceAnalysis.h" |
| #include "mlir/Linalg/IR/LinalgOps.h" |
| #include "mlir/Linalg/IR/LinalgTypes.h" |
| #include "mlir/Linalg/Passes.h" |
| #include "mlir/Linalg/Utils/Intrinsics.h" |
| #include "mlir/Linalg/Utils/Utils.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/STLExtras.h" |
| #include "mlir/Transforms/FoldUtils.h" |
| |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/Debug.h" |
| |
| #define DEBUG_TYPE "linalg-fusion" |
| |
| using namespace mlir; |
| using namespace mlir::edsc; |
| using namespace mlir::edsc::intrinsics; |
| using namespace mlir::linalg; |
| using namespace mlir::linalg::intrinsics; |
| |
| using llvm::dbgs; |
| |
| /// Implements a simple high-level fusion pass of linalg library operations. |
| /// |
| /// In each block, linalg ops are processed in reverse textual order. |
| /// Given a linalg op, fusion occurs by: |
| /// 1. tiling the op by a given multi-dimensional tile size; |
| /// 2. inspecting the linalg ops that write into the views read by the op in |
| /// step 1. This uses the SSA value of the views to determine producer- |
| /// consumer dependences: only identical SSA views are considered for |
| /// fusion at this point; |
| /// 3. greedily fuse the producing linalg ops into the consuming loop tiles; |
| /// 4. inspect the fused ops and determine whether they have other remaining |
| /// LinalgOp uses. If not, then erase the original producing linalg op. |
| /// |
| /// More advanced use cases, analyses as well as profitability heuristics are |
| /// left for future work. |
| |
| static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); |
| static llvm::cl::list<unsigned> clTileSizes( |
| "linalg-fusion-tile-sizes", |
| llvm::cl::desc( |
| "Tile sizes by which to tile linalg operations during linalg fusion"), |
| llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, |
| llvm::cl::cat(clOptionsCategory)); |
| |
| // Return a cloned version of `op` that operates on `loopRanges`, assumed to be |
| // a subset of the original loop ranges of `op`. |
| // This is achieved by applying the `loopToOperandRangesMaps` permutation maps |
| // to the `loopRanges` in order to obtain view ranges. |
| static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, |
| ArrayRef<SubViewOp::Range> loopRanges, |
| OperationFolder &state) { |
| ScopedContext scope(b, loc); |
| |
| auto maps = loopToOperandRangesMaps(op); |
| SmallVector<Value *, 8> clonedViews; |
| clonedViews.reserve(op.getNumInputsAndOutputs()); |
| // Iterate over the inputs and outputs in order. |
| // Extract the subranges from the linearized ranges. |
| SmallVector<Value *, 8> ios(op.getInputsAndOutputs()); |
| for (auto en : llvm::enumerate(ios)) { |
| unsigned idx = en.index(); |
| auto map = maps[idx]; |
| LLVM_DEBUG(dbgs() << "map: " << map << "\n"); |
| Value *view = en.value(); |
| SmallVector<SubViewOp::Range, 8> viewRanges(map.getNumResults()); |
| for (auto en2 : llvm::enumerate(map.getResults())) { |
| unsigned d = en2.index(); |
| // loopToOperandRangesMaps are permutations-only. |
| unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); |
| viewRanges[d] = loopRanges[loopPos]; |
| LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() |
| << "\t" |
| << "loopPos: " << loopPos << "\t" << viewRanges[d]); |
| } |
| // TODO(ntv) opportunities for folding/CSE here rather than build new IR. |
| clonedViews.push_back(b.create<SubViewOp>(loc, view, viewRanges)); |
| } |
| auto operands = getAssumedNonViewOperands(op); |
| clonedViews.append(operands.begin(), operands.end()); |
| return op.create(b, loc, clonedViews, op.getAttrs()); |
| } |
| |
| struct ViewDimension { |
| Value *view; |
| unsigned dimension; |
| }; |
| |
| static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { |
| auto maps = loopToOperandRangesMaps(op); |
| SmallVector<Value *, 8> clonedViews; |
| clonedViews.reserve(op.getNumInputsAndOutputs()); |
| // Iterate over the inputs and outputs in order. |
| // Extract the subranges from the linearized ranges. |
| SmallVector<Value *, 8> ios(op.getInputsAndOutputs()); |
| for (auto en : llvm::enumerate(ios)) { |
| unsigned idx = en.index(); |
| auto map = maps[idx]; |
| LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); |
| LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); |
| Value *view = en.value(); |
| SmallVector<Value *, 8> viewRanges(map.getNumResults(), nullptr); |
| for (auto en2 : llvm::enumerate(map.getResults())) { |
| if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { |
| LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth |
| << "\n"); |
| LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view |
| << "\n"); |
| return ViewDimension{view, static_cast<unsigned>(en2.index())}; |
| } |
| } |
| } |
| llvm_unreachable("Expect to be able to extract a view defining loop range"); |
| } |
| |
| static Optional<LinalgOp> fuse(Value *producedView, LinalgOp producer, |
| LinalgOp consumer, LinalgOp tiledConsumer, |
| OperationFolder &state) { |
| auto maybeConsumerIdx = consumer.getIndexOfInput(producedView); |
| if (!maybeConsumerIdx.hasValue()) |
| return llvm::None; |
| unsigned consumerIdx = maybeConsumerIdx.getValue(); |
| |
| auto maybeProducerIdx = producer.getIndexOfOutput(producedView); |
| if (!maybeProducerIdx.hasValue()) |
| return llvm::None; |
| unsigned producerIdx = maybeProducerIdx.getValue(); |
| |
| // If the view is the same between consumer and tiledConsumer, this means we |
| // don't have loops and the producer cannot be fused at this level. |
| if (consumer.getInput(consumerIdx) == tiledConsumer.getInput(consumerIdx)) |
| return llvm::None; |
| |
| auto tiledConsumerSubView = dyn_cast_or_null<SubViewOp>( |
| tiledConsumer.getInput(consumerIdx)->getDefiningOp()); |
| |
| // If we don't have a slice, this also means we don't have loops and the |
| // producer cannot be fused at this level. |
| if (!tiledConsumerSubView) |
| return llvm::None; |
| |
| // loopToOperandRangesMaps are permutations-only by construction: |
| // we can always identify a data dimension with a (at least one) loop |
| // dimension. |
| AffineMap producerMap = |
| loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx]; |
| LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << ", consumer map: " |
| << loopToOperandRangesMaps(consumer)[consumerIdx] << "\n"); |
| LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx |
| << ", producer map: " << producerMap << "\n"); |
| |
| unsigned nPar = producer.getNumParallelLoops(); |
| unsigned nRed = producer.getNumReductionLoops(); |
| unsigned nWin = producer.getNumWindowLoops(); |
| SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); |
| |
| // Iterate over dimensions identified by the producer map for `producerIdx`. |
| // This defines a subset of the loop ranges that we need to complete later. |
| for (auto en : llvm::enumerate(producerMap.getResults())) { |
| unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); |
| loopRanges[posInProducerLoop] = tiledConsumerSubView.getRange(en.index()); |
| } |
| |
| OpBuilder b(tiledConsumer.getOperation()); |
| auto loc = tiledConsumer.getLoc(); |
| // Iterate over all dimensions. For the dimensions not identified by the |
| // producer map for `producerIdx`, we need to explicitly compute the view that |
| // defines the loop ranges using the `producer`. |
| for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { |
| if (loopRanges[i].min) |
| LLVM_DEBUG(llvm::dbgs() |
| << "existing LoopRange: " << loopRanges[i] << "\n"); |
| else { |
| auto viewDim = getViewDefiningLoopRange(producer, i); |
| loopRanges[i] = SubViewOp::Range{ |
| state.create<ConstantIndexOp>(b, loc, 0), |
| linalg::intrinsics::dim(viewDim.view, viewDim.dimension), |
| state.create<ConstantIndexOp>(b, loc, 1)}; |
| LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); |
| } |
| } |
| |
| return cloneWithLoopRanges(b, loc, producer, loopRanges, state); |
| } |
| |
| // Encode structural fusion safety preconditions. |
| // Some of these will be lifted in the future with better analysis. |
| static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, |
| LinalgOp consumer) { |
| // If a producer has multiple outputs, the analysis needs to take the tiling |
| // of other outputs into account. |
| if (producer.getNumOutputs() != 1) |
| return false; |
| // Until subview analysis is available, same SSA value is required for fusion. |
| if (producer.getOutput(0) != readView) |
| return false; |
| // No control-flow divergence supported. Only straightline op fusion allowed. |
| // TODO(ntv) allow fusion when a dominance relation exists. |
| if (producer.getOperation()->getBlock() != |
| consumer.getOperation()->getBlock()) |
| return false; |
| return true; |
| } |
| |
| static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) { |
| OperationFolder state; |
| DenseSet<Operation *> eraseSet; |
| |
| LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); |
| |
| // 1. Record the linalg ops so we can traverse them in reverse order. |
| SmallVector<Operation *, 8> linalgOps; |
| f.walk<LinalgOp>( |
| [&](LinalgOp op) { linalgOps.push_back(op.getOperation()); }); |
| |
| // 2. Setup the dependences graph, aliases are populated lazily. |
| Aliases aliases; |
| LinalgDependenceGraph G(aliases, linalgOps); |
| |
| // 2. For each original linalg op (in reverse order to allow chained |
| // fusions). |
| for (auto *op : llvm::reverse(linalgOps)) { |
| auto consumer = cast<LinalgOp>(op); |
| LLVM_DEBUG(dbgs() << "\n******\nStart processing:\t" << *op); |
| // 3. If marked for erasure, it has already been fused. Skip fusing op. |
| if (eraseSet.count(op) > 0) { |
| LLVM_DEBUG(dbgs() << "\nAlready fused and marked for erasure, skip."); |
| continue; |
| } |
| |
| // 4. Apply loop tiling to enable fusion. If unsuccessful, skip fusing op. |
| auto tiledOp = tileLinalgOp(op, tileSizes, state); |
| if (!tiledOp) { |
| LLVM_DEBUG(dbgs() << "\nTile sizes did not produce loops, skip."); |
| continue; |
| } |
| |
| // 5. For now, we only fuse RAW dependences. |
| SmallVector<Operation *, 8> fusedProducers; |
| SmallVector<Value *, 8> fusedViews; |
| for (auto dependence : G.getDependencesInto( |
| consumer, LinalgDependenceGraph::DependenceType::RAW)) { |
| auto producer = cast<LinalgOp>(dependence.dependentOpView.op); |
| LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" |
| << *producer.getOperation() << "\n"); |
| |
| // a. For now we require fusion on identical SSA values, this allows us to |
| // not worry about partial writes etc. |
| // TODO(ntv) support more elaborate fusion with non identical SSA values. |
| auto *view = dependence.indexingView; |
| if (view != dependence.dependentOpView.view) { |
| LLVM_DEBUG(dbgs() << "\nviews are different SSA values, skip."); |
| continue; |
| } |
| // b. Make some simple structural checks that alleviate the need for more |
| // complex analyses. |
| if (!isStructurallyFusableProducer(producer, view, op)) { |
| LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation()); |
| continue; |
| } |
| // c. Check for fusion-preventing write that would violate dependences. |
| // `view` is a producer write that cannot bypass any other write or read. |
| bool preventFusion = false; |
| for (auto *op : G.findCoveringDependences(producer, consumer)) |
| if (eraseSet.count(op) == 0) { |
| preventFusion = true; |
| LLVM_DEBUG(dbgs() << "\n***Found fusion preventing dep via: " << *op); |
| break; |
| } |
| if (preventFusion) |
| continue; |
| |
| // 6. Try to fuse `producer` just before `tiledOp`. |
| LLVM_DEBUG(f.print(dbgs() << "\nBefore tiledOp-fusion: \n")); |
| |
| auto tOp = tiledOp->op; |
| OpBuilder builder(tOp.getOperation()); |
| ScopedContext scope(builder, tOp.getLoc()); |
| LLVM_DEBUG(dbgs() << "Try fuse into tiled consumer: " << *tOp << "\n"); |
| auto maybeFusedProducer = fuse(view, producer, op, tOp, state); |
| if (!maybeFusedProducer) { |
| LLVM_DEBUG(dbgs() << "\nFusion did not do anything, skip."); |
| continue; |
| } |
| |
| fusedProducers.push_back(producer.getOperation()); |
| fusedViews.push_back(view); |
| } |
| |
| // 7. If no fusion occurred, or a drop the outer tiled loop which undoes |
| // everything we did. |
| if (fusedProducers.empty()) { |
| tiledOp->loops[0].erase(); |
| continue; |
| } |
| |
| eraseSet.insert(op); |
| eraseSet.insert(fusedProducers.begin(), fusedProducers.end()); |
| } |
| |
| for (auto *op : eraseSet) |
| op->erase(); |
| |
| LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); |
| } |
| |
| namespace { |
| struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> { |
| LinalgFusionPass() = default; |
| LinalgFusionPass(ArrayRef<int64_t> sizes); |
| |
| void runOnFunction() { fuseLinalgOps(getFunction(), tileSizes); } |
| |
| SmallVector<int64_t, 8> tileSizes; |
| }; |
| } // namespace |
| |
| LinalgFusionPass::LinalgFusionPass(ArrayRef<int64_t> sizes) |
| : LinalgFusionPass() { |
| if (!sizes.empty()) |
| this->tileSizes.assign(sizes.begin(), sizes.end()); |
| } |
| |
| FunctionPassBase * |
| mlir::linalg::createLinalgFusionPass(ArrayRef<int64_t> tileSizes) { |
| return new LinalgFusionPass(tileSizes); |
| } |
| |
| static PassRegistration<LinalgFusionPass> |
| pass("linalg-fusion", "Fuse operations in the linalg dialect", [] { |
| auto *pass = new LinalgFusionPass(); |
| pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); |
| return pass; |
| }); |