Integer set + operands / affine if op canonicalization
- turn canonicalizeMapAndOperands into a template that works on both
sets and maps, and use it to introduce a utility to canonicalize an
affine integer set and its operands
- add pattern to canonicalize affine if op's.
- rename IntegerSet::getNumOperands -> IntegerSet::getNumInputs to be
consistent with AffineMap
- add missing accessors for IntegerSet
Doesn't need extensive testing since canonicalizeSetAndOperands just
reuses canonicalizeMapAndOperands' logic, and the latter is tested on
affine.apply map + operands; the new method works the same way on an
integer set + operands of an affine if op for example.
Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>
Closes #112
COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/112 from bondhugula:set-canonicalize eff72f23250b96fa7d9f5caff3877440f5de2cec
PiperOrigin-RevId: 267532876
diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
index a6af20e..03b945c 100644
--- a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
+++ b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
@@ -522,6 +522,10 @@
/// 2. drop unused dims and symbols from map
void canonicalizeMapAndOperands(AffineMap *map,
llvm::SmallVectorImpl<Value *> *operands);
+/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
+/// for affine maps.
+void canonicalizeSetAndOperands(IntegerSet *set,
+ llvm::SmallVectorImpl<Value *> *operands);
/// Returns a composed AffineApplyOp by composing `map` and `operands` with
/// other AffineApplyOps supplying those operands. The operands of the resulting
diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
index 237692c..4961ce8 100644
--- a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
+++ b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
@@ -223,6 +223,11 @@
IntegerSet getIntegerSet();
void setIntegerSet(IntegerSet newSet);
+ /// Sets the integer set with its operands. The size of 'operands' must not
+ /// exceed the current number of operands for this instance, as the operands
+ /// list of AffineIf is not resizable.
+ void setConditional(IntegerSet set, ArrayRef<Value *> operands);
+
OpBuilder getThenBodyBuilder() {
assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
Block &body = thenRegion().front();
@@ -234,6 +239,8 @@
return OpBuilder(&body, std::prev(body.end()));
}
}];
+
+ let hasCanonicalizer = 1;
}
def AffineTerminatorOp :
diff --git a/third_party/mlir/include/mlir/IR/IntegerSet.h b/third_party/mlir/include/mlir/IR/IntegerSet.h
index b7662f0..e989f91 100644
--- a/third_party/mlir/include/mlir/IR/IntegerSet.h
+++ b/third_party/mlir/include/mlir/IR/IntegerSet.h
@@ -72,12 +72,22 @@
/// Returns true if this is the canonical integer set.
bool isEmptyIntegerSet() const;
+ /// This method substitutes any uses of dimensions and symbols (e.g.
+ /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
+ /// integer set. Because this can be used to eliminate dims and
+ /// symbols, the client needs to specify the number of dims and symbols in
+ /// the result. The returned map always has the same number of results.
+ IntegerSet replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
+ ArrayRef<AffineExpr> symReplacements,
+ unsigned numResultDims,
+ unsigned numResultSyms);
+
explicit operator bool() { return set; }
bool operator==(IntegerSet other) const { return set == other.set; }
unsigned getNumDims() const;
unsigned getNumSymbols() const;
- unsigned getNumOperands() const;
+ unsigned getNumInputs() const;
unsigned getNumConstraints() const;
unsigned getNumEqualities() const;
unsigned getNumInequalities() const;
@@ -96,6 +106,10 @@
MLIRContext *getContext() const;
+ /// Walk all of the AffineExpr's in this set's constraints. Each node in an
+ /// expression tree is visited in postorder.
+ void walkExprs(llvm::function_ref<void(AffineExpr)> callback) const;
+
void print(raw_ostream &os) const;
void dump() const;
diff --git a/third_party/mlir/lib/Analysis/AffineStructures.cpp b/third_party/mlir/lib/Analysis/AffineStructures.cpp
index f660fff..2804ac6 100644
--- a/third_party/mlir/lib/Analysis/AffineStructures.cpp
+++ b/third_party/mlir/lib/Analysis/AffineStructures.cpp
@@ -308,7 +308,7 @@
// Construct from an IntegerSet.
FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
- : numReservedCols(set.getNumOperands() + 1),
+ : numReservedCols(set.getNumInputs() + 1),
numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
numSymbols(set.getNumSymbols()) {
equalities.reserve(set.getNumEqualities() * numReservedCols);
diff --git a/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index c6abc05..2161ae0 100644
--- a/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -620,26 +620,27 @@
// A symbol may appear as a dim in affine.apply operations. This function
// canonicalizes dims that are valid symbols into actual symbols.
+template <class MapOrSet>
static void
-canonicalizePromotedSymbols(AffineMap *map,
+canonicalizePromotedSymbols(MapOrSet *mapOrSet,
llvm::SmallVectorImpl<Value *> *operands) {
- if (!map || operands->empty())
+ if (!mapOrSet || operands->empty())
return;
- assert(map->getNumInputs() == operands->size() &&
- "map inputs must match number of operands");
+ assert(mapOrSet->getNumInputs() == operands->size() &&
+ "map/set inputs must match number of operands");
- auto *context = map->getContext();
+ auto *context = mapOrSet->getContext();
SmallVector<Value *, 8> resultOperands;
resultOperands.reserve(operands->size());
SmallVector<Value *, 8> remappedSymbols;
remappedSymbols.reserve(operands->size());
unsigned nextDim = 0;
unsigned nextSym = 0;
- unsigned oldNumSyms = map->getNumSymbols();
- SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
- for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) {
- if (i < map->getNumDims()) {
+ unsigned oldNumSyms = mapOrSet->getNumSymbols();
+ SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
+ for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
+ if (i < mapOrSet->getNumDims()) {
if (isValidSymbol((*operands)[i])) {
// This is a valid symbol that appears as a dim, canonicalize it.
dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
@@ -655,42 +656,49 @@
resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
*operands = resultOperands;
- *map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
- oldNumSyms + nextSym);
+ *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
+ oldNumSyms + nextSym);
- assert(map->getNumInputs() == operands->size() &&
- "map inputs must match number of operands");
+ assert(mapOrSet->getNumInputs() == operands->size() &&
+ "map/set inputs must match number of operands");
}
-void mlir::canonicalizeMapAndOperands(
- AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
- if (!map || operands->empty())
+// Works for either an affine map or an integer set.
+template <class MapOrSet>
+static void
+canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
+ llvm::SmallVectorImpl<Value *> *operands) {
+ static_assert(std::is_same<MapOrSet, AffineMap>::value ||
+ std::is_same<MapOrSet, IntegerSet>::value,
+ "Argument must be either of AffineMap or IntegerSet type");
+
+ if (!mapOrSet || operands->empty())
return;
- assert(map->getNumInputs() == operands->size() &&
- "map inputs must match number of operands");
+ assert(mapOrSet->getNumInputs() == operands->size() &&
+ "map/set inputs must match number of operands");
- canonicalizePromotedSymbols(map, operands);
+ canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
// Check to see what dims are used.
- llvm::SmallBitVector usedDims(map->getNumDims());
- llvm::SmallBitVector usedSyms(map->getNumSymbols());
- map->walkExprs([&](AffineExpr expr) {
+ llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
+ llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
+ mapOrSet->walkExprs([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
usedDims[dimExpr.getPosition()] = true;
else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
usedSyms[symExpr.getPosition()] = true;
});
- auto *context = map->getContext();
+ auto *context = mapOrSet->getContext();
SmallVector<Value *, 8> resultOperands;
resultOperands.reserve(operands->size());
llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
- SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
+ SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
unsigned nextDim = 0;
- for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) {
+ for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
if (usedDims[i]) {
// Remap dim positions for duplicate operands.
auto it = seenDims.find((*operands)[i]);
@@ -704,37 +712,47 @@
}
}
llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
- SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols());
+ SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
unsigned nextSym = 0;
- for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) {
+ for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
if (!usedSyms[i])
continue;
// Handle constant operands (only needed for symbolic operands since
// constant operands in dimensional positions would have already been
// promoted to symbolic positions above).
IntegerAttr operandCst;
- if (matchPattern((*operands)[i + map->getNumDims()],
+ if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
m_Constant(&operandCst))) {
symRemapping[i] =
getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
continue;
}
// Remap symbol positions for duplicate operands.
- auto it = seenSymbols.find((*operands)[i + map->getNumDims()]);
+ auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
if (it == seenSymbols.end()) {
symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
- resultOperands.push_back((*operands)[i + map->getNumDims()]);
- seenSymbols.insert(
- std::make_pair((*operands)[i + map->getNumDims()], symRemapping[i]));
+ resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
+ seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
+ symRemapping[i]));
} else {
symRemapping[i] = it->second;
}
}
- *map =
- map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym);
+ *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
+ nextDim, nextSym);
*operands = resultOperands;
}
+void mlir::canonicalizeMapAndOperands(
+ AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
+ canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
+}
+
+void mlir::canonicalizeSetAndOperands(
+ IntegerSet *set, llvm::SmallVectorImpl<Value *> *operands) {
+ canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
+}
+
namespace {
/// Simplify AffineApply operations.
///
@@ -1540,7 +1558,7 @@
// Verify that there are enough operands for the condition.
IntegerSet condition = conditionAttr.getValue();
- if (op.getNumOperands() != condition.getNumOperands())
+ if (op.getNumOperands() != condition.getNumInputs())
return op.emitOpError(
"operand count and condition integer set dimension and "
"symbol count must match");
@@ -1639,6 +1657,44 @@
setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
}
+void AffineIfOp::setConditional(IntegerSet set, ArrayRef<Value *> operands) {
+ setIntegerSet(set);
+ getOperation()->setOperands(operands);
+}
+
+namespace {
+// This is a pattern to canonicalize an affine if op's conditional (integer
+// set + operands).
+struct AffineIfOpCanonicalizer : public OpRewritePattern<AffineIfOp> {
+ using OpRewritePattern<AffineIfOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineIfOp ifOp,
+ PatternRewriter &rewriter) const override {
+ auto set = ifOp.getIntegerSet();
+ SmallVector<Value *, 4> operands(ifOp.getOperands());
+
+ canonicalizeSetAndOperands(&set, &operands);
+
+ // Any canonicalization change always leads to either a reduction in the
+ // number of operands or a change in the number of symbolic operands
+ // (promotion of dims to symbols).
+ if (operands.size() < ifOp.getIntegerSet().getNumInputs() ||
+ set.getNumSymbols() > ifOp.getIntegerSet().getNumSymbols()) {
+ ifOp.setConditional(set, operands);
+ rewriter.updatedRootInPlace(ifOp);
+ return matchSuccess();
+ }
+
+ return matchFailure();
+ }
+};
+} // end anonymous namespace
+
+void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<AffineIfOpCanonicalizer>(context);
+}
+
//===----------------------------------------------------------------------===//
// AffineLoadOp
//===----------------------------------------------------------------------===//
diff --git a/third_party/mlir/lib/IR/IntegerSet.cpp b/third_party/mlir/lib/IR/IntegerSet.cpp
index 74a1297..139ca50 100644
--- a/third_party/mlir/lib/IR/IntegerSet.cpp
+++ b/third_party/mlir/lib/IR/IntegerSet.cpp
@@ -24,7 +24,7 @@
unsigned IntegerSet::getNumDims() const { return set->dimCount; }
unsigned IntegerSet::getNumSymbols() const { return set->symbolCount; }
-unsigned IntegerSet::getNumOperands() const {
+unsigned IntegerSet::getNumInputs() const {
return set->dimCount + set->symbolCount;
}
@@ -70,3 +70,23 @@
MLIRContext *IntegerSet::getContext() const {
return getConstraint(0).getContext();
}
+
+/// Walk all of the AffineExpr's in this set. Each node in an expression
+/// tree is visited in postorder.
+void IntegerSet::walkExprs(
+ llvm::function_ref<void(AffineExpr)> callback) const {
+ for (auto expr : getConstraints())
+ expr.walk(callback);
+}
+
+IntegerSet IntegerSet::replaceDimsAndSymbols(
+ ArrayRef<AffineExpr> dimReplacements, ArrayRef<AffineExpr> symReplacements,
+ unsigned numResultDims, unsigned numResultSyms) {
+ SmallVector<AffineExpr, 8> constraints;
+ constraints.reserve(getNumConstraints());
+ for (auto cst : getConstraints())
+ constraints.push_back(
+ cst.replaceDimsAndSymbols(dimReplacements, symReplacements));
+
+ return get(numResultDims, numResultSyms, constraints, getEqFlags());
+}