Add getMemRefType() accessors to LoadOp/StoreOp.
- There are several places where we are casting the type of the memref obtained
from the load/store op to a memref type, and this will become even more
common (some upcoming CLs this week). Add a getMemRefType and use it at
several places where the cast was being used.
PiperOrigin-RevId: 219164326
diff --git a/include/mlir/StandardOps/StandardOps.h b/include/mlir/StandardOps/StandardOps.h
index 87f06ef..b733bad 100644
--- a/include/mlir/StandardOps/StandardOps.h
+++ b/include/mlir/StandardOps/StandardOps.h
@@ -460,6 +460,9 @@
SSAValue *getMemRef() { return getOperand(0); }
const SSAValue *getMemRef() const { return getOperand(0); }
void setMemRef(SSAValue *value) { setOperand(0, value); }
+ MemRefType *getMemRefType() const {
+ return cast<MemRefType>(getMemRef()->getType());
+ }
llvm::iterator_range<Operation::operand_iterator> getIndices() {
return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
@@ -580,6 +583,9 @@
SSAValue *getMemRef() { return getOperand(1); }
const SSAValue *getMemRef() const { return getOperand(1); }
void setMemRef(SSAValue *value) { setOperand(1, value); }
+ MemRefType *getMemRefType() const {
+ return cast<MemRefType>(getMemRef()->getType());
+ }
llvm::iterator_range<Operation::operand_iterator> getIndices() {
return {getOperation()->operand_begin() + 2, getOperation()->operand_end()};
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 5f1b7f2..3b427d6 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -157,7 +157,7 @@
template <typename LoadOrStoreOpPointer>
static bool isContiguousAccess(MLValue *input, LoadOrStoreOpPointer memoryOp) {
auto indicesAsOperandIterators = memoryOp->getIndices();
- auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
+ auto *memRefType = memoryOp->getMemRefType();
SmallVector<MLValue *, 4> indices;
for (auto *it : indicesAsOperandIterators) {
indices.push_back(cast<MLValue>(it));
diff --git a/lib/StandardOps/StandardOps.cpp b/lib/StandardOps/StandardOps.cpp
index 30b2bb8..02db875 100644
--- a/lib/StandardOps/StandardOps.cpp
+++ b/lib/StandardOps/StandardOps.cpp
@@ -753,7 +753,7 @@
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getMemRef()->getType();
+ *p << " : " << *getMemRefType();
}
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
@@ -928,7 +928,7 @@
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getMemRef()->getType();
+ *p << " : " << *getMemRefType();
}
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {