Add getMemRefType() accessors to LoadOp/StoreOp.

- There are several places where we are casting the type of the memref obtained
  from the load/store op to a memref type, and this will become even more
  common (some upcoming CLs this week). Add a getMemRefType and use it at
  several places where the cast was being used.

PiperOrigin-RevId: 219164326
diff --git a/include/mlir/StandardOps/StandardOps.h b/include/mlir/StandardOps/StandardOps.h
index 87f06ef..b733bad 100644
--- a/include/mlir/StandardOps/StandardOps.h
+++ b/include/mlir/StandardOps/StandardOps.h
@@ -460,6 +460,9 @@
   SSAValue *getMemRef() { return getOperand(0); }
   const SSAValue *getMemRef() const { return getOperand(0); }
   void setMemRef(SSAValue *value) { setOperand(0, value); }
+  MemRefType *getMemRefType() const {
+    return cast<MemRefType>(getMemRef()->getType());
+  }
 
   llvm::iterator_range<Operation::operand_iterator> getIndices() {
     return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
@@ -580,6 +583,9 @@
   SSAValue *getMemRef() { return getOperand(1); }
   const SSAValue *getMemRef() const { return getOperand(1); }
   void setMemRef(SSAValue *value) { setOperand(1, value); }
+  MemRefType *getMemRefType() const {
+    return cast<MemRefType>(getMemRef()->getType());
+  }
 
   llvm::iterator_range<Operation::operand_iterator> getIndices() {
     return {getOperation()->operand_begin() + 2, getOperation()->operand_end()};
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 5f1b7f2..3b427d6 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -157,7 +157,7 @@
 template <typename LoadOrStoreOpPointer>
 static bool isContiguousAccess(MLValue *input, LoadOrStoreOpPointer memoryOp) {
   auto indicesAsOperandIterators = memoryOp->getIndices();
-  auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
+  auto *memRefType = memoryOp->getMemRefType();
   SmallVector<MLValue *, 4> indices;
   for (auto *it : indicesAsOperandIterators) {
     indices.push_back(cast<MLValue>(it));
diff --git a/lib/StandardOps/StandardOps.cpp b/lib/StandardOps/StandardOps.cpp
index 30b2bb8..02db875 100644
--- a/lib/StandardOps/StandardOps.cpp
+++ b/lib/StandardOps/StandardOps.cpp
@@ -753,7 +753,7 @@
   p->printOperands(getIndices());
   *p << ']';
   p->printOptionalAttrDict(getAttrs());
-  *p << " : " << *getMemRef()->getType();
+  *p << " : " << *getMemRefType();
 }
 
 bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
@@ -928,7 +928,7 @@
   p->printOperands(getIndices());
   *p << ']';
   p->printOptionalAttrDict(getAttrs());
-  *p << " : " << *getMemRef()->getType();
+  *p << " : " << *getMemRefType();
 }
 
 bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {