Adds affine.min operation which returns the minimum value from a multi-result affine map. This operation is useful for things like computing the dynamic value of affine loop bounds, and is trivial to constant fold.
PiperOrigin-RevId: 279959714
Change-Id: I11698cd11c3fe192b72d966925305bc5ac51a536
diff --git a/third_party/mlir/g3doc/Dialects/Affine.md b/third_party/mlir/g3doc/Dialects/Affine.md
index 541daf8..6049457 100644
--- a/third_party/mlir/g3doc/Dialects/Affine.md
+++ b/third_party/mlir/g3doc/Dialects/Affine.md
@@ -560,6 +560,29 @@
```
+#### 'affine.min' operation
+
+Syntax:
+
+``` {.ebnf}
+operation ::= ssa-id `=` `affine.min` affine-map dim-and-symbol-use-list
+```
+
+The `affine.min` operation applies an
+[affine mapping](#affine-expressions) to a list of SSA values, and returns the
+minimum value of all result expressions. The number of dimension and symbol
+arguments to affine.min must be equal to the respective number of dimensional
+and symbolic inputs to the affine mapping; the `affine.min` operation always
+returns one value. The input operands and result must all have 'index' type.
+
+Example:
+
+```mlir {.mlir}
+
+%0 = affine.min (d0)[s0] -> (1000, d0 + 512, s0) (%arg0)[%arg1]
+
+```
+
#### `affine.terminator` operation
Syntax:
diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
index f54c514..1b6c777 100644
--- a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
+++ b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
@@ -248,6 +248,24 @@
let hasCanonicalizer = 1;
}
+def AffineMinOp : Affine_Op<"min"> {
+ let summary = "min operation";
+ let description = [{
+ The "min" operation computes the minimum value result from a multi-result
+ affine map.
+
+ Example:
+
+ %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) : index
+ }];
+ let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
+ let results = (outs Index);
+ let extraClassDeclaration = [{
+ static StringRef getMapAttrName() { return "map"; }
+ }];
+ let hasFolder = 1;
+}
+
def AffineTerminatorOp :
Affine_Op<"terminator", [Terminator]> {
let summary = "affine terminator operation";
diff --git a/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index acec1dd..77ee9cf 100644
--- a/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -1937,5 +1937,80 @@
results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
}
+//===----------------------------------------------------------------------===//
+// AffineMinOp
+//===----------------------------------------------------------------------===//
+//
+// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
+//
+
+static ParseResult parseAffineMinOp(OpAsmParser &parser,
+ OperationState &result) {
+ auto &builder = parser.getBuilder();
+ auto indexType = builder.getIndexType();
+ SmallVector<OpAsmParser::OperandType, 8> dim_infos;
+ SmallVector<OpAsmParser::OperandType, 8> sym_infos;
+ AffineMapAttr mapAttr;
+ return failure(
+ parser.parseAttribute(mapAttr, AffineMinOp::getMapAttrName(),
+ result.attributes) ||
+ parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOperandList(sym_infos,
+ OpAsmParser::Delimiter::OptionalSquare) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.resolveOperands(dim_infos, indexType, result.operands) ||
+ parser.resolveOperands(sym_infos, indexType, result.operands) ||
+ parser.addTypeToList(indexType, result.types));
+}
+
+static void print(OpAsmPrinter &p, AffineMinOp op) {
+ p << op.getOperationName() << ' '
+ << op.getAttr(AffineMinOp::getMapAttrName());
+ auto begin = op.operand_begin();
+ auto end = op.operand_end();
+ unsigned numDims = op.map().getNumDims();
+ p << '(';
+ p.printOperands(begin, begin + numDims);
+ p << ')';
+
+ if (begin + numDims != end) {
+ p << '[';
+ p.printOperands(begin + numDims, end);
+ p << ']';
+ }
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
+}
+
+static LogicalResult verify(AffineMinOp op) {
+ // Verify that operand count matches affine map dimension and symbol count.
+ if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
+ return op.emitOpError(
+ "operand count and affine map dimension and symbol count must match");
+ return success();
+}
+
+OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
+ // Fold the affine map.
+ // TODO(andydavis, ntv) Fold more cases: partial static information,
+ // min(some_affine, some_affine + constant, ...).
+ SmallVector<Attribute, 2> results;
+ if (failed(map().constantFold(operands, results)))
+ return {};
+
+ // Compute and return min of folded map results.
+ int64_t min = std::numeric_limits<int64_t>::max();
+ int minIndex = -1;
+ for (unsigned i = 0, e = results.size(); i < e; ++i) {
+ auto intAttr = results[i].cast<IntegerAttr>();
+ if (intAttr.getInt() < min) {
+ min = intAttr.getInt();
+ minIndex = i;
+ }
+ }
+ if (minIndex < 0)
+ return {};
+ return results[minIndex];
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"