[mlir][NFC] Use `getDefiningOp<OpTy>()` instead of `dyn_cast<OpTy>(getDefiningOp())` (#150428)
This PR uses `val.getDefiningOp<OpTy>()` to replace `dyn_cast<OpTy>(val.getDefiningOp())` , `dyn_cast_or_null<OpTy>(val.getDefiningOp())` and `dyn_cast_if_present<OpTy>(val.getDefiningOp())`.
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 3e434ea..5bd1d49 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -49,7 +49,7 @@
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
predList.emplace_back(pos, builder.getIsNotNull());
- if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
+ if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) {
// If the attribute has a type or value, add a constraint.
if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 10da907..50a0f3d 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -1322,7 +1322,7 @@
return false;
Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy,
state.builder, value.getLoc());
- if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(value.getDefiningOp()))
+ if (auto constOp = value.getDefiningOp<arith::ConstantOp>())
return constOp.getValue() == valueAttr;
return false;
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 910334b..488c3c3 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2498,7 +2498,7 @@
matchPattern(adaptor.getFalseValue(), m_Zero()))
return condition;
- if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
+ if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
auto pred = cmp.getPredicate();
if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
auto cmpLhs = cmp.getLhs();
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
index 5aadaec..45b896d 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -49,7 +49,7 @@
// If the operand is not defined by an explicit extend operation of the
// accepted operation type allow for an implicit sign-extension.
- auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ auto extOp = v.getDefiningOp<Op>();
if (!extOp) {
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
auto eltTy = cast<VectorType>(v.getType()).getElementType();
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
index ac1df38..fcfeb9c 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -50,7 +50,7 @@
// If the operand is not defined by an explicit extend operation of the
// accepted operation type allow for an implicit sign-extension.
- auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ auto extOp = v.getDefiningOp<Op>();
if (!extOp) {
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
auto vTy = cast<VectorType>(v.getType());
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index d5fe3b4..3f0690c 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -62,9 +62,7 @@
continue;
for (Value operand : op.getOperands()) {
- auto usedExpression =
- dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp());
-
+ auto usedExpression = operand.getDefiningOp<ExpressionOp>();
if (!usedExpression)
continue;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5b01596..d42ce96 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2707,7 +2707,7 @@
while (alias) {
Block &initBlock = alias.getInitializerBlock();
auto returnOp = cast<ReturnOp>(initBlock.getTerminator());
- auto addrOp = dyn_cast<AddressOfOp>(returnOp.getArg().getDefiningOp());
+ auto addrOp = returnOp.getArg().getDefiningOp<AddressOfOp>();
// FIXME: This is a best effort solution. The AliasOp body might be more
// complex and in that case we bail out with success. To completely match
// the LLVM IR logic it would be necessary to implement proper alias and
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e3ce0e1..9f523e9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1852,7 +1852,7 @@
assert(!packOp && "packOp must be null on entry when unPackOp is not null");
OpOperand *packUse = linalgOp.getDpsInitOperand(
cast<OpResult>(unPackOp.getSource()).getResultNumber());
- packOp = dyn_cast_or_null<linalg::PackOp>(packUse->get().getDefiningOp());
+ packOp = packUse->get().getDefiningOp<linalg::PackOp>();
if (!packOp || !packOp.getResult().hasOneUse())
return emitSilenceableError() << "could not find matching pack op";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 19729af..fd530f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -757,8 +757,7 @@
Value source = extractSliceOp.getSource();
LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");
while (source && source != expectedSource) {
- auto destOp =
- dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp());
+ auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
if (!destOp)
break;
LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 70bc7b6..58986a6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -165,8 +165,7 @@
Value source = transferRead.getBase();
// Skip view-like Ops and retrive the actual soruce Operation
- while (auto srcOp =
- dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp()))
+ while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>())
source = srcOp.getViewSource();
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 28608cb..01426e4 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -755,7 +755,7 @@
MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
MeshSharding::MeshSharding(Value rhs) {
- auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
+ auto shardingOp = rhs.getDefiningOp<ShardingOp>();
assert(shardingOp && "expected sharding op");
auto splitAxes = shardingOp.getSplitAxes().getAxes();
// If splitAxes are empty, use "empty" constructor.
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 09c754d..0a68376 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -167,7 +167,7 @@
for (auto [operand, sharding] :
llvm::zip_equal(op->getOperands(), operandShardings)) {
- ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
+ ShardOp shardOp = operand.getDefiningOp<ShardOp>();
if (!shardOp) {
continue;
}
@@ -376,8 +376,7 @@
LLVM_DEBUG(
DBGS() << "print all the ops' iterator types and indexing maps in the "
"block.\n";
- for (Operation &op
- : block.getOperations()) {
+ for (Operation &op : block.getOperations()) {
if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index c6e76ec..5dd744d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -660,8 +660,7 @@
// Check if 2 shard ops are chained. If not there is no need for resharding
// as the source and target shared the same sharding.
- ShardOp srcShardOp =
- dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
+ ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>();
if (!srcShardOp) {
targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
} else {
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 5d6c5499..c1c1767 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1730,8 +1730,7 @@
if (!mapOp.getDefiningOp())
return emitError(op->getLoc(), "missing map operation");
- if (auto mapInfoOp =
- mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
+ if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
uint64_t mapTypeBits = mapInfoOp.getMapType();
bool to = mapTypeToBitFlag(
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index b44dbfd..c5ec0ca 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -53,7 +53,7 @@
Value ptrLike;
FromPtrOp fromPtr = *this;
while (fromPtr != nullptr) {
- auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
+ auto toPtr = fromPtr.getPtr().getDefiningOp<ToPtrOp>();
// Cannot fold if it's not a `to_ptr` op or the initial and final types are
// different.
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
@@ -64,13 +64,12 @@
ptrLike = toPtr.getPtr();
} else if (md) {
// Fold if the metadata can be verified to be equal.
- if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
+ if (auto mdOp = md.getDefiningOp<GetMetadataOp>();
mdOp && mdOp.getPtr() == toPtr.getPtr())
ptrLike = toPtr.getPtr();
}
// Check for a sequence of casts.
- fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
- : nullptr);
+ fromPtr = ptrLike ? ptrLike.getDefiningOp<FromPtrOp>() : nullptr;
}
return ptrLike;
}
@@ -112,13 +111,13 @@
Value ptr;
ToPtrOp toPtr = *this;
while (toPtr != nullptr) {
- auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
+ auto fromPtr = toPtr.getPtr().getDefiningOp<FromPtrOp>();
// Cannot fold if it's not a `from_ptr` op.
if (!fromPtr)
return ptr;
ptr = fromPtr.getPtr();
// Check for chains of casts.
- toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
+ toPtr = ptr.getDefiningOp<ToPtrOp>();
}
return ptr;
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index 84a779b..081f5fb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -100,11 +100,10 @@
op.getStep(), tileSizeConstants)) {
// Collect the statically known loop bounds
auto lowerBoundConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
+ lowerBound.getDefiningOp<arith::ConstantIndexOp>();
auto upperBoundConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
- auto stepConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
+ upperBound.getDefiningOp<arith::ConstantIndexOp>();
+ auto stepConstant = step.getDefiningOp<arith::ConstantIndexOp>();
auto tileSize =
cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value();
// If the loop bounds and the loop step are constant and if the number of
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 0262319..a52872d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -1317,7 +1317,7 @@
Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
Value newSize = arith::AddIOp::create(rewriter, loc, size, n);
- auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
+ auto nValue = n.getDefiningOp<arith::ConstantIndexOp>();
bool nIsOne = (nValue && nValue.value() == 1);
if (!op.getInbounds()) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 34e7e42..1ad2c80 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -554,7 +554,7 @@
Value input = op.getInput();
// Check the input to the CLAMP op is itself a CLAMP.
- auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
+ auto clampOp = input.getDefiningOp<tosa::ClampOp>();
if (!clampOp)
return failure();
@@ -1636,7 +1636,7 @@
for (Value operand : getOperands()) {
concatOperands.emplace_back(operand);
- auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
+ auto producer = operand.getDefiningOp<ConcatOp>();
if (!producer)
continue;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed..4e9f93b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2591,8 +2591,7 @@
llvm::enumerate(fromElements.getElements())) {
// Check that the element is from a vector.extract operation.
- auto extractOp =
- dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+ auto extractOp = element.getDefiningOp<vector::ExtractOp>();
if (!extractOp) {
return rewriter.notifyMatchFailure(fromElements,
"element not from vector.extract");
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index a36c6ac..dcd2e11 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -900,8 +900,7 @@
// inlined, and as such should be wrapped in parentheses in order to guarantee
// its precedence and associativity.
auto requiresParentheses = [&](Value value) {
- auto expressionOp =
- dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
+ auto expressionOp = value.getDefiningOp<ExpressionOp>();
if (!expressionOp)
return false;
return shouldBeInlined(expressionOp);
@@ -1542,7 +1541,7 @@
return success();
}
- auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
+ auto expressionOp = value.getDefiningOp<ExpressionOp>();
if (expressionOp && shouldBeInlined(expressionOp))
return emitExpression(expressionOp);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index d162afd..97c6b4e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -151,8 +151,7 @@
// Copyin operands are handled as `to` call.
llvm::SmallVector<mlir::Value> create, copyin;
for (mlir::Value dataOp : op.getDataClauseOperands()) {
- if (auto createOp =
- mlir::dyn_cast_or_null<acc::CreateOp>(dataOp.getDefiningOp())) {
+ if (auto createOp = dataOp.getDefiningOp<acc::CreateOp>()) {
create.push_back(createOp.getVarPtr());
} else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
dataOp.getDefiningOp())) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d9eb6ae..9f18199 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3537,8 +3537,7 @@
}
static bool isDeclareTargetLink(mlir::Value value) {
- if (auto addressOfOp =
- llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
+ if (auto addressOfOp = value.getDefiningOp<LLVM::AddressOfOp>()) {
auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
if (auto declareTargetGlobal =
@@ -4498,8 +4497,7 @@
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = dataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
@@ -4516,8 +4514,7 @@
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = enterDataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
RTLFn =
@@ -4536,8 +4533,7 @@
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = exitDataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
@@ -4556,8 +4552,7 @@
ifCond = moduleTranslation.lookupValue(ifVar);
if (auto devId = updateDataOp.getDevice())
- if (auto constOp =
- dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
@@ -5198,8 +5193,7 @@
if (!value)
return std::nullopt;
- if (auto constOp =
- dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
+ if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
return constAttr.getInt();
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b4aeccf..1fff57e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -139,8 +139,7 @@
LogicalResult matchAndRewrite(TestCommutative2Op op,
PatternRewriter &rewriter) const override {
- auto operand =
- dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
+ auto operand = op->getOperand(0).getDefiningOp<TestCommutative2Op>();
if (!operand)
return failure();
Attribute constInput;