Introduce a new extract_element operation that does what it says.  Introduce a
new VectorOrTensorType class that provides a common interface between vector
and tensor since a number of operations will be uniform across them (including
extract_element).  Improve the LoadOp verifier.

I also updated the MLIR spec doc as well.

PiperOrigin-RevId: 209953189
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index a7b671c..e688ea2 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -327,6 +327,47 @@
   explicit DimOp(const Operation *state) : OpBase(state) {}
 };
 
+/// The "extract_element" op reads a tensor or vector and returns one element
+/// from it specified by an index list. The output of extract is a new value
+/// with the same type as the elements of the tensor or vector. The arity of
+/// indices matches the rank of the accessed value (i.e., if a tensor is of rank
+/// 3, then 3 indices are required for the extract).  The indices should all be
+/// of affine_int type.
+///
+/// For example:
+///
+///   %3 = extract_element %0[%1, %2] : vector<4x4xi32>
+///
+class ExtractElementOp
+    : public OpBase<ExtractElementOp, OpTrait::VariadicOperands,
+                    OpTrait::OneResult> {
+public:
+  static void build(Builder *builder, OperationState *result,
+                    SSAValue *aggregate, ArrayRef<SSAValue *> indices = {});
+
+  SSAValue *getAggregate() { return getOperand(0); }
+  const SSAValue *getAggregate() const { return getOperand(0); }
+
+  llvm::iterator_range<Operation::operand_iterator> getIndices() {
+    return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
+  }
+
+  llvm::iterator_range<Operation::const_operand_iterator> getIndices() const {
+    return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
+  }
+
+  static StringRef getOperationName() { return "extract_element"; }
+
+  // Hooks to customize behavior of this op.
+  const char *verify() const;
+  static bool parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p) const;
+
+private:
+  friend class Operation;
+  explicit ExtractElementOp(const Operation *state) : OpBase(state) {}
+};
+
 /// The "load" op reads an element from a memref specified by an index list. The
 /// output of load is a new value with the same type as the elements of the
 /// memref. The arity of indices is the rank of the memref (i.e., if the memref
@@ -352,6 +393,8 @@
   static StringRef getOperationName() { return "load"; }
 
   // Hooks to customize behavior of this op.
+  static void build(Builder *builder, OperationState *result, SSAValue *memref,
+                    ArrayRef<SSAValue *> indices = {});
   const char *verify() const;
   static bool parse(OpAsmParser *parser, OperationState *result);
   void print(OpAsmPrinter *p) const;
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index 4bfdba1..62f146a 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -256,9 +256,33 @@
   ~FunctionType() = delete;
 };
 
+/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
+/// types, because many operations work on values of these aggregate types.
+class VectorOrTensorType : public Type {
+public:
+  Type *getElementType() const { return elementType; }
+
+  /// If this is ranked tensor or vector type, return the rank.  If it is an
+  /// unranked tensor, return -1.
+  int getRankIfPresent() const;
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(const Type *type) {
+    return type->getKind() == Kind::Vector ||
+           type->getKind() == Kind::RankedTensor ||
+           type->getKind() == Kind::UnrankedTensor;
+  }
+
+public:
+  Type *elementType;
+
+  VectorOrTensorType(Kind kind, MLIRContext *context, Type *elementType,
+                     unsigned subClassData = 0);
+};
+
 /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
 /// known constant shape with one or more dimension.
-class VectorType : public Type {
+class VectorType : public VectorOrTensorType {
 public:
   static VectorType *get(ArrayRef<unsigned> shape, Type *elementType);
 
@@ -266,7 +290,7 @@
     return ArrayRef<unsigned>(shapeElements, getSubclassData());
   }
 
-  Type *getElementType() const { return elementType; }
+  unsigned getRank() const { return getSubclassData(); }
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Type *type) {
@@ -283,9 +307,8 @@
 
 /// Tensor types represent multi-dimensional arrays, and have two variants:
 /// RankedTensorType and UnrankedTensorType.
-class TensorType : public Type {
+class TensorType : public VectorOrTensorType {
 public:
-  Type *getElementType() const { return elementType; }
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Type *type) {
@@ -294,9 +317,6 @@
   }
 
 protected:
-  /// The type of each scalar element of the tensor.
-  Type *elementType;
-
   TensorType(Kind kind, Type *elementType, MLIRContext *context);
   ~TensorType() {}
 };
@@ -313,7 +333,7 @@
     return ArrayRef<int>(shapeElements, getSubclassData());
   }
 
-  unsigned getRank() const { return getShape().size(); }
+  unsigned getRank() const { return getSubclassData(); }
 
   static bool classof(const Type *type) {
     return type->getKind() == Kind::RankedTensor;
@@ -346,7 +366,6 @@
 /// number of dimensions. Each shape element can be a positive integer or
 /// unknown (represented by any negative integer). MemRef types also have an
 /// affine map composition, represented as an array AffineMap pointers.
-// TODO: Use -1 for unknown dimensions (rather than arbitrary negative numbers).
 class MemRefType : public Type {
 public:
   /// Get or create a new MemRefType based on shape, element type, affine
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index e59b59c..024e1b6 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -454,17 +454,6 @@
   return *existing.first = result;
 }
 
-static bool isValidTensorElementType(Type *type, MLIRContext *context) {
-  return isa<FloatType>(type) || isa<VectorType>(type) ||
-         isa<IntegerType>(type) || type == Type::getTFString(context);
-}
-
-TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
-    : Type(kind, context), elementType(elementType) {
-  assert(isValidTensorElementType(elementType, context));
-  assert(isa<TensorType>(this));
-}
-
 RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
                                         Type *elementType) {
   auto *context = elementType->getContext();
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 9f01c13..18507b0 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -520,9 +520,77 @@
 }
 
 //===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+void ExtractElementOp::build(Builder *builder, OperationState *result,
+                             SSAValue *aggregate,
+                             ArrayRef<SSAValue *> indices) {
+  auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
+  result->addOperands(aggregate);
+  result->addOperands(indices);
+  result->types.push_back(aggregateType->getElementType());
+}
+
+void ExtractElementOp::print(OpAsmPrinter *p) const {
+  *p << "extract_element " << *getAggregate() << '[';
+  p->printOperands(getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << *getAggregate()->getType();
+}
+
+bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType aggregateInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  VectorOrTensorType *type;
+
+  auto affineIntTy = parser->getBuilder().getAffineIntType();
+  return parser->parseOperand(aggregateInfo) ||
+         parser->parseOperandList(indexInfo, -1,
+                                  OpAsmParser::Delimiter::Square) ||
+         parser->parseOptionalAttributeDict(result->attributes) ||
+         parser->parseColonType(type) ||
+         parser->resolveOperand(aggregateInfo, type, result->operands) ||
+         parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+         parser->addTypeToList(type->getElementType(), result->types);
+}
+
+const char *ExtractElementOp::verify() const {
+  if (getNumOperands() == 0)
+    return "expected an aggregate to index into";
+
+  auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
+  if (!aggregateType)
+    return "first operand must be a vector or tensor";
+
+  if (getResult()->getType() != aggregateType->getElementType())
+    return "result type must match element type of aggregate";
+
+  for (auto *idx : getIndices())
+    if (!idx->getType()->isAffineInt())
+      return "index to extract_element must have 'affineint' type";
+
+  // Verify the # indices match if we have a ranked type.
+  auto aggregateRank = aggregateType->getRankIfPresent();
+  if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
+    return "incorrect number of indices for extract_element";
+
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//
 
+void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
+                   ArrayRef<SSAValue *> indices) {
+  auto *memrefType = cast<MemRefType>(memref->getType());
+  result->addOperands(memref);
+  result->addOperands(indices);
+  result->types.push_back(memrefType->getElementType());
+}
+
 void LoadOp::print(OpAsmPrinter *p) const {
   *p << "load " << *getMemRef() << '[';
   p->printOperands(getIndices());
@@ -555,6 +623,12 @@
   if (!memRefType)
     return "first operand must be a memref";
 
+  if (getResult()->getType() != memRefType->getElementType())
+    return "result type must match element type of memref";
+
+  if (memRefType->getRank() != getNumOperands() - 1)
+    return "incorrect number of indices for load";
+
   for (auto *idx : getIndices())
     if (!idx->getType()->isAffineInt())
       return "index to load must have 'affineint' type";
@@ -671,6 +745,7 @@
 /// Install the standard operations in the specified operation set.
 void mlir::registerStandardOperations(OperationSet &opSet) {
   opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
-                      ConstantOp, DeallocOp, DimOp, LoadOp, ReturnOp, StoreOp>(
+                      ConstantOp, DeallocOp, DimOp, ExtractElementOp, LoadOp,
+                      ReturnOp, StoreOp>(
       /*prefix=*/"");
 }
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 7dfad79..d32fae3 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -35,10 +35,40 @@
     numResults(numResults), inputsAndResults(inputsAndResults) {
 }
 
+VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
+                                       Type *elementType, unsigned subClassData)
+    : Type(kind, context, subClassData), elementType(elementType) {}
+
+/// If this is ranked tensor or vector type, return the rank.  If it is an
+/// unranked tensor, return -1.
+int VectorOrTensorType::getRankIfPresent() const {
+  switch (getKind()) {
+  default:
+    llvm_unreachable("not a VectorOrTensorType");
+  case Kind::Vector:
+    return cast<VectorType>(this)->getRank();
+  case Kind::RankedTensor:
+    return cast<RankedTensorType>(this)->getRank();
+  case Kind::UnrankedTensor:
+    return -1;
+  }
+}
+
 VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
                        MLIRContext *context)
-    : Type(Kind::Vector, context, shape.size()), shapeElements(shape.data()),
-      elementType(elementType) {}
+    : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
+      shapeElements(shape.data()) {}
+
+/// Return true if the specified element type is ok in a tensor.
+static bool isValidTensorElementType(Type *type, MLIRContext *context) {
+  return isa<FloatType>(type) || isa<VectorType>(type) ||
+         isa<IntegerType>(type) || type == Type::getTFString(context);
+}
+
+TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
+    : VectorOrTensorType(kind, context, elementType) {
+  assert(isValidTensorElementType(elementType, context));
+}
 
 RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
                                    MLIRContext *context)
@@ -65,7 +95,7 @@
 unsigned MemRefType::getNumDynamicDims() const {
   unsigned numDynamicDims = 0;
   for (int dimSize : getShape()) {
-    if (dimSize < 0)
+    if (dimSize == -1)
       ++numDynamicDims;
   }
   return numDynamicDims;
diff --git a/test/IR/core-ops.mlir b/test/IR/core-ops.mlir
index e7fba92..ca94036 100644
--- a/test/IR/core-ops.mlir
+++ b/test/IR/core-ops.mlir
@@ -132,4 +132,16 @@
   return
 }
 
+// CHECK-LABEL: mlfunc @extract_element(%arg0 : tensor<??i32>, %arg1 : tensor<4x4xf32>) -> i32 {
+mlfunc @extract_element(%arg0 : tensor<??i32>, %arg1 : tensor<4x4xf32>) -> i32 {
+  %c0 = "constant"() {value: 0} : () -> affineint
+
+  // CHECK: %0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<??i32>
+  %0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<??i32>
+
+  // CHECK: %1 = extract_element %arg1[%c0, %c0] : tensor<4x4xf32>
+  %1 = extract_element %arg1[%c0, %c0] : tensor<4x4xf32>
+
+  return %0 : i32
+}