Update the constantFold/fold API to use LogicalResult instead of bool.
PiperOrigin-RevId: 237719658
diff --git a/include/mlir/IR/AffineMap.h b/include/mlir/IR/AffineMap.h
index 87411cf..41aefba 100644
--- a/include/mlir/IR/AffineMap.h
+++ b/include/mlir/IR/AffineMap.h
@@ -35,6 +35,7 @@
class AffineExpr;
class Attribute;
+struct LogicalResult;
class MLIRContext;
/// A multi-dimensional affine map
@@ -115,10 +116,9 @@
unsigned numResultSyms);
/// Folds the results of the application of an affine map on the provided
- /// operands to a constant if possible. Returns false if the folding happens,
- /// true otherwise.
- bool constantFold(ArrayRef<Attribute> operandConstants,
- SmallVectorImpl<Attribute> &results) const;
+ /// operands to a constant if possible.
+ LogicalResult constantFold(ArrayRef<Attribute> operandConstants,
+ SmallVectorImpl<Attribute> &results) const;
/// Returns the AffineMap resulting from composing `this` with `map`.
/// The resulting AffineMap has as many AffineDimExpr as `map` and as many
diff --git a/include/mlir/IR/Dialect.h b/include/mlir/IR/Dialect.h
index 7577136..1994c11 100644
--- a/include/mlir/IR/Dialect.h
+++ b/include/mlir/IR/Dialect.h
@@ -31,7 +31,7 @@
using DialectConstantDecodeHook =
std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
-using DialectConstantFoldHook = std::function<bool(
+using DialectConstantFoldHook = std::function<LogicalResult(
const Instruction *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
using DialectExtractElementHook =
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
@@ -53,12 +53,12 @@
/// fold hook of each operation, it attempts to constant fold the operation
/// with the specified constant operand values - the elements in "operands"
/// will correspond directly to the operands of the operation, but may be null
- /// if non-constant. If constant folding is successful, this returns false
- /// and fills in the `results` vector. If not, this returns true and
- /// `results` is unspecified.
+ /// if non-constant. If constant folding is successful, this fills in the
+ /// `results` vector. If not, this returns failure and `results` is
+ /// unspecified.
DialectConstantFoldHook constantFoldHook =
[](const Instruction *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) { return true; };
+ SmallVectorImpl<Attribute> &results) { return failure(); };
/// Registered hook to decode opaque constants associated with this
/// dialect. The hook function attempts to decode an opaque constant tensor
diff --git a/include/mlir/IR/Instruction.h b/include/mlir/IR/Instruction.h
index 1f3c881..f9a3ac0 100644
--- a/include/mlir/IR/Instruction.h
+++ b/include/mlir/IR/Instruction.h
@@ -415,13 +415,13 @@
/// Attempt to constant fold this operation with the specified constant
/// operand values - the elements in "operands" will correspond directly to
/// the operands of the operation, but may be null if non-constant. If
- /// constant folding is successful, this returns false and fills in the
- /// `results` vector. If not, this returns true and `results` is unspecified.
- bool constantFold(ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) const;
+ /// constant folding is successful, this fills in the `results` vector. If
+ /// not, `results` is unspecified.
+ LogicalResult constantFold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<Attribute> &results) const;
/// Attempt to fold this operation using the Op's registered foldHook.
- bool fold(SmallVectorImpl<Value *> &results);
+ LogicalResult fold(SmallVectorImpl<Value *> &results);
//===--------------------------------------------------------------------===//
// Conversions to declared operations like DimOp
diff --git a/include/mlir/IR/Matchers.h b/include/mlir/IR/Matchers.h
index b6105e3..d8b913b 100644
--- a/include/mlir/IR/Matchers.h
+++ b/include/mlir/IR/Matchers.h
@@ -73,7 +73,7 @@
if (op->getNumOperands() > 0 || op->getNumResults() != 1)
return false;
SmallVector<Attribute, 1> foldedAttr;
- if (!op->constantFold(/*operands=*/llvm::None, foldedAttr)) {
+ if (succeeded(op->constantFold(/*operands=*/llvm::None, foldedAttr))) {
*bind_value = foldedAttr.front();
return true;
}
diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h
index 703af97..4c5754f 100644
--- a/include/mlir/IR/OpDefinition.h
+++ b/include/mlir/IR/OpDefinition.h
@@ -241,9 +241,9 @@
public:
/// This is an implementation detail of the constant folder hook for
/// AbstractOperation.
- static bool constantFoldHook(const Instruction *op,
- ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) {
+ static LogicalResult constantFoldHook(const Instruction *op,
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<Attribute> &results) {
return op->cast<ConcreteType>()->constantFold(operands, results,
op->getContext());
}
@@ -252,19 +252,20 @@
/// fold this operation with the specified constant operand values - the
/// elements in "operands" will correspond directly to the operands of the
/// operation, but may be null if non-constant. If constant folding is
- /// successful, this returns false and fills in the `results` vector. If not,
- /// this returns true and `results` is unspecified.
+ /// successful, this fills in the `results` vector. If not, `results` is
+ /// unspecified.
///
/// If not overridden, this fallback implementation always fails to fold.
///
- bool constantFold(ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results,
- MLIRContext *context) const {
- return true;
+ LogicalResult constantFold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<Attribute> &results,
+ MLIRContext *context) const {
+ return failure();
}
/// This is an implementation detail of the folder hook for AbstractOperation.
- static bool foldHook(Instruction *op, SmallVectorImpl<Value *> &results) {
+ static LogicalResult foldHook(Instruction *op,
+ SmallVectorImpl<Value *> &results) {
return op->cast<ConcreteType>()->fold(results);
}
@@ -276,12 +277,12 @@
/// can only perform the following changes to the operation:
///
/// 1. They can leave the operation alone and without changing the IR, and
- /// return true.
+ /// return failure.
/// 2. They can mutate the operation in place, without changing anything else
- /// in the IR. In this case, return false.
+ /// in the IR. In this case, return success.
/// 3. They can return a list of existing values that can be used instead of
/// the operation. In this case, fill in the results list and return
- /// false. The caller will remove the operation and use those results
+ /// success. The caller will remove the operation and use those results
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
@@ -291,7 +292,7 @@
///
/// If not overridden, this fallback implementation always fails to fold.
///
- bool fold(SmallVectorImpl<Value *> &results) { return true; }
+ LogicalResult fold(SmallVectorImpl<Value *> &results) { return failure(); }
};
/// This template specialization defines the constantFoldHook and foldHook as
@@ -303,16 +304,16 @@
public:
/// This is an implementation detail of the constant folder hook for
/// AbstractOperation.
- static bool constantFoldHook(const Instruction *op,
- ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) {
+ static LogicalResult constantFoldHook(const Instruction *op,
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<Attribute> &results) {
auto result =
op->cast<ConcreteType>()->constantFold(operands, op->getContext());
if (!result)
- return true;
+ return failure();
results.push_back(result);
- return false;
+ return success();
}
/// Op implementations can implement this hook. It should attempt to constant
@@ -330,13 +331,14 @@
}
/// This is an implementation detail of the folder hook for AbstractOperation.
- static bool foldHook(Instruction *op, SmallVectorImpl<Value *> &results) {
+ static LogicalResult foldHook(Instruction *op,
+ SmallVectorImpl<Value *> &results) {
auto *result = op->cast<ConcreteType>()->fold();
if (!result)
- return true;
+ return failure();
if (result != op->getResult(0))
results.push_back(result);
- return false;
+ return success();
}
/// This hook implements a generalized folder for this operation. Operations
diff --git a/include/mlir/IR/OperationSupport.h b/include/mlir/IR/OperationSupport.h
index c9122a3..a3e2911 100644
--- a/include/mlir/IR/OperationSupport.h
+++ b/include/mlir/IR/OperationSupport.h
@@ -27,6 +27,7 @@
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/PointerUnion.h"
#include <memory>
@@ -90,11 +91,11 @@
/// if everything is ok.
bool (&verifyInvariants)(const Instruction *op);
- /// This hook implements a constant folder for this operation. It returns
- /// true if folding failed, or returns false and fills in `results` on
- /// success.
- bool (&constantFoldHook)(const Instruction *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results);
+ /// This hook implements a constant folder for this operation. It fills in
+ /// `results` on success.
+ LogicalResult (&constantFoldHook)(const Instruction *op,
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<Attribute> &results);
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
@@ -104,19 +105,19 @@
/// can only perform the following changes to the operation:
///
/// 1. They can leave the operation alone and without changing the IR, and
- /// return true.
+ /// return failure.
/// 2. They can mutate the operation in place, without changing anything else
- /// in the IR. In this case, return false.
+ /// in the IR. In this case, return success.
/// 3. They can return a list of existing values that can be used instead of
/// the operation. In this case, fill in the results list and return
- /// false. The caller will remove the operation and use those results
+ /// success. The caller will remove the operation and use those results
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), but does
/// not allow for canonicalizations that need to introduce new operations, not
/// even constants (e.g. "x-x -> 0" cannot be expressed).
- bool (&foldHook)(Instruction *op, SmallVectorImpl<Value *> &results);
+ LogicalResult (&foldHook)(Instruction *op, SmallVectorImpl<Value *> &results);
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
@@ -149,10 +150,11 @@
bool (&parseAssembly)(OpAsmParser *parser, OperationState *result),
void (&printAssembly)(const Instruction *op, OpAsmPrinter *p),
bool (&verifyInvariants)(const Instruction *op),
- bool (&constantFoldHook)(const Instruction *op,
- ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results),
- bool (&foldHook)(Instruction *op, SmallVectorImpl<Value *> &results),
+ LogicalResult (&constantFoldHook)(const Instruction *op,
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<Attribute> &results),
+ LogicalResult (&foldHook)(Instruction *op,
+ SmallVectorImpl<Value *> &results),
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
MLIRContext *context))
: name(name), dialect(dialect), isClassFor(isClassFor),
diff --git a/lib/AffineOps/AffineOps.cpp b/lib/AffineOps/AffineOps.cpp
index f181026..c19565c 100644
--- a/lib/AffineOps/AffineOps.cpp
+++ b/lib/AffineOps/AffineOps.cpp
@@ -204,7 +204,7 @@
MLIRContext *context) const {
auto map = getAffineMap();
SmallVector<Attribute, 1> result;
- if (map.constantFold(operands, result))
+ if (failed(map.constantFold(operands, result)))
return Attribute();
return result[0];
}
@@ -837,7 +837,7 @@
assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute, 4> foldedResults;
- if (boundMap.constantFold(operandConstants, foldedResults))
+ if (failed(boundMap.constantFold(operandConstants, foldedResults)))
return;
// Compute the max or min as applicable over the results.
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index f14f690..2d091ba 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/Functional.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/StringRef.h"
@@ -166,8 +167,9 @@
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise.
-bool AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
- SmallVectorImpl<Attribute> &results) const {
+LogicalResult
+AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
+ SmallVectorImpl<Attribute> &results) const {
assert(getNumInputs() == operandConstants.size());
// Fold each of the result expressions.
@@ -177,13 +179,13 @@
auto folded = exprFolder.constantFold(expr);
// If we didn't fold to a constant, then folding fails.
if (!folded)
- return true;
+ return failure();
results.push_back(folded);
}
assert(results.size() == getNumResults() &&
"constant folding produced the wrong number of results");
- return false;
+ return success();
}
/// Walk all of the AffineExpr's in this mapping. The results are visited
diff --git a/lib/IR/Instruction.cpp b/lib/IR/Instruction.cpp
index 64c9ac2..36a0449 100644
--- a/lib/IR/Instruction.cpp
+++ b/lib/IR/Instruction.cpp
@@ -510,15 +510,16 @@
}
/// Attempt to constant fold this operation with the specified constant
-/// operand values. If successful, this returns false and fills in the
-/// results vector. If not, this returns true and results is unspecified.
-bool Instruction::constantFold(ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) const {
+/// operand values. If successful, this fills in the results vector. If not,
+/// results is unspecified.
+LogicalResult
+Instruction::constantFold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<Attribute> &results) const {
if (auto *abstractOp = getAbstractOperation()) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
- if (!abstractOp->constantFoldHook(this, operands, results))
- return false;
+ if (succeeded(abstractOp->constantFoldHook(this, operands, results)))
+ return success();
// Otherwise, fall back on the dialect hook to handle it.
return abstractOp->dialect.constantFoldHook(this, operands, results);
@@ -528,22 +529,21 @@
// operation, fall back to a dialect which matches the prefix.
auto opName = getName().getStringRef();
auto dialectPrefix = opName.split('.').first;
- if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix)) {
+ if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix))
return dialect->constantFoldHook(this, operands, results);
- }
- return true;
+ return failure();
}
/// Attempt to fold this operation using the Op's registered foldHook.
-bool Instruction::fold(SmallVectorImpl<Value *> &results) {
+LogicalResult Instruction::fold(SmallVectorImpl<Value *> &results) {
if (auto *abstractOp = getAbstractOperation()) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
- if (!abstractOp->foldHook(this, results))
- return false;
+ if (succeeded(abstractOp->foldHook(this, results)))
+ return success();
}
- return true;
+ return failure();
}
/// Emit an error with the op name prefixed, like "'dim' op " which is
diff --git a/lib/Transforms/ConstantFold.cpp b/lib/Transforms/ConstantFold.cpp
index 6bdb1bf..ef063d0 100644
--- a/lib/Transforms/ConstantFold.cpp
+++ b/lib/Transforms/ConstantFold.cpp
@@ -63,7 +63,7 @@
// Attempt to constant fold the operation.
SmallVector<Attribute, 8> resultConstants;
- if (op->constantFold(operandConstants, resultConstants))
+ if (failed(op->constantFold(operandConstants, resultConstants)))
return;
// Ok, if everything succeeded, then we can create constants corresponding
diff --git a/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 39e6c5f..2a0238c 100644
--- a/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -216,7 +216,7 @@
// If constant folding was successful, create the result constants, RAUW the
// operation and remove it.
resultConstants.clear();
- if (!op->constantFold(operandConstants, resultConstants)) {
+ if (succeeded(op->constantFold(operandConstants, resultConstants))) {
builder.setInsertionPoint(op);
// Add the operands to the worklist for visitation.
@@ -256,7 +256,7 @@
// operation.
originalOperands.assign(op->operand_begin(), op->operand_end());
resultValues.clear();
- if (!op->fold(resultValues)) {
+ if (succeeded(op->fold(resultValues))) {
// If the result was an in-place simplification (e.g. max(x,x,y) ->
// max(x,y)) then add the original operands to the worklist so we can make
// sure to revisit them.
diff --git a/tools/mlir-tblgen/OpDefinitionsGen.cpp b/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 08ae721..ff9587c 100644
--- a/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -475,9 +475,9 @@
os << " Attribute constantFold(ArrayRef<Attribute> operands,\n"
" MLIRContext *context) const;\n";
} else {
- os << " bool constantFold(ArrayRef<Attribute> operands,\n"
- << " SmallVectorImpl<Attribute> &results,\n"
- << " MLIRContext *context) const;\n";
+ os << " LogicalResult constantFold(ArrayRef<Attribute> operands,\n"
+ << " SmallVectorImpl<Attribute> &results,"
+ << "\n MLIRContext *context) const;\n";
}
}