TableGen: allow mixing attributes and operands in the Arguments DAG of Op definition

The existing implementation of the Op definition generator assumes and relies
on the fact that native Op Attributes appear after its value-based operands in
the Arguments list.  Furthermore, the same order is used in the generated
`build` function for the operation.  This is not desirable for some operations
with mandatory attributes that would want the attribute to appear upfront for
better consistency with their textual representation, for example `cmpi` would
prefer the `predicate` attribute to be foremost in the argument list.

Introduce support for using attributes and operands in the Arguments DAG in no
particular order.  This is achieved by maintaining a list of Arguments that
point to either the value or the attribute and are used to generate the `build`
method.

PiperOrigin-RevId: 237002921
diff --git a/include/mlir/IR/OpBase.td b/include/mlir/IR/OpBase.td
index 7e17e55..df9688c 100644
--- a/include/mlir/IR/OpBase.td
+++ b/include/mlir/IR/OpBase.td
@@ -438,8 +438,7 @@
   // Additional, longer human-readable description of what the op does.
   string description = "";
 
-  // Dag containting the arguments of the op. Default to 0 arguments. Operands
-  // to the op need to precede attributes to ops in the argument specification.
+  // Dag containting the arguments of the op. Default to 0 arguments.
   dag arguments = (ins);
 
   // The list of results of the op. Default to 0 results.
diff --git a/include/mlir/TableGen/Operator.h b/include/mlir/TableGen/Operator.h
index 93ad645..366e83d 100644
--- a/include/mlir/TableGen/Operator.h
+++ b/include/mlir/TableGen/Operator.h
@@ -105,7 +105,7 @@
   bool hasVariadicOperand() const;
 
   // Returns the total number of arguments.
-  int getNumArgs() const { return getNumOperands() + getNumNativeAttributes(); }
+  int getNumArgs() const { return arguments.size(); }
 
   // Op argument (attribute or operand) accessors.
   Argument getArg(int index);
@@ -140,22 +140,24 @@
   // The operands of the op.
   SmallVector<Value, 4> operands;
 
-  // The attributes of the op.
+  // The attributes of the op.  Contains native attributes (corresponding to the
+  // actual stored attributed of the operation) followed by derived attributes
+  // (corresponding to dynamic properties of the operation that are computed
+  // upon request).
   SmallVector<NamedAttribute, 4> attributes;
 
+  // The arguments of the op (operands and native attributes).
+  SmallVector<Argument, 4> arguments;
+
   // The results of the op.
   SmallVector<Value, 4> results;
 
   // The traits of the op.
   SmallVector<OpTrait, 4> traits;
 
-  // The start of native attributes, which are specified when creating the op
-  // as a part of the op's definition.
-  int nativeAttrStart;
-
-  // The start of derived attributes, which are computed from properties of
-  // the op.
-  int derivedAttrStart;
+  // The number of native attributes stored in the leading positions of
+  // `attributes`.
+  int numNativeAttributes;
 
   // The TableGen definition of this op.
   const llvm::Record &def;
diff --git a/lib/TableGen/Operator.cpp b/lib/TableGen/Operator.cpp
index 87dfcad..0979534 100644
--- a/lib/TableGen/Operator.cpp
+++ b/lib/TableGen/Operator.cpp
@@ -78,7 +78,7 @@
 }
 
 int tblgen::Operator::getNumNativeAttributes() const {
-  return derivedAttrStart - nativeAttrStart;
+  return numNativeAttributes;
 }
 
 int tblgen::Operator::getNumDerivedAttributes() const {
@@ -140,22 +140,21 @@
 }
 
 auto tblgen::Operator::getArg(int index) -> Argument {
-  if (index < nativeAttrStart)
-    return {&operands[index]};
-  return {&attributes[index - nativeAttrStart]};
+  return arguments[index];
 }
 
 void tblgen::Operator::populateOpStructure() {
   auto &recordKeeper = def.getRecords();
+  auto typeConstraintClass = recordKeeper.getClass("TypeConstraint");
   auto attrClass = recordKeeper.getClass("Attr");
   auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
-  derivedAttrStart = -1;
+  numNativeAttributes = 0;
 
   // The argument ordering is operands, native attributes, derived
   // attributes.
   DagInit *argumentValues = def.getValueAsDag("arguments");
   unsigned i = 0;
-  // Handle operands.
+  // Handle operands and native attributes.
   for (unsigned e = argumentValues->getNumArgs(); i != e; ++i) {
     auto arg = argumentValues->getArg(i);
     auto givenName = argumentValues->getArgNameStr(i);
@@ -164,32 +163,26 @@
       PrintFatalError(def.getLoc(),
                       Twine("undefined type for argument #") + Twine(i));
     Record *argDef = argDefInit->getDef();
-    if (argDef->isSubClassOf(attrClass))
-      break;
-    operands.push_back(Value{givenName, Type(argDefInit)});
-  }
 
-  // Handle native attributes.
-  nativeAttrStart = i;
-  for (unsigned e = argumentValues->getNumArgs(); i != e; ++i) {
-    auto arg = argumentValues->getArg(i);
-    auto givenName = argumentValues->getArgNameStr(i);
-    Record *argDef = cast<DefInit>(arg)->getDef();
-    if (!argDef->isSubClassOf(attrClass))
-      PrintFatalError(def.getLoc(),
-                      Twine("expected attribute as argument ") + Twine(i));
-
-    if (givenName.empty())
-      PrintFatalError(argDef->getLoc(), "attributes must be named");
-    bool isDerived = argDef->isSubClassOf(derivedAttrClass);
-    if (isDerived)
-      PrintFatalError(def.getLoc(),
-                      "derived attributes not allowed in argument list");
-    attributes.push_back({givenName, Attribute(argDef)});
+    if (argDef->isSubClassOf(typeConstraintClass)) {
+      operands.push_back(Value{givenName, Type(argDefInit)});
+      arguments.emplace_back(&operands.back());
+    } else if (argDef->isSubClassOf(attrClass)) {
+      if (givenName.empty())
+        PrintFatalError(argDef->getLoc(), "attributes must be named");
+      if (argDef->isSubClassOf(derivedAttrClass))
+        PrintFatalError(argDef->getLoc(),
+                        "derived attributes not allowed in argument list");
+      attributes.push_back({givenName, Attribute(argDef)});
+      arguments.emplace_back(&attributes.back());
+      ++numNativeAttributes;
+    } else {
+      PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
+                                    "from TypeConstraint or Attr are allowed");
+    }
   }
 
   // Handle derived attributes.
-  derivedAttrStart = i;
   for (const auto &val : def.getValues()) {
     if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
       if (!record->isSubClassOf(attrClass))
diff --git a/test/mlir-tblgen/op-attribute.td b/test/mlir-tblgen/op-attribute.td
index 29a4566..deecedd 100644
--- a/test/mlir-tblgen/op-attribute.td
+++ b/test/mlir-tblgen/op-attribute.td
@@ -2,11 +2,22 @@
 
 include "mlir/IR/OpBase.td"
 
+def MixOperandsAndAttrs : Op<"mix_operands_and_attrs", []> {
+  let arguments = (ins F32Attr:$attr, F32:$operand, F32Attr:$otherAttr, F32:$otherArg);
+}
+
+// CHECK-LABEL: class MixOperandsAndAttrs
+// CHECK-DAG: Value *operand()
+// CHECK-DAG: Value *otherArg()
+// CHECK-DAG: void build(Builder *builder, OperationState *result, FloatAttr attr, Value *operand, FloatAttr otherAttr, Value *otherArg)
+// CHECK-DAG: APFloat attr()
+// CHECK-DAG: APFloat otherAttr()
+
 def OpWithArgs : Op<"op_with_args", []> {
   let arguments = (ins I32:$x, F32Attr:$attr, OptionalAttr<F32Attr>:$optAttr);
 }
 
-// CHECK-LABEL: OpWithArgs
+// CHECK-LABEL: class OpWithArgs
 // CHECK: void build(Builder *builder, OperationState *result, Value *x, FloatAttr attr, /*optional*/FloatAttr optAttr)
 // CHECK: APFloat attr()
 // CHECK: Optional< APFloat > optAttr()
diff --git a/tools/mlir-tblgen/OpDefinitionsGen.cpp b/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 86c219e..30773dc 100644
--- a/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -244,7 +244,6 @@
   OUT(2) << "static void build(Builder *builder, OperationState *result";
 
   auto numResults = op.getNumResults();
-  auto numOperands = op.getNumOperands();
 
   llvm::SmallVector<std::string, 4> resultNames;
   resultNames.reserve(numResults);
@@ -263,24 +262,30 @@
     }
   }
 
-  // Emit parameters for all operands
-  for (int i = 0; i != numOperands; ++i) {
-    auto &operand = op.getOperand(i);
-    os << (operand.type.isVariadic() ? ", ArrayRef<Value *> " : ", Value *")
-       << getArgumentName(op, i);
+  // Emit parameters for all arguments (operands and attributes).
+  int numOperands = 0;
+  int numAttrs = 0;
+  for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
+    auto argument = op.getArg(i);
+    if (argument.is<tblgen::Value *>()) {
+      auto &operand = op.getOperand(numOperands);
+      os << (operand.type.isVariadic() ? ", ArrayRef<Value *> " : ", Value *")
+         << getArgumentName(op, numOperands);
+      ++numOperands;
+    } else {
+      // TODO(antiagainst): Support default initializer for attributes
+      const auto &namedAttr = op.getAttribute(numAttrs);
+      const auto &attr = namedAttr.attr;
+      os << ", ";
+      if (attr.isOptional())
+        os << "/*optional*/";
+      os << attr.getStorageType() << ' ' << namedAttr.name;
+      ++numAttrs;
+    }
   }
-
-  // Emit parameters for all attributes
-  // TODO(antiagainst): Support default initializer for attributes
-  for (const auto &namedAttr : op.getAttributes()) {
-    const auto &attr = namedAttr.attr;
-    if (attr.isDerivedAttr())
-      break;
-    os << ", ";
-    if (attr.isOptional())
-      os << "/*optional*/";
-    os << attr.getStorageType() << ' ' << namedAttr.name;
-  }
+  if (numOperands + numAttrs != op.getNumArgs())
+    return PrintFatalError(
+        "op arguments must be either operands or attributes");
 
   os << ") {\n";