Loop Fusion pass update: introduce utilities to perform generalized loop fusion based on slicing; encompasses standard loop fusion.
*) Adds simple greedy fusion algorithm to drive experimentation. This algorithm greedily fuses loop nests with single-writer/single-reader memref dependences to improve locality.
*) Adds support for fusing slices of a loop nest computation: fusing one loop nest into another by adjusting the source loop nest's iteration bounds (after it is fused into the destination loop nest). This is accomplished by solving for the source loop nest's IVs in terms of the destination loop nests IVs and symbols using the dependece polyhedron, then creating AffineMaps of these functions for the loop bounds of the fused source loop.
*) Adds utility function 'insertMemRefComputationSlice' which computes and inserts computation slice from loop nest surrounding a source memref access into the loop nest surrounding the destingation memref access.
*) Adds FlatAffineConstraints::toAffineMap function which returns and AffineMap which represents an equality contraint where one dimension identifier is represented as a function of all others in the equality constraint.
*) Adds multiple fusion unit tests.
PiperOrigin-RevId: 225842944
diff --git a/include/mlir/Analysis/AffineAnalysis.h b/include/mlir/Analysis/AffineAnalysis.h
index a5bc373..bc67127 100644
--- a/include/mlir/Analysis/AffineAnalysis.h
+++ b/include/mlir/Analysis/AffineAnalysis.h
@@ -145,9 +145,12 @@
/// the operation statement, indices and memref associated with the access.
/// Returns 'false' if it can be determined conclusively that the accesses do
/// not access the same memref element. Returns 'true' otherwise.
+// TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into
+// a single struct.
+// TODO(andydavis) Make 'dependenceConstraints' optional arg.
bool checkMemrefAccessDependence(
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
- unsigned loopDepth,
+ unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents);
} // end namespace mlir
diff --git a/include/mlir/Analysis/AffineStructures.h b/include/mlir/Analysis/AffineStructures.h
index ce53eaf..3261d13 100644
--- a/include/mlir/Analysis/AffineStructures.h
+++ b/include/mlir/Analysis/AffineStructures.h
@@ -384,6 +384,21 @@
AffineExpr toAffineExpr(unsigned idx, MLIRContext *context);
+ // Returns an AffineMap that expresses the identifier at pos as a function of
+ // other dimensional and symbolic identifiers using the 'idx^th' equality
+ // constraint.
+ // If 'nonZeroDimIds' and 'nonZeroSymbolIds' are non-null, they are populated
+ // with the positions of the non-zero equality constraint coefficients which
+ // were used to build the returned AffineMap.
+ // Returns AffineMap::Null on error (i.e. if coefficient is zero or does
+ // not divide other coefficients in the equality constraint).
+ // TODO(andydavis) Remove 'nonZeroDimIds' and 'nonZeroSymbolIds' from this
+ // API when we can manage the mapping of MLValues and ids in the constraint
+ // system.
+ AffineMap toAffineMapFromEq(unsigned idx, unsigned pos, MLIRContext *context,
+ SmallVectorImpl<unsigned> *nonZeroDimIds,
+ SmallVectorImpl<unsigned> *nonZeroSymbolIds);
+
// Adds an inequality (>= 0) from the coefficients specified in inEq.
void addInequality(ArrayRef<int64_t> inEq);
// Adds an equality from the coefficients specified in eq.
diff --git a/include/mlir/Analysis/Utils.h b/include/mlir/Analysis/Utils.h
index 197edb2..796a7aa 100644
--- a/include/mlir/Analysis/Utils.h
+++ b/include/mlir/Analysis/Utils.h
@@ -33,7 +33,9 @@
namespace mlir {
class FlatAffineConstraints;
+class ForStmt;
class MLValue;
+class MemRefAccess;
class OperationStmt;
class Statement;
@@ -139,6 +141,21 @@
bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
bool emitError = true);
+/// Creates a clone of the computation contained in the loop nest surrounding
+/// 'srcAccess', and inserts it at the beginning of the statement block of the
+/// loop containing 'dstAccess'. Returns the top-level loop of the computation
+/// slice on success, returns nullptr otherwise.
+// Computes memref dependence between 'srcAccess' and 'dstAccess' and uses the
+// dependence constraint system to create AffineMaps with which to adjust the
+// loop bounds of the inserted compution slice so that they are functions of the
+// loop IVs and symbols of the loops surrounding 'dstAccess'.
+// TODO(andydavis) Add 'dstLoopDepth' argument for computation slice insertion.
+// Loop depth is a crucial optimization choice that determines where to
+// materialize the results of the backward slice - presenting a trade-off b/w
+// storage and redundant computation in several cases
+// TODO(andydavis) Support computation slices with common surrounding loops.
+ForStmt *insertBackwardComputationSlice(MemRefAccess *srcAccess,
+ MemRefAccess *dstAccess);
} // end namespace mlir
#endif // MLIR_ANALYSIS_UTILS_H
diff --git a/lib/Analysis/AffineAnalysis.cpp b/lib/Analysis/AffineAnalysis.cpp
index 7f53a14..80da93d 100644
--- a/lib/Analysis/AffineAnalysis.cpp
+++ b/lib/Analysis/AffineAnalysis.cpp
@@ -1152,7 +1152,7 @@
// The access functions would be the following:
//
// src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M)
-// src: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
+// dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
//
// The iteration domains for the src/dst accesses would be the following:
//
@@ -1166,7 +1166,7 @@
// symbol pos: 0 1 2
//
// Equality constraints are built by equating each result of src/destination
-// access functions. For this example, the folloing two equality constraints
+// access functions. For this example, the following two equality constraints
// will be added to the dependence constraint system:
//
// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
@@ -1190,7 +1190,7 @@
// TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv.
bool mlir::checkMemrefAccessDependence(
const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
- unsigned loopDepth,
+ unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) {
// Return 'false' if these accesses do not acces the same memref.
if (srcAccess.memref != dstAccess.memref)
@@ -1247,28 +1247,31 @@
unsigned numCols = numIds + 1;
// Create flat affine constraints reserving space for 'numEq' and 'numIneq'.
- FlatAffineConstraints dependenceDomain(numIneq, numEq, numCols, numDims,
- numSymbols);
+ dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols,
+ /*numLocals=*/0);
// Create memref access constraint by equating src/dst access functions.
// Note that this check is conservative, and will failure in the future
// when local variables for mod/div exprs are supported.
if (!addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap,
- &dependenceDomain))
+ dependenceConstraints))
return true;
// Add 'src' happens before 'dst' ordering constraints.
addOrderingConstraints(srcIterationDomainContext, dstIterationDomainContext,
- valuePosMap, loopDepth, &dependenceDomain);
+ valuePosMap, loopDepth, dependenceConstraints);
// Add src and dst domain constraints.
addDomainConstraints(srcIterationDomainContext, dstIterationDomainContext,
- valuePosMap, &dependenceDomain);
+ valuePosMap, dependenceConstraints);
// Return false if the solution space is empty: no dependence.
- if (dependenceDomain.isEmpty()) {
+ if (dependenceConstraints->isEmpty()) {
return false;
}
// Compute dependence direction vector and return true.
- computeDirectionVector(srcIterationDomainContext, dstIterationDomainContext,
- loopDepth, &dependenceDomain, dependenceComponents);
+ if (dependenceComponents != nullptr) {
+ computeDirectionVector(srcIterationDomainContext, dstIterationDomainContext,
+ loopDepth, dependenceConstraints,
+ dependenceComponents);
+ }
return true;
}
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index 4a344a1..9d14405 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -745,7 +745,6 @@
// add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
const auto &flatExpr = flatExprs[r];
-
// eqToAdd is the equality corresponding to the flattened affine expression.
SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
// Set the coefficient for this result to one.
@@ -1100,6 +1099,54 @@
return posLimit - posStart;
}
+// Returns an AffineMap which represents 'pos' in equality constraint 'idx',
+// as a function of dim and symbols identifers in all other positions.
+// TODO(andydavis) Add local variable support to this function.
+AffineMap FlatAffineConstraints::toAffineMapFromEq(
+ unsigned idx, unsigned pos, MLIRContext *context,
+ SmallVectorImpl<unsigned> *nonZeroDimIds,
+ SmallVectorImpl<unsigned> *nonZeroSymbolIds) {
+ assert(getNumLocalIds() == 0);
+ assert(idx < getNumEqualities());
+ int64_t v = atEq(idx, pos);
+ // Return if coefficient at (idx, pos) is zero or does not divide constant.
+ if (v == 0 || (atEq(idx, getNumIds()) % v != 0))
+ return AffineMap::Null();
+ // Check that coefficient at 'pos' divides all other coefficient in row 'idx'.
+ for (unsigned j = 0, e = getNumIds(); j < e; ++j) {
+ if (j != pos && (atEq(idx, j) % v != 0))
+ return AffineMap::Null();
+ }
+ // Build AffineExpr solving for identifier 'pos' in terms of all others.
+ auto expr = getAffineConstantExpr(0, context);
+ unsigned mapNumDims = 0;
+ unsigned mapNumSymbols = 0;
+ for (unsigned j = 0, e = getNumIds(); j < e; ++j) {
+ if (j == pos)
+ continue;
+ int64_t c = atEq(idx, j);
+ if (c == 0)
+ continue;
+ // Divide 'c' by 'v' from 'pos' for which we are solving.
+ c /= v;
+ if (j < numDims) {
+ expr = expr + getAffineDimExpr(mapNumDims++, context) * c;
+ nonZeroDimIds->push_back(j);
+ } else {
+ expr =
+ expr + getAffineSymbolExpr(mapNumDims + mapNumSymbols++, context) * c;
+ nonZeroSymbolIds->push_back(j);
+ }
+ expr = expr * (-1);
+ }
+ // Add constant term to AffineExpr.
+ int64_t c = atEq(idx, getNumIds());
+ if (c > 0) {
+ expr = expr + (c / v) * (-1);
+ }
+ return AffineMap::get(mapNumDims, mapNumSymbols, {expr}, {});
+}
+
void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
assert(eq.size() == getNumCols());
unsigned offset = equalities.size();
diff --git a/lib/Analysis/MemRefDependenceCheck.cpp b/lib/Analysis/MemRefDependenceCheck.cpp
index 7a58818..2e3df2d 100644
--- a/lib/Analysis/MemRefDependenceCheck.cpp
+++ b/lib/Analysis/MemRefDependenceCheck.cpp
@@ -148,8 +148,10 @@
unsigned numCommonLoops =
getNumCommonSurroundingLoops(srcLoops, dstLoops);
for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
+ FlatAffineConstraints dependenceConstraints;
llvm::SmallVector<DependenceComponent, 2> dependenceComponents;
bool ret = checkMemrefAccessDependence(srcAccess, dstAccess, d,
+ &dependenceConstraints,
&dependenceComponents);
// TODO(andydavis) Print dependence type (i.e. RAW, etc) and print
// distance vectors as: ([2, 3], [0, 10]). Also, shorten distance
diff --git a/lib/Analysis/Utils.cpp b/lib/Analysis/Utils.cpp
index 293d912..3fe22e9 100644
--- a/lib/Analysis/Utils.cpp
+++ b/lib/Analysis/Utils.cpp
@@ -310,3 +310,132 @@
bool emitError);
template bool mlir::boundCheckLoadOrStoreOp(OpPointer<StoreOp> storeOp,
bool emitError);
+
+// Returns in 'positions' the StmtBlock positions of 'stmt' in each ancestor
+// StmtBlock from the StmtBlock containing statement, stopping at 'limitBlock'.
+static void findStmtPosition(const Statement *stmt, StmtBlock *limitBlock,
+ SmallVectorImpl<unsigned> *positions) {
+ StmtBlock *block = stmt->getBlock();
+ while (block != limitBlock) {
+ int stmtPosInBlock = block->findStmtPosInBlock(*stmt);
+ assert(stmtPosInBlock >= 0);
+ positions->push_back(stmtPosInBlock);
+ stmt = block->getContainingStmt();
+ block = stmt->getBlock();
+ }
+ std::reverse(positions->begin(), positions->end());
+}
+
+// Returns the Statement in a possibly nested set of StmtBlocks, where the
+// position of the statement is represented by 'positions', which has a
+// StmtBlock position for each level of nesting.
+static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
+ unsigned level, StmtBlock *block) {
+ unsigned i = 0;
+ for (auto &stmt : *block) {
+ if (i != positions[level]) {
+ ++i;
+ continue;
+ }
+ if (level == positions.size() - 1)
+ return &stmt;
+ if (auto *childForStmt = dyn_cast<ForStmt>(&stmt))
+ return getStmtAtPosition(positions, level + 1, childForStmt);
+
+ if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
+ auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen());
+ if (ret != nullptr)
+ return ret;
+ if (auto *elseClause = ifStmt->getElse())
+ return getStmtAtPosition(positions, level + 1, elseClause);
+ }
+ }
+ return nullptr;
+}
+
+// TODO(andydavis) Support a 'dstLoopDepth' argument for computation slice
+// insertion (currently the computation slice is inserted at the same
+// loop depth as 'dstAccess.opStmt'.
+ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
+ MemRefAccess *dstAccess) {
+ FlatAffineConstraints dependenceConstraints;
+ if (!checkMemrefAccessDependence(*srcAccess, *dstAccess, /*loopDepth=*/0,
+ &dependenceConstraints,
+ /*dependenceComponents=*/nullptr)) {
+ return nullptr;
+ }
+ // Get loop nest surrounding src operation.
+ SmallVector<ForStmt *, 4> srcLoopNest;
+ getLoopIVs(*srcAccess->opStmt, &srcLoopNest);
+ unsigned srcLoopNestSize = srcLoopNest.size();
+
+ // Get loop nest surrounding dst operation.
+ SmallVector<ForStmt *, 4> dstLoopNest;
+ getLoopIVs(*dstAccess->opStmt, &dstLoopNest);
+ unsigned dstLoopNestSize = dstLoopNest.size();
+
+ // Solve for src IVs in terms of dst IVs, symbols and constants.
+ SmallVector<AffineMap, 4> srcIvMaps(srcLoopNestSize, AffineMap::Null());
+ std::vector<SmallVector<MLValue *, 2>> srcIvOperands(srcLoopNestSize);
+ for (unsigned i = 0; i < srcLoopNestSize; ++i) {
+ auto cst = dependenceConstraints.clone();
+ for (int j = srcLoopNestSize - 1; j >= 0; --j) {
+ if (i != j)
+ cst->projectOut(j);
+ }
+ if (cst->getNumEqualities() != 1) {
+ srcIvMaps[i] = AffineMap::Null();
+ continue;
+ }
+ SmallVector<unsigned, 2> nonZeroDimIds;
+ SmallVector<unsigned, 2> nonZeroSymbolIds;
+ srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opStmt->getContext(),
+ &nonZeroDimIds, &nonZeroSymbolIds);
+ if (srcIvMaps[i] == AffineMap::Null())
+ continue;
+ // Add operands for all non-zero dst dims and symbols.
+ // TODO(andydavis) Add local variable support.
+ for (auto dimId : nonZeroDimIds) {
+ srcIvOperands[i].push_back(dstLoopNest[dimId - 1]);
+ }
+ // TODO(andydavis) Add symbols from the access function. Ideally, we
+ // should be able to query the constaint system for the MLValue associated
+ // with a symbol identifiers in 'nonZeroSymbolIds'.
+ }
+
+ // Find the stmt block positions of 'srcAccess->opStmt' within 'srcLoopNest'.
+ SmallVector<unsigned, 4> positions;
+ findStmtPosition(srcAccess->opStmt, srcLoopNest[0]->getBlock(), &positions);
+
+ // Clone src loop nest and insert it a the beginning of the statement block
+ // of the same loop in which containts 'dstAccess->opStmt'.
+ auto *dstForStmt = dstLoopNest[dstLoopNestSize - 1];
+ MLFuncBuilder b(dstForStmt, dstForStmt->begin());
+ DenseMap<const MLValue *, MLValue *> operandMap;
+ auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
+
+ // Lookup stmt in cloned 'sliceLoopNest' at 'positions'.
+ Statement *sliceStmt =
+ getStmtAtPosition(positions, /*level=*/0, sliceLoopNest);
+ // Get loop nest surrounding 'sliceStmt'.
+ SmallVector<ForStmt *, 4> sliceSurroundingLoops;
+ getLoopIVs(*sliceStmt, &sliceSurroundingLoops);
+ unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
+
+ // Update loop bounds for loops in 'sliceLoopNest'.
+ for (unsigned i = dstLoopNestSize; i < sliceSurroundingLoopsSize; ++i) {
+ auto *forStmt = sliceSurroundingLoops[i];
+ unsigned index = i - dstLoopNestSize;
+ AffineMap lbMap = srcIvMaps[index];
+ if (lbMap == AffineMap::Null())
+ continue;
+ forStmt->setLowerBound(srcIvOperands[index], lbMap);
+ // Create upper bound map with is lower bound map + 1;
+ assert(lbMap.getNumResults() == 1);
+ AffineExpr ubResultExpr = lbMap.getResult(0) + 1;
+ AffineMap ubMap = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
+ {ubResultExpr}, {});
+ forStmt->setUpperBound(srcIvOperands[index], ubMap);
+ }
+ return sliceLoopNest;
+}
diff --git a/lib/Transforms/LoopFusion.cpp b/lib/Transforms/LoopFusion.cpp
index 3db290f..521fca8 100644
--- a/lib/Transforms/LoopFusion.cpp
+++ b/lib/Transforms/LoopFusion.cpp
@@ -20,7 +20,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -31,16 +33,25 @@
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/raw_ostream.h"
+
+using llvm::SetVector;
using namespace mlir;
namespace {
-/// Loop fusion pass. This pass fuses adjacent loops in MLFunctions which
-/// access the same memref with no dependences.
-// See MatchTestPattern for details on candidate loop selection.
+/// Loop fusion pass. This pass currently supports a greedy fusion policy,
+/// which fuses loop nests with single-writer/single-reader memref dependences
+/// with the goal of improving locality.
+
+// TODO(andydavis) Support fusion of source loop nests which write to multiple
+// memrefs, where each memref can have multiple users (if profitable).
// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
// and add support for more general loop fusion algorithms.
+
struct LoopFusion : public FunctionPass {
LoopFusion() : FunctionPass(&LoopFusion::passID) {}
@@ -48,51 +59,12 @@
static char passID;
};
-// LoopCollector walks the statements in an MLFunction and builds a map from
-// StmtBlocks to a list of loops within the StmtBlock, and a map from ForStmts
-// to the list of loads and stores with its StmtBlock.
-class LoopCollector : public StmtWalker<LoopCollector> {
-public:
- DenseMap<StmtBlock *, SmallVector<ForStmt *, 2>> loopMap;
- DenseMap<ForStmt *, SmallVector<OperationStmt *, 2>> loadsAndStoresMap;
- bool hasIfStmt = false;
-
- void visitForStmt(ForStmt *forStmt) {
- loopMap[forStmt->getBlock()].push_back(forStmt);
- }
-
- void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
-
- void visitOperationStmt(OperationStmt *opStmt) {
- if (auto *parentStmt = opStmt->getParentStmt()) {
- if (auto *parentForStmt = dyn_cast<ForStmt>(parentStmt)) {
- if (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()) {
- loadsAndStoresMap[parentForStmt].push_back(opStmt);
- }
- }
- }
- }
-};
-
} // end anonymous namespace
char LoopFusion::passID = 0;
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
-// TODO(andydavis) Remove the following test code when more general loop
-// fusion is supported.
-struct FusionCandidate {
- // Loop nest of ForStmts with 'accessA' in the inner-most loop.
- SmallVector<ForStmt *, 2> forStmtsA;
- // Load or store operation within loop nest 'forStmtsA'.
- MemRefAccess accessA;
- // Loop nest of ForStmts with 'accessB' in the inner-most loop.
- SmallVector<ForStmt *, 2> forStmtsB;
- // Load or store operation within loop nest 'forStmtsB'.
- MemRefAccess accessB;
-};
-
static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt,
MemRefAccess *access) {
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
@@ -116,137 +88,348 @@
}
}
-// Checks if 'forStmtA' and 'forStmtB' match specific test criterion:
-// constant loop bounds, no nested loops, single StoreOp in 'forStmtA' and
-// a single LoadOp in 'forStmtB'.
-// Returns true if the test pattern matches, false otherwise.
-static bool MatchTestPatternLoopPair(LoopCollector *lc,
- FusionCandidate *candidate,
- ForStmt *forStmtA, ForStmt *forStmtB) {
- if (forStmtA == nullptr || forStmtB == nullptr)
- return false;
- // Return if 'forStmtA' and 'forStmtB' do not have matching constant
- // bounds and step.
- if (!forStmtA->hasConstantBounds() || !forStmtB->hasConstantBounds() ||
- forStmtA->getConstantLowerBound() != forStmtB->getConstantLowerBound() ||
- forStmtA->getConstantUpperBound() != forStmtB->getConstantUpperBound() ||
- forStmtA->getStep() != forStmtB->getStep())
- return false;
+// FusionCandidate encapsulates source and destination memref access within
+// loop nests which are candidates for loop fusion.
+struct FusionCandidate {
+ // Load or store access within src loop nest to be fused into dst loop nest.
+ MemRefAccess srcAccess;
+ // Load or store access within dst loop nest.
+ MemRefAccess dstAccess;
+};
- // Return if 'forStmtA' or 'forStmtB' have nested loops.
- if (lc->loopMap.count(forStmtA) > 0 || lc->loopMap.count(forStmtB))
- return false;
-
- // Return if 'forStmtA' or 'forStmtB' do not have exactly one load or store.
- if (lc->loadsAndStoresMap[forStmtA].size() != 1 ||
- lc->loadsAndStoresMap[forStmtB].size() != 1)
- return false;
-
- // Get load/store access for forStmtA.
- getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtA][0],
- &candidate->accessA);
- // Return if 'accessA' is not a store.
- if (!candidate->accessA.opStmt->isa<StoreOp>())
- return false;
-
- // Get load/store access for forStmtB.
- getSingleMemRefAccess(lc->loadsAndStoresMap[forStmtB][0],
- &candidate->accessB);
-
- // Return if accesses do not access the same memref.
- if (candidate->accessA.memref != candidate->accessB.memref)
- return false;
-
- candidate->forStmtsA.push_back(forStmtA);
- candidate->forStmtsB.push_back(forStmtB);
- return true;
+static FusionCandidate buildFusionCandidate(OperationStmt *srcStoreOpStmt,
+ OperationStmt *dstLoadOpStmt) {
+ FusionCandidate candidate;
+ // Get store access for src loop nest.
+ getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess);
+ // Get load access for dst loop nest.
+ getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess);
+ return candidate;
}
-// Returns the child ForStmt of 'parent' if unique, returns false otherwise.
-ForStmt *getSingleForStmtChild(ForStmt *parent) {
- if (parent->getStatements().size() == 1 && isa<ForStmt>(parent->front()))
- return dyn_cast<ForStmt>(&parent->front());
- return nullptr;
-}
+namespace {
-// Checks for a specific ForStmt/OpStatment test pattern in 'f', returns true
-// on success and resturns fusion candidate in 'candidate'. Returns false
-// otherwise.
-// Currently supported test patterns:
-// *) Adjacent loops with a StoreOp the only op in first loop, and a LoadOp the
-// only op in the second loop (both load/store accessing the same memref).
-// *) As above, but with one level of perfect loop nesting.
+// LoopNestStateCollector walks loop nests and collects load and store
+// operations, and whether or not an IfStmt was encountered in the loop nest.
+class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> {
+public:
+ SmallVector<ForStmt *, 4> forStmts;
+ SmallVector<OperationStmt *, 4> loadOpStmts;
+ SmallVector<OperationStmt *, 4> storeOpStmts;
+ bool hasIfStmt = false;
+
+ void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
+
+ void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
+
+ void visitOperationStmt(OperationStmt *opStmt) {
+ if (opStmt->isa<LoadOp>())
+ loadOpStmts.push_back(opStmt);
+ if (opStmt->isa<StoreOp>())
+ storeOpStmts.push_back(opStmt);
+ }
+};
+
+// GreedyFusionPolicy greedily fuses loop nests which have a producer/consumer
+// relationship on a memref, with the goal of improving locality. Currently,
+// this the producer/consumer relationship is required to be unique in the
+// MLFunction (there are TODOs to relax this constraint in the future).
//
-// TODO(andydavis) Look into using ntv@ pattern matcher here.
-static bool MatchTestPattern(MLFunction *f, FusionCandidate *candidate) {
- LoopCollector lc;
- lc.walk(f);
- // Return if an IfStmt was found or if less than two ForStmts were found.
- if (lc.hasIfStmt || lc.loopMap.count(f) == 0 || lc.loopMap[f].size() < 2)
- return false;
- auto *forStmtA = lc.loopMap[f][0];
- auto *forStmtB = lc.loopMap[f][1];
- if (!MatchTestPatternLoopPair(&lc, candidate, forStmtA, forStmtB)) {
- // Check for one level of loop nesting.
- candidate->forStmtsA.push_back(forStmtA);
- candidate->forStmtsB.push_back(forStmtB);
- return MatchTestPatternLoopPair(&lc, candidate,
- getSingleForStmtChild(forStmtA),
- getSingleForStmtChild(forStmtB));
- }
- return true;
-}
+// The steps of the algorithm are as follows:
+//
+// *) Initialize. While visiting each statement in the MLFunction do:
+// *) Assign each top-level ForStmt a 'position' which is its initial
+// position in the MLFunction's StmtBlock at the start of the pass.
+// *) Gather memref load/store state aggregated by top-level statement. For
+// example, all loads and stores contained in a loop nest are aggregated
+// under the loop nest's top-level ForStmt.
+// *) Add each top-level ForStmt to a worklist.
+//
+// *) Run. The algorithm processes the worklist with the following steps:
+// *) The worklist is processed in reverse order (starting from the last
+// top-level ForStmt in the MLFunction).
+// *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate
+// destination ForStmt into which fusion will be attempted.
+// *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'.
+// *) For each LoadOp in 'dstLoadOps' do:
+// *) Lookup dependent loop nests at earlier positions in the MLFunction
+// which have a single store op to the same memref.
+// *) Check if dependences would be violated by the fusion. For example,
+// the src loop nest may load from memrefs which are different than
+// the producer-consumer memref between src and dest loop nests.
+// *) Get a computation slice of 'srcLoopNest', which adjust its loop
+// bounds to be functions of 'dstLoopNest' IVs and symbols.
+// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
+// just before the dst load op user.
+// *) Add the newly fused load/store operation statements to the state,
+// and also add newly fuse load ops to 'dstLoopOps' to be considered
+// as fusion dst load ops in another iteration.
+// *) Remove old src loop nest and its associated state.
+//
+// Given a graph where top-level statements are vertices in the set 'V' and
+// edges in the set 'E' are dependences between vertices, this algorithm
+// takes O(V) time for initialization, and has runtime O(V * E).
+// TODO(andydavis) Reduce this time complexity to O(V + E).
+//
+// This greedy algorithm is not 'maximally' but there is a TODO to fix this.
+//
+// TODO(andydavis) Experiment with other fusion policies.
+struct GreedyFusionPolicy {
+ // Convenience wrapper with information about 'stmt' ready to access.
+ struct StmtInfo {
+ Statement *stmt;
+ bool isOrContainsIfStmt = false;
+ };
+ // The worklist of top-level loop nest positions.
+ SmallVector<unsigned, 4> worklist;
+ // Mapping from top-level position to StmtInfo.
+ DenseMap<unsigned, StmtInfo> posToStmtInfo;
+ // Mapping from memref MLValue to set of top-level positions of loop nests
+ // which contain load ops on that memref.
+ DenseMap<MLValue *, DenseSet<unsigned>> memrefToLoadPosSet;
+ // Mapping from memref MLValue to set of top-level positions of loop nests
+ // which contain store ops on that memref.
+ DenseMap<MLValue *, DenseSet<unsigned>> memrefToStorePosSet;
+ // Mapping from top-level loop nest to the set of load ops it contains.
+ DenseMap<ForStmt *, SetVector<OperationStmt *>> forStmtToLoadOps;
+ // Mapping from top-level loop nest to the set of store ops it contains.
+ DenseMap<ForStmt *, SetVector<OperationStmt *>> forStmtToStoreOps;
-// FuseLoops implements the code generation mechanics of loop fusion.
-// Fuses the operations statments from the inner-most loop in 'c.forStmtsB',
-// by cloning them into the inner-most loop in 'c.forStmtsA', then erasing
-// old statements and loops.
-static void fuseLoops(const FusionCandidate &c) {
- MLFuncBuilder builder(c.forStmtsA.back(),
- StmtBlock::iterator(c.forStmtsA.back()->end()));
- DenseMap<const MLValue *, MLValue *> operandMap;
- assert(c.forStmtsA.size() == c.forStmtsB.size());
- for (unsigned i = 0, e = c.forStmtsA.size(); i < e; i++) {
- // Map loop IVs to 'forStmtB[i]' to loop IV for 'forStmtA[i]'.
- operandMap[c.forStmtsB[i]] = c.forStmtsA[i];
+ GreedyFusionPolicy(MLFunction *f) { init(f); }
+
+ void run() {
+ if (hasIfStmts())
+ return;
+
+ while (!worklist.empty()) {
+ // Pop the position of a loop nest into which fusion will be attempted.
+ unsigned dstPos = worklist.back();
+ worklist.pop_back();
+ // Skip if 'dstPos' is not tracked (was fused into another loop nest).
+ if (posToStmtInfo.count(dstPos) == 0)
+ continue;
+ // Get the top-level ForStmt at 'dstPos'.
+ auto *dstForStmt = getForStmtAtPos(dstPos);
+ // Skip if this ForStmt contains no load ops.
+ if (forStmtToLoadOps.count(dstForStmt) == 0)
+ continue;
+
+ // Greedy Policy: iterate through load ops in 'dstForStmt', greedily
+ // fusing in src loop nests which have a single store op on the same
+ // memref, until a fixed point is reached where there is nothing left to
+ // fuse.
+ SetVector<OperationStmt *> dstLoadOps = forStmtToLoadOps[dstForStmt];
+ while (!dstLoadOps.empty()) {
+ auto *dstLoadOpStmt = dstLoadOps.pop_back_val();
+
+ auto dstLoadOp = dstLoadOpStmt->cast<LoadOp>();
+ auto *memref = cast<MLValue>(dstLoadOp->getMemRef());
+ // Skip if not single src store / dst load pair on 'memref'.
+ if (memrefToLoadPosSet[memref].size() != 1 ||
+ memrefToStorePosSet[memref].size() != 1)
+ continue;
+ unsigned srcPos = *memrefToStorePosSet[memref].begin();
+ if (srcPos >= dstPos)
+ continue;
+ auto *srcForStmt = getForStmtAtPos(srcPos);
+ // Skip if 'srcForStmt' has more than one store op.
+ if (forStmtToStoreOps[srcForStmt].size() > 1)
+ continue;
+ // Skip if fusion would violated dependences between 'memref' access
+ // for loop nests between 'srcPos' and 'dstPos':
+ // For each src load op: check for store ops in range (srcPos, dstPos).
+ // For each src store op: check for load ops in range (srcPos, dstPos).
+ if (moveWouldViolateDependences(srcPos, dstPos))
+ continue;
+ auto *srcStoreOpStmt = forStmtToStoreOps[srcForStmt].front();
+ // Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'.
+ FusionCandidate candidate =
+ buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt);
+ // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
+ auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
+ &candidate.srcAccess, &candidate.dstAccess);
+ if (sliceLoopNest != nullptr) {
+ // Remove 'srcPos' mappings from 'state'.
+ moveAccessesAndRemovePos(srcPos, dstPos);
+ // Record all load/store accesses in 'sliceLoopNest' at 'dstPos'.
+ LoopNestStateCollector collector;
+ collector.walkForStmt(sliceLoopNest);
+ // Record mappings for loads and stores from 'collector'.
+ for (auto *opStmt : collector.loadOpStmts) {
+ addLoadOpStmtAt(dstPos, opStmt, dstForStmt);
+ // Add newly fused load ops to 'dstLoadOps' to be considered for
+ // fusion on subsequent iterations.
+ dstLoadOps.insert(opStmt);
+ }
+ for (auto *opStmt : collector.storeOpStmts) {
+ addStoreOpStmtAt(dstPos, opStmt, dstForStmt);
+ }
+ for (auto *forStmt : collector.forStmts) {
+ promoteIfSingleIteration(forStmt);
+ }
+ // Remove old src loop nest.
+ srcForStmt->erase();
+ }
+ }
+ }
}
- // Clone the body of inner-most loop in 'forStmtsB', into the body of
- // inner-most loop in 'forStmtsA'.
- SmallVector<Statement *, 2> stmtsToErase;
- auto *innerForStmtB = c.forStmtsB.back();
- for (auto &stmt : *innerForStmtB) {
- builder.clone(stmt, operandMap);
- stmtsToErase.push_back(&stmt);
+
+ // Walk MLFunction 'f' assigning each top-level statement a position, and
+ // gathering state on load and store ops.
+ void init(MLFunction *f) {
+ unsigned pos = 0;
+ for (auto &stmt : *f) {
+ if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
+ // Record all loads and store accesses in 'forStmt' at 'pos'.
+ LoopNestStateCollector collector;
+ collector.walkForStmt(forStmt);
+ // Create StmtInfo for 'forStmt' for top-level loop nests.
+ addStmtInfoAt(pos, forStmt, collector.hasIfStmt);
+ // Record mappings for loads and stores from 'collector'.
+ for (auto *opStmt : collector.loadOpStmts) {
+ addLoadOpStmtAt(pos, opStmt, forStmt);
+ }
+ for (auto *opStmt : collector.storeOpStmts) {
+ addStoreOpStmtAt(pos, opStmt, forStmt);
+ }
+ // Add 'pos' associated with 'forStmt' to worklist.
+ worklist.push_back(pos);
+ }
+ if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+ if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
+ // Create StmtInfo for top-level load op.
+ addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/false);
+ addLoadOpStmtAt(pos, opStmt, /*containingForStmt=*/nullptr);
+ }
+ if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
+ // Create StmtInfo for top-level store op.
+ addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/false);
+ addStoreOpStmtAt(pos, opStmt, /*containingForStmt=*/nullptr);
+ }
+ }
+ if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
+ addStmtInfoAt(pos, &stmt, /*hasIfStmt=*/true);
+ }
+ ++pos;
+ }
}
- // Erase 'forStmtB' and its statement list.
- for (auto it = stmtsToErase.rbegin(); it != stmtsToErase.rend(); ++it)
- (*it)->erase();
- // Erase 'forStmtsB' loop nest.
- for (int i = static_cast<int>(c.forStmtsB.size()) - 1; i >= 0; --i)
- c.forStmtsB[i]->erase();
-}
+
+ // Check if fusing loop nest at 'srcPos' into the loop nest at 'dstPos'
+ // would violated any dependences w.r.t other loop nests in that range.
+ bool moveWouldViolateDependences(unsigned srcPos, unsigned dstPos) {
+ // Lookup src ForStmt at 'srcPos'.
+ auto *srcForStmt = getForStmtAtPos(srcPos);
+ // For each src load op: check for store ops in range (srcPos, dstPos).
+ if (forStmtToLoadOps.count(srcForStmt) > 0) {
+ for (auto *opStmt : forStmtToLoadOps[srcForStmt]) {
+ auto loadOp = opStmt->cast<LoadOp>();
+ auto *memref = cast<MLValue>(loadOp->getMemRef());
+ for (unsigned pos = srcPos + 1; pos < dstPos; ++pos) {
+ if (memrefToStorePosSet.count(memref) > 0 &&
+ memrefToStorePosSet[memref].count(pos) > 0)
+ return true;
+ }
+ }
+ }
+ // For each src store op: check for load ops in range (srcPos, dstPos).
+ if (forStmtToStoreOps.count(srcForStmt) > 0) {
+ for (auto *opStmt : forStmtToStoreOps[srcForStmt]) {
+ auto storeOp = opStmt->cast<StoreOp>();
+ auto *memref = cast<MLValue>(storeOp->getMemRef());
+ for (unsigned pos = srcPos + 1; pos < dstPos; ++pos) {
+ if (memrefToLoadPosSet.count(memref) > 0 &&
+ memrefToLoadPosSet[memref].count(pos) > 0)
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ // Update mappings of memref loads and stores at 'srcPos' to 'dstPos'.
+ void moveAccessesAndRemovePos(unsigned srcPos, unsigned dstPos) {
+ // Lookup ForStmt at 'srcPos'.
+ auto *srcForStmt = getForStmtAtPos(srcPos);
+ // Move load op accesses from src to dst.
+ if (forStmtToLoadOps.count(srcForStmt) > 0) {
+ for (auto *opStmt : forStmtToLoadOps[srcForStmt]) {
+ auto loadOp = opStmt->cast<LoadOp>();
+ auto *memref = cast<MLValue>(loadOp->getMemRef());
+ // Remove 'memref' to 'srcPos' mapping.
+ memrefToLoadPosSet[memref].erase(srcPos);
+ }
+ }
+ // Move store op accesses from src to dst.
+ if (forStmtToStoreOps.count(srcForStmt) > 0) {
+ for (auto *opStmt : forStmtToStoreOps[srcForStmt]) {
+ auto storeOp = opStmt->cast<StoreOp>();
+ auto *memref = cast<MLValue>(storeOp->getMemRef());
+ // Remove 'memref' to 'srcPos' mapping.
+ memrefToStorePosSet[memref].erase(srcPos);
+ }
+ }
+ // Remove old state.
+ forStmtToLoadOps.erase(srcForStmt);
+ forStmtToStoreOps.erase(srcForStmt);
+ posToStmtInfo.erase(srcPos);
+ }
+
+ ForStmt *getForStmtAtPos(unsigned pos) {
+ assert(posToStmtInfo.count(pos) > 0);
+ assert(isa<ForStmt>(posToStmtInfo[pos].stmt));
+ return cast<ForStmt>(posToStmtInfo[pos].stmt);
+ }
+
+ void addStmtInfoAt(unsigned pos, Statement *stmt, bool hasIfStmt) {
+ StmtInfo stmtInfo;
+ stmtInfo.stmt = stmt;
+ stmtInfo.isOrContainsIfStmt = hasIfStmt;
+ // Add mapping from 'pos' to StmtInfo for 'forStmt'.
+ posToStmtInfo[pos] = stmtInfo;
+ }
+
+ // Adds the following mappings:
+ // *) 'containingForStmt' to load 'opStmt'
+ // *) 'memref' of load 'opStmt' to 'topLevelPos'.
+ void addLoadOpStmtAt(unsigned topLevelPos, OperationStmt *opStmt,
+ ForStmt *containingForStmt) {
+ if (containingForStmt != nullptr) {
+ // Add mapping from 'containingForStmt' to 'opStmt' for load op.
+ forStmtToLoadOps[containingForStmt].insert(opStmt);
+ }
+ auto loadOp = opStmt->cast<LoadOp>();
+ auto *memref = cast<MLValue>(loadOp->getMemRef());
+ // Add mapping from 'memref' to 'topLevelPos' for load.
+ memrefToLoadPosSet[memref].insert(topLevelPos);
+ }
+
+ // Adds the following mappings:
+ // *) 'containingForStmt' to store 'opStmt'
+ // *) 'memref' of store 'opStmt' to 'topLevelPos'.
+ void addStoreOpStmtAt(unsigned topLevelPos, OperationStmt *opStmt,
+ ForStmt *containingForStmt) {
+ if (containingForStmt != nullptr) {
+ // Add mapping from 'forStmt' to 'opStmt' for store op.
+ forStmtToStoreOps[containingForStmt].insert(opStmt);
+ }
+ auto storeOp = opStmt->cast<StoreOp>();
+ auto *memref = cast<MLValue>(storeOp->getMemRef());
+ // Add mapping from 'memref' to 'topLevelPos' for store.
+ memrefToStorePosSet[memref].insert(topLevelPos);
+ }
+
+ bool hasIfStmts() {
+ for (auto &pair : posToStmtInfo)
+ if (pair.second.isOrContainsIfStmt)
+ return true;
+ return false;
+ }
+};
+
+} // end anonymous namespace
PassResult LoopFusion::runOnMLFunction(MLFunction *f) {
- FusionCandidate candidate;
- if (!MatchTestPattern(f, &candidate))
- return failure();
-
- // TODO(andydavis) Add checks for fusion-preventing dependences and ordering
- // constraints which would prevent fusion.
- // TODO(andydavis) This check is overly conservative for now. Support fusing
- // statements with compatible dependences (i.e. statements where the
- // dependence between the statements does not reverse direction when the
- // statements are fused into the same loop).
- llvm::SmallVector<DependenceComponent, 2> dependenceComponents;
- // TODO(andydavis) Check dependences at differnt loop nest depths.
- if (!checkMemrefAccessDependence(candidate.accessA, candidate.accessB,
- /*loopNestDepth=*/0,
- &dependenceComponents)) {
- // Current conservatinve test policy: No dependence exists between accesses
- // in different loop nests -> fuse loops.
- fuseLoops(candidate);
- }
-
+ GreedyFusionPolicy(f).run();
return success();
}
diff --git a/test/Transforms/loop-fusion.mlir b/test/Transforms/loop-fusion.mlir
index 2b8ce07..d0de62e 100644
--- a/test/Transforms/loop-fusion.mlir
+++ b/test/Transforms/loop-fusion.mlir
@@ -1,141 +1,555 @@
-// RUN: mlir-opt %s -loop-fusion | FileCheck %s
+// RUN: mlir-opt %s -loop-fusion -split-input-file -verify | FileCheck %s
-// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0 * 2 + 2)
-// CHECK: [[MAP1:#map[0-9]+]] = (d0) -> (d0 * 3 + 1)
-// CHECK: [[MAP2:#map[0-9]+]] = (d0) -> (d0 * 2)
-// CHECK: [[MAP3:#map[0-9]+]] = (d0) -> (d0 * 2 + 1)
-// CHECK: [[MAP4:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * 2 - d1 - s0 * 7 + 3, d0 * 9 + d1 * 3 + s1 * 13 - 10)
-// CHECK: [[MAP6:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * 2 - 1, d1 * 3 + s0 + s1 * 3)
+// TODO(andydavis) Add more tests:
+// *) Add nested fusion test cases when non-constant loop bound support is
+// added to iteration domain in dependence check.
+// *) Add a test w/ floordiv/ceildiv/mod when supported in dependence check.
+// *) Add tests which check fused computation slice indexing and loop bounds.
+// TODO(andydavis) Test clean up: move memref allocs to mlfunc args.
-// The dependence check for this test builds the following set of constraints,
-// where the equality contraint equates the two accesses to the memref (from
-// different loops), and the inequality constraints represent the upper and
-// lower bounds for each loop. After elimination, this linear system can be
-// shown to be non-empty (i.e. x0 = x1 = 1 is a solution). As such, the
-// dependence check between accesses in the two loops will return true, and
-// the loops (according to the current test loop fusion algorithm) should not be
-// fused.
-//
-// x0 x1 x2
-// 2 -3 1 = 0
-// 1 0 0 >= 0
-// -1 0 100 >= 0
-// 0 1 0 >= 0
-// 0 -1 100 >= 0
-//
-// CHECK-LABEL: mlfunc @loop_fusion_1d_should_not_fuse_loops() {
-mlfunc @loop_fusion_1d_should_not_fuse_loops() {
- %m = alloc() : memref<100xf32, (d0) -> (d0)>
- // Check that the first loop remains unfused.
- // CHECK: for %i0 = 0 to 100 {
- // CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP0]](%i0)
- // CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]{{\]}}
- // CHECK-NEXT: }
- for %i0 = 0 to 100 {
- %a0 = affine_apply (d0) -> (d0 * 2 + 2) (%i0)
- %c1 = constant 1.0 : f32
- store %c1, %m[%a0] : memref<100xf32, (d0) -> (d0)>
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0)
+
+// CHECK-LABEL: mlfunc @should_fuse_raw_dep_for_locality() {
+mlfunc @should_fuse_raw_dep_for_locality() {
+ %m = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ store %cf7, %m[%i0] : memref<10xf32>
}
- // Check that the second loop remains unfused.
- // CHECK: for %i1 = 0 to 100 {
- // CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP1]](%i1)
- // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]{{\]}}
- // CHECK-NEXT: }
- for %i1 = 0 to 100 {
- %a1 = affine_apply (d0) -> (d0 * 3 + 1) (%i1)
- %v0 = load %m[%a1] : memref<100xf32, (d0) -> (d0)>
+ for %i1 = 0 to 10 {
+ %v0 = load %m[%i1] : memref<10xf32>
}
- return
-}
-
-// The dependence check for this test builds the following set of constraints:
-//
-// x0 x1 x2
-// 2 -2 -1 = 0
-// 1 0 0 >= 0
-// -1 0 100 >= 0
-// 0 1 0 >= 0
-// 0 -1 100 >= 0
-//
-// After elimination, this linear system can be shown to have no solutions, and
-// so no dependence exists and the loops should be fused in this test (according
-// to the current trivial test loop fusion policy).
-//
-//
-// CHECK-LABEL: mlfunc @loop_fusion_1d_should_fuse_loops() {
-mlfunc @loop_fusion_1d_should_fuse_loops() {
- %m = alloc() : memref<100xf32, (d0) -> (d0)>
- // Should fuse statements from the second loop into the first loop.
- // CHECK: for %i0 = 0 to 100 {
- // CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP2]](%i0)
- // CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]{{\]}}
- // CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP3]](%i0)
- // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]{{\]}}
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: %1 = affine_apply [[MAP0]](%i0)
+ // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32>
+ // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
- for %i0 = 0 to 100 {
- %a0 = affine_apply (d0) -> (d0 * 2) (%i0)
- %c1 = constant 1.0 : f32
- store %c1, %m[%a0] : memref<100xf32, (d0) -> (d0)>
- }
-
- for %i1 = 0 to 100 {
- %a1 = affine_apply (d0) -> (d0 * 2 + 1) (%i1)
-
- %v0 = load %m[%a1] : memref<100xf32, (d0) -> (d0)>
- }
return
}
-// TODO(andydavis) Add LoopFusion tests based on fusion policy and cost model.
+// -----
-// The dependence check for this test builds the following set of
-// equality constraints (one for each memref dimension). Note: inequality
-// constraints for loop bounds not shown.
-//
-// i0 i1 i2 i3 s0 s1 s2 c
-// 2 -1 -2 0 -7 0 0 4 = 0
-// 9 3 0 -3 0 12 -3 -10 = 0
-//
-// The second equality will fail the GCD test and so the system has no solution,
-// so the loops should be fused under the current test policy.
-//
-// CHECK-LABEL: mlfunc @loop_fusion_2d_should_fuse_loops() {
-mlfunc @loop_fusion_2d_should_fuse_loops() {
- %m = alloc() : memref<10x10xf32>
+// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0)
- %s0 = constant 7 : index
- %s1 = constant 11 : index
- %s2 = constant 13 : index
- // Should fuse statements from the second loop into the first loop.
- // CHECK: for %i0 = 0 to 100 {
- // CHECK-NEXT: for %i1 = 0 to 50 {
- // CHECK-NEXT: [[I0:%[0-9]+]] = affine_apply [[MAP4]](%i0, %i1)[%c7, %c11]
- // CHECK: store {{.*}}, %{{[0-9]+}}{{\[}}[[I0]]#0, [[I0]]#1{{\]}}
- // CHECK-NEXT: [[I1:%[0-9]+]] = affine_apply [[MAP6]](%i0, %i1)[%c11, %c13]
- // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I1]]#0, [[I1]]#1{{\]}}
+// TODO(andydavis) Turn this into a proper reduction when constraints on
+// the current greedy fusion policy are relaxed.
+// CHECK-LABEL: mlfunc @should_fuse_reduction_to_pointwise() {
+mlfunc @should_fuse_reduction_to_pointwise() {
+ %a = alloc() : memref<10x10xf32>
+ %b = alloc() : memref<10xf32>
+ %c = alloc() : memref<10xf32>
+ %d = alloc() : memref<10xf32>
+
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ for %i1 = 0 to 10 {
+ %v0 = load %d[%i0] : memref<10xf32>
+ %v1 = load %a[%i0, %i1] : memref<10x10xf32>
+ %v3 = addf %v0, %v1 : f32
+ store %v3, %b[%i0] : memref<10xf32>
+ }
+ }
+ for %i2 = 0 to 10 {
+ %v4 = load %b[%i2] : memref<10xf32>
+ store %v4, %c[%i2] : memref<10xf32>
+ }
+
+ // Should fuse in entire inner loop on %i1 from source loop nest, as %i1
+ // is not used in the access function of the store/load on %b.
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i0)
+ // CHECK-NEXT: for %i1 = 0 to 10 {
+ // CHECK-NEXT: %5 = load %3[%4] : memref<10xf32>
+ // CHECK-NEXT: %6 = load %0[%4, %i1] : memref<10x10xf32>
+ // CHECK-NEXT: %7 = addf %5, %6 : f32
+ // CHECK-NEXT: store %7, %1[%4] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: %8 = load %1[%i0] : memref<10xf32>
+ // CHECK-NEXT: store %8, %2[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK: [[MAP_SHIFT_MINUS_ONE:#map[0-9]+]] = (d0) -> (d0 - 1)
+// CHECK: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0, d1) -> (d0 + 1, d1 + 1)
+
+// CHECK-LABEL: mlfunc @should_fuse_loop_nests_with_shifts() {
+mlfunc @should_fuse_loop_nests_with_shifts() {
+ %a = alloc() : memref<10x10xf32>
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ for %i1 = 0 to 10 {
+ %a0 = affine_apply (d0, d1) -> (d0 + 1, d1 + 1) (%i0, %i1)
+ store %cf7, %a[%a0#0, %a0#1] : memref<10x10xf32>
+ }
+ }
+ for %i2 = 0 to 10 {
+ for %i3 = 0 to 10 {
+ %v0 = load %a[%i2, %i3] : memref<10x10xf32>
+ }
+ }
+
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: for %i1 = 0 to 10 {
+ // CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0)
+ // CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i1)
+ // CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2)
+ // CHECK-NEXT: store %cst, %0[%3#0, %3#1] : memref<10x10xf32>
+ // CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
- for %i0 = 0 to 100 {
- for %i1 = 0 to 50 {
- %a0 = affine_apply
- (d0, d1)[s0, s1] ->
- (d0 * 2 -d1 + -7 * s0 + 3 , d0 * 9 + d1 * 3 + 13 * s1 - 10)
- (%i0, %i1)[%s0, %s1]
- %c1 = constant 1.0 : f32
- store %c1, %m[%a0#0, %a0#1] : memref<10x10xf32>
+ return
+}
+
+// -----
+
+// CHECK: [[MAP_IDENTITY:#map[0-9]+]] = (d0) -> (d0)
+
+// CHECK-LABEL: mlfunc @should_fuse_loop_nest() {
+mlfunc @should_fuse_loop_nest() {
+ %a = alloc() : memref<10x10xf32>
+ %b = alloc() : memref<10x10xf32>
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ for %i1 = 0 to 10 {
+ store %cf7, %a[%i0, %i1] : memref<10x10xf32>
+ }
+ }
+ for %i2 = 0 to 10 {
+ for %i3 = 0 to 10 {
+ %v0 = load %a[%i3, %i2] : memref<10x10xf32>
+ store %v0, %b[%i2, %i3] : memref<10x10xf32>
+ }
+ }
+ for %i4 = 0 to 10 {
+ for %i5 = 0 to 10 {
+ %v1 = load %b[%i4, %i5] : memref<10x10xf32>
}
}
- for %i2 = 0 to 100 {
- for %i3 = 0 to 50 {
- %a1 = affine_apply
- (d0, d1)[s0, s1] ->
- (d0 * 2 - 1, d1 * 3 + s0 + s1 * 3) (%i2, %i3)[%s1, %s2]
- %v0 = load %m[%a1#0, %a1#1] : memref<10x10xf32>
- }
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: for %i1 = 0 to 10 {
+ // CHECK-NEXT: %2 = affine_apply [[MAP_IDENTITY]](%i1)
+ // CHECK-NEXT: %3 = affine_apply [[MAP_IDENTITY]](%i0)
+ // CHECK-NEXT: store %cst, %0[%2, %3] : memref<10x10xf32>
+ // CHECK-NEXT: %4 = affine_apply [[MAP_IDENTITY]](%i0)
+ // CHECK-NEXT: %5 = affine_apply [[MAP_IDENTITY]](%i1)
+ // CHECK-NEXT: %6 = load %0[%5, %4] : memref<10x10xf32>
+ // CHECK-NEXT: store %6, %1[%4, %5] : memref<10x10xf32>
+ // CHECK-NEXT: %7 = load %1[%i0, %i1] : memref<10x10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0)
+
+// CHECK-LABEL: mlfunc @should_fuse_across_intermediate_loop_with_no_deps() {
+mlfunc @should_fuse_across_intermediate_loop_with_no_deps() {
+ %a = alloc() : memref<10xf32>
+ %b = alloc() : memref<10xf32>
+ %c = alloc() : memref<10xf32>
+
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ %v0 = load %a[%i0] : memref<10xf32>
+ store %v0, %b[%i0] : memref<10xf32>
}
+ for %i1 = 0 to 10 {
+ store %cf7, %c[%i1] : memref<10xf32>
+ }
+ for %i2 = 0 to 10 {
+ %v1 = load %b[%i2] : memref<10xf32>
+ }
+
+ // Should fuse first loop (past second loop with no dependences) into third.
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: store %cst, %2[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i1 = 0 to 10 {
+ // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i1)
+ // CHECK-NEXT: %4 = load %0[%3] : memref<10xf32>
+ // CHECK-NEXT: store %4, %1[%3] : memref<10xf32>
+ // CHECK-NEXT: %5 = load %1[%i1] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0)
+
+// CHECK-LABEL: mlfunc @should_fuse_all_loops() {
+mlfunc @should_fuse_all_loops() {
+ %a = alloc() : memref<10xf32>
+ %b = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ // Set up flow dependences from first and second loops to third.
+ for %i0 = 0 to 10 {
+ store %cf7, %a[%i0] : memref<10xf32>
+ }
+ for %i1 = 0 to 10 {
+ store %cf7, %b[%i1] : memref<10xf32>
+ }
+ for %i2 = 0 to 10 {
+ %v0 = load %a[%i2] : memref<10xf32>
+ %v1 = load %b[%i2] : memref<10xf32>
+ }
+
+ // Should fuse first and second loops into third.
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i0)
+ // CHECK-NEXT: store %cst, %0[%2] : memref<10xf32>
+ // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0)
+ // CHECK-NEXT: store %cst, %1[%3] : memref<10xf32>
+ // CHECK-NEXT: %4 = load %0[%i0] : memref<10xf32>
+ // CHECK-NEXT: %5 = load %1[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0)
+
+// CHECK-LABEL: mlfunc @should_fuse_first_and_second_loops() {
+mlfunc @should_fuse_first_and_second_loops() {
+ %a = alloc() : memref<10xf32>
+ %b = alloc() : memref<10xf32>
+ %c = alloc() : memref<10xf32>
+
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ store %cf7, %a[%i0] : memref<10xf32>
+ }
+ for %i1 = 0 to 10 {
+ %v0 = load %a[%i1] : memref<10xf32>
+ store %cf7, %b[%i1] : memref<10xf32>
+ }
+ for %i2 = 0 to 10 {
+ %v1 = load %c[%i2] : memref<10xf32>
+ }
+
+ // Should fuse first loop into the second (last loop should not be fused).
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0)
+ // CHECK-NEXT: store %cst, %0[%3] : memref<10xf32>
+ // CHECK-NEXT: %4 = load %0[%i0] : memref<10xf32>
+ // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i1 = 0 to 10 {
+ // CHECK-NEXT: %5 = load %2[%i1] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
return
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-LABEL: mlfunc @should_not_fuse_would_create_cycle() {
+mlfunc @should_not_fuse_would_create_cycle() {
+ %a = alloc() : memref<10xf32>
+ %b = alloc() : memref<10xf32>
+ %c = alloc() : memref<10xf32>
+
+ %cf7 = constant 7.0 : f32
+
+ // Set up the following dependences:
+ // 1) loop0 -> loop1 on memref '%a'
+ // 2) loop0 -> loop2 on memref '%b'
+ // 3) loop1 -> loop2 on memref '%c'
+ for %i0 = 0 to 10 {
+ %v0 = load %a[%i0] : memref<10xf32>
+ store %cf7, %b[%i0] : memref<10xf32>
+ }
+ for %i1 = 0 to 10 {
+ store %cf7, %a[%i1] : memref<10xf32>
+ %v1 = load %c[%i1] : memref<10xf32>
+ }
+ for %i2 = 0 to 10 {
+ %v2 = load %b[%i2] : memref<10xf32>
+ store %cf7, %c[%i2] : memref<10xf32>
+ }
+ // Should not fuse: fusing loop first loop into last would create a cycle.
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: %3 = load %0[%i0] : memref<10xf32>
+ // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i1 = 0 to 10 {
+ // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32>
+ // CHECK-NEXT: %4 = load %2[%i1] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i2 = 0 to 10 {
+ // CHECK-NEXT: %5 = load %1[%i2] : memref<10xf32>
+ // CHECK-NEXT: store %cst, %2[%i2] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK-LABEL: mlfunc @should_not_fuse_raw_dep_would_be_violated() {
+mlfunc @should_not_fuse_raw_dep_would_be_violated() {
+ %m = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ store %cf7, %m[%i0] : memref<10xf32>
+ }
+ for %i1 = 0 to 10 {
+ %v0 = load %m[%i1] : memref<10xf32>
+ }
+ for %i2 = 0 to 10 {
+ %v1 = load %m[%i2] : memref<10xf32>
+ }
+ // Fusing loop %i0 to %i2 would violate the RAW dependence between %i0 and %i1
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i1 = 0 to 10 {
+ // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i2 = 0 to 10 {
+ // CHECK-NEXT: %2 = load %0[%i2] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK-LABEL: mlfunc @should_not_fuse_waw_dep_would_be_violated() {
+mlfunc @should_not_fuse_waw_dep_would_be_violated() {
+ %m = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ store %cf7, %m[%i0] : memref<10xf32>
+ }
+ for %i1 = 0 to 10 {
+ store %cf7, %m[%i1] : memref<10xf32>
+ }
+ for %i2 = 0 to 10 {
+ %v1 = load %m[%i2] : memref<10xf32>
+ }
+ // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i1 = 0 to 10 {
+ // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i2 = 0 to 10 {
+ // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK-LABEL: mlfunc @should_not_fuse_war_dep_would_be_violated() {
+mlfunc @should_not_fuse_war_dep_would_be_violated() {
+ %a = alloc() : memref<10xf32>
+ %b = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ %v0 = load %a[%i0] : memref<10xf32>
+ store %v0, %b[%i0] : memref<10xf32>
+ }
+ for %i1 = 0 to 10 {
+ store %cf7, %a[%i1] : memref<10xf32>
+ }
+ for %i2 = 0 to 10 {
+ %v1 = load %b[%i2] : memref<10xf32>
+ }
+ // Fusing loop %i0 to %i2 would violate the WAR dependence between %i0 and %i1
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32>
+ // CHECK-NEXT: store %2, %1[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i1 = 0 to 10 {
+ // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i2 = 0 to 10 {
+ // CHECK-NEXT: %3 = load %1[%i2] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK-LABEL: mlfunc @should_not_fuse_if_top_level_access() {
+mlfunc @should_not_fuse_if_top_level_access() {
+ %m = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ store %cf7, %m[%i0] : memref<10xf32>
+ }
+ for %i1 = 0 to 10 {
+ %v0 = load %m[%i1] : memref<10xf32>
+ }
+
+ %c0 = constant 4 : index
+ %v1 = load %m[%c0] : memref<10xf32>
+ // Top-level load to '%m' should prevent fusion.
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i1 = 0 to 10 {
+ // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32>
+ // CHECK-NEXT: }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0)
+
+// CHECK-LABEL: mlfunc @should_fuse_no_top_level_access() {
+mlfunc @should_fuse_no_top_level_access() {
+ %m = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ store %cf7, %m[%i0] : memref<10xf32>
+ }
+ for %i1 = 0 to 10 {
+ %v0 = load %m[%i1] : memref<10xf32>
+ }
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: %1 = affine_apply #map0(%i0)
+ // CHECK-NEXT: store %cst, %0[%1] : memref<10xf32>
+ // CHECK-NEXT: %2 = load %0[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+#set0 = (d0) : (1 == 0)
+
+// CHECK-LABEL: mlfunc @should_not_fuse_if_stmt_at_top_level() {
+mlfunc @should_not_fuse_if_stmt_at_top_level() {
+ %m = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+
+ for %i0 = 0 to 10 {
+ store %cf7, %m[%i0] : memref<10xf32>
+ }
+ for %i1 = 0 to 10 {
+ %v0 = load %m[%i1] : memref<10xf32>
+ }
+ %c0 = constant 4 : index
+ if #set0(%c0) {
+ }
+ // Top-level IfStmt should prevent fusion.
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i1 = 0 to 10 {
+ // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32>
+ // CHECK-NEXT: }
+ return
+}
+
+// -----
+
+#set0 = (d0) : (1 == 0)
+
+// CHECK-LABEL: mlfunc @should_not_fuse_if_stmt_in_loop_nest() {
+mlfunc @should_not_fuse_if_stmt_in_loop_nest() {
+ %m = alloc() : memref<10xf32>
+ %cf7 = constant 7.0 : f32
+ %c4 = constant 4 : index
+
+ for %i0 = 0 to 10 {
+ store %cf7, %m[%i0] : memref<10xf32>
+ }
+ for %i1 = 0 to 10 {
+ if #set0(%c4) {
+ }
+ %v0 = load %m[%i1] : memref<10xf32>
+ }
+
+ // IfStmt in ForStmt should prevent fusion.
+ // CHECK: for %i0 = 0 to 10 {
+ // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK: for %i1 = 0 to 10 {
+ // CHECK-NEXT: if #set0(%c4) {
+ // CHECK-NEXT: }
+ // CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32>
+ // CHECK-NEXT: }
+ return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0)
+// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2)
+// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1, d2) -> (d1, d2, d0)
+
+// CHECK-LABEL: mlfunc @remap_ivs() {
+mlfunc @remap_ivs() {
+ %m = alloc() : memref<10x20x30xf32>
+
+ %cf7 = constant 7.0 : f32
+ for %i0 = 0 to 10 {
+ for %i1 = 0 to 20 {
+ for %i2 = 0 to 30 {
+ %a0 = affine_apply (d0, d1, d2) -> (d0, d1, d2) (%i0, %i1, %i2)
+ store %cf7, %m[%a0#0, %a0#1, %a0#2] : memref<10x20x30xf32>
+ }
+ }
+ }
+ for %i3 = 0 to 30 {
+ for %i4 = 0 to 10 {
+ for %i5 = 0 to 20 {
+ %a1 = affine_apply (d0, d1, d2) -> (d1, d2, d0) (%i3, %i4, %i5)
+ %v0 = load %m[%a1#0, %a1#1, %a1#2] : memref<10x20x30xf32>
+ }
+ }
+ }
+// CHECK: for %i0 = 0 to 30 {
+// CHECK-NEXT: for %i1 = 0 to 10 {
+// CHECK-NEXT: for %i2 = 0 to 20 {
+// CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1)
+// CHECK-NEXT: %2 = affine_apply [[MAP0]](%i2)
+// CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0)
+// CHECK-NEXT: %4 = affine_apply [[MAP1]](%1, %2, %3)
+// CHECK-NEXT: store %cst, %0[%4#0, %4#1, %4#2] : memref<10x20x30xf32>
+// CHECK-NEXT: %5 = affine_apply [[MAP2]](%i0, %i1, %i2)
+// CHECK-NEXT: %6 = load %0[%5#0, %5#1, %5#2] : memref<10x20x30xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+ return
+}