Enable multi-level Linalg fusion

This CL adds support for SubViewOp in the alias analysis to permit multiple Linalg fusion passes to compose. The debugging messages are also improved for better readability. The readability benefits came in handy when tracking this issue.

A 2-level fusion test is added to capture the new behavior.

PiperOrigin-RevId: 259720246
diff --git a/lib/Linalg/Analysis/DependenceAnalysis.cpp b/lib/Linalg/Analysis/DependenceAnalysis.cpp
index 10b5284b..f44bea3 100644
--- a/lib/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/lib/Linalg/Analysis/DependenceAnalysis.cpp
@@ -44,15 +44,25 @@
            "Buffer or block argument expected");
     return it->getSecond();
   }
-  if (auto slice = dyn_cast_or_null<SliceOp>(v->getDefiningOp())) {
-    auto it = aliases.insert(std::make_pair(v, find(slice.getBaseView())));
-    return it.first->second;
+
+  while (true) {
+    if (isa<BlockArgument>(v))
+      return v;
+    if (auto slice = dyn_cast_or_null<SliceOp>(v->getDefiningOp())) {
+      auto it = aliases.insert(std::make_pair(v, find(slice.getBaseView())));
+      return it.first->second;
+    }
+    if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
+      auto it = aliases.insert(std::make_pair(v, view.getSupportingBuffer()));
+      return it.first->second;
+    }
+    if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {
+      v = view.getView();
+      continue;
+    }
+    llvm::errs() << "View alias analysis reduces to: " << *v << "\n";
+    llvm_unreachable("unsupported view alias case");
   }
-  if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
-    auto it = aliases.insert(std::make_pair(v, view.getSupportingBuffer()));
-    return it.first->second;
-  }
-  llvm_unreachable("unsupported view alias case");
 }
 
 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
diff --git a/lib/Linalg/Transforms/Fusion.cpp b/lib/Linalg/Transforms/Fusion.cpp
index 480d19f..4864f39 100644
--- a/lib/Linalg/Transforms/Fusion.cpp
+++ b/lib/Linalg/Transforms/Fusion.cpp
@@ -98,7 +98,8 @@
       // loopToOperandRangesMaps are permutations-only.
       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
       viewRanges[d] = loopRanges[loopPos];
-      LLVM_DEBUG(dbgs() << "i,j: " << en.index() << ", " << en2.index() << "\t"
+      LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
+                        << "\t"
                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
     }
     // TODO(ntv) opportunities for folding/CSE here rather than build new IR.
@@ -124,12 +125,18 @@
   for (auto en : llvm::enumerate(ios)) {
     unsigned idx = en.index();
     auto map = maps[idx];
-    LLVM_DEBUG(dbgs() << "map: " << map << "\n");
+    LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
+    LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
     Value *view = en.value();
     SmallVector<Value *, 8> viewRanges(map.getNumResults(), nullptr);
     for (auto en2 : llvm::enumerate(map.getResults())) {
-      if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition())
+      if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
+        LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
+                          << "\n");
+        LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view
+                          << "\n");
         return ViewDimension{view, static_cast<unsigned>(en2.index())};
+      }
     }
   }
   llvm_unreachable("Expect to be able to extract a view defining loop range");
@@ -148,44 +155,57 @@
     return llvm::None;
   unsigned producerIdx = maybeProducerIdx.getValue();
 
-  auto sliceOp = dyn_cast_or_null<SubViewOp>(
-      tiledConsumer.getInput(consumerIdx)->getDefiningOp());
-  // If we don't have a slice, this also means we don't have loops and the
-  // producer cannot be fused at this level.
-  if (!sliceOp)
+  // If the view is the same between consumer and tiledConsumer, this means we
+  // don't have loops and the producer cannot be fused at this level.
+  if (consumer.getInput(consumerIdx) == tiledConsumer.getInput(consumerIdx))
     return llvm::None;
 
+  auto tiledConsumerSubView = dyn_cast_or_null<SubViewOp>(
+      tiledConsumer.getInput(consumerIdx)->getDefiningOp());
+
+  // If we don't have a slice, this also means we don't have loops and the
+  // producer cannot be fused at this level.
+  if (!tiledConsumerSubView)
+    return llvm::None;
+
+  // loopToOperandRangesMaps are permutations-only by construction:
+  //   we can always identify a data dimension with a (at least one) loop
+  //   dimension.
   AffineMap producerMap =
       loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx];
-  LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << "\tmap: "
+  LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << ", consumer map: "
                     << loopToOperandRangesMaps(consumer)[consumerIdx] << "\n");
   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
-                    << "\tmap: " << producerMap << "\n");
+                    << ", producer map: " << producerMap << "\n");
 
   unsigned nPar = producer.getNumParallelLoops();
   unsigned nRed = producer.getNumReductionLoops();
   unsigned nWin = producer.getNumWindowLoops();
   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
-  DenseSet<unsigned> fromSlice;
+
+  // Iterate over dimensions identified by the producer map for `producerIdx`.
+  // This defines a subset of the loop ranges that we need to complete later.
   for (auto en : llvm::enumerate(producerMap.getResults())) {
-    // loopToOperandRangesMaps are permutations-only.
     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
-    loopRanges[posInProducerLoop] = sliceOp.getRange(en.index());
-    fromSlice.insert(posInProducerLoop);
+    loopRanges[posInProducerLoop] = tiledConsumerSubView.getRange(en.index());
   }
 
   OpBuilder b(tiledConsumer.getOperation());
   auto loc = tiledConsumer.getLoc();
-  for (unsigned i = 0; i < loopRanges.size(); ++i) {
-    if (fromSlice.count(i))
-      LLVM_DEBUG(llvm::dbgs() << "LR: " << loopRanges[i] << "\n");
+  // Iterate over all dimensions. For the dimensions not identified by the
+  // producer map for `producerIdx`, we need to explicitly compute the view that
+  // defines the loop ranges using the `producer`.
+  for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
+    if (loopRanges[i].min)
+      LLVM_DEBUG(llvm::dbgs()
+                 << "existing LoopRange: " << loopRanges[i] << "\n");
     else {
       auto viewDim = getViewDefiningLoopRange(producer, i);
       loopRanges[i] = SubViewOp::Range{
           state.create<ConstantIndexOp>(b, loc, 0),
           linalg::intrinsics::dim(viewDim.view, viewDim.dimension),
           state.create<ConstantIndexOp>(b, loc, 1)};
-      LLVM_DEBUG(llvm::dbgs() << "new LR: " << loopRanges[i] << "\n");
+      LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
     }
   }
 
@@ -215,6 +235,8 @@
   OperationFolder state;
   DenseSet<Operation *> eraseSet;
 
+  LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
+
   // 1. Record the linalg ops so we can traverse them in reverse order.
   SmallVector<Operation *, 8> linalgOps;
   f.walk<LinalgOp>(
@@ -249,7 +271,7 @@
              consumer, LinalgDependenceGraph::DependenceType::RAW)) {
       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
       LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
-                        << *producer.getOperation());
+                        << *producer.getOperation() << "\n");
 
       // a. For now we require fusion on identical SSA values, this allows us to
       // not worry about partial writes etc.
@@ -278,9 +300,12 @@
         continue;
 
       // 6. Try to fuse `producer` just before `tiledOp`.
+      LLVM_DEBUG(f.print(dbgs() << "\nBefore tiledOp-fusion: \n"));
+
       auto tOp = tiledOp->op;
       OpBuilder builder(tOp.getOperation());
       ScopedContext scope(builder, tOp.getLoc());
+      LLVM_DEBUG(dbgs() << "Try fuse into tiled consumer: " << *tOp << "\n");
       auto maybeFusedProducer = fuse(view, producer, op, tOp, state);
       if (!maybeFusedProducer) {
         LLVM_DEBUG(dbgs() << "\nFusion did not do anything, skip.");
@@ -310,7 +335,7 @@
 
 namespace {
 struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
-  LinalgFusionPass();
+  LinalgFusionPass() = default;
   LinalgFusionPass(ArrayRef<int64_t> sizes);
 
   void runOnFunction() { fuseLinalgOps(getFunction(), tileSizes); }
@@ -319,9 +344,6 @@
 };
 } // namespace
 
-LinalgFusionPass::LinalgFusionPass()
-    : tileSizes(clTileSizes.begin(), clTileSizes.end()) {}
-
 LinalgFusionPass::LinalgFusionPass(ArrayRef<int64_t> sizes)
     : LinalgFusionPass() {
   if (!sizes.empty())
@@ -334,4 +356,8 @@
 }
 
 static PassRegistration<LinalgFusionPass>
-    pass("linalg-fusion", "Fuse operations in the linalg dialect");
+    pass("linalg-fusion", "Fuse operations in the linalg dialect", [] {
+      auto *pass = new LinalgFusionPass();
+      pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
+      return pass;
+    });
diff --git a/test/Linalg/fusion-2-level.mlir b/test/Linalg/fusion-2-level.mlir
new file mode 100644
index 0000000..29c87c7
--- /dev/null
+++ b/test/Linalg/fusion-2-level.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -linalg-fusion -linalg-fusion-tile-sizes=16 -cse | mlir-opt -linalg-fusion -linalg-fusion-tile-sizes=8 | FileCheck %s
+
+func @f1(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  linalg.matmul(%C, %D, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  return %E : !linalg.view<?x?xf32>
+}
+// CHECK-LABEL: func @f1
+//   CHECK-DAG: %[[c8:.*]] = constant 8
+//   CHECK-DAG: %[[c16:.*]] = constant 16
+//       CHECK:   loop.for %{{.*}} step %[[c16]] {
+//       CHECK:     loop.for %{{.*}} %[[c8]] {
+//       CHECK:       linalg.matmul
+//       CHECK:       linalg.matmul
\ No newline at end of file