Add builderCall to Type and add constant attr class.

With the builder to construct the type on the Type, the appropriate mlir::Type can be constructed where needed. Also add a constant attr class that has the attribute and value as members.

PiperOrigin-RevId: 227564789
diff --git a/include/mlir/IR/op_base.td b/include/mlir/IR/op_base.td
index e10acfd..b3a4d99 100644
--- a/include/mlir/IR/op_base.td
+++ b/include/mlir/IR/op_base.td
@@ -24,11 +24,15 @@
 //===----------------------------------------------------------------------===//
 
 // Base class for all types.
-class Type;
+class Type {
+  // The builder call to invoke (if specified) to construct the Type.
+  code builderCall = ?;
+}
 
 // Integer types.
 class I<int width> : Type {
   int bitwidth = width;
+  let builderCall = "getIntegerType(" # bitwidth # ")";
 }
 def I1  : I<1>;
 def I32 : I<32>;
@@ -37,7 +41,9 @@
 class F<int width> : Type {
   int bitwidth = width;
 }
-def F32 : F<32>;
+def F32 : F<32> {
+  let builderCall = "getF32Type()";
+}
 
 // Vector types.
 class Vector<Type t, list<int> dims> : Type {
@@ -108,10 +114,18 @@
   code body = Body;
 }
 
+// Derived attribute that returns a mlir::Type.
+class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
+
+// Represents a constant attribute of specific Attr type. The leaf class that
+// derives from this should additionally include a `value` member.
+class ConstantAttr<Attr attribute> {
+  Attr attr = attribute;
+}
+
 // The values for const F32 attributes are set as strings as floating point
 // values can't be provided directly in TableGen.
-class ConstF32Attr<string val> {
-  Type type = F32;
+class ConstF32Attr<string val> : ConstantAttr<F32Attr> {
   string value = val;
 }
 
diff --git a/tools/mlir-tblgen/RewriterGen.cpp b/tools/mlir-tblgen/RewriterGen.cpp
index 7aa2997..e4ba02b 100644
--- a/tools/mlir-tblgen/RewriterGen.cpp
+++ b/tools/mlir-tblgen/RewriterGen.cpp
@@ -85,45 +85,32 @@
 } // end namespace
 
 void Pattern::emitAttributeValue(Record *constAttr) {
-  Record *type = constAttr->getValueAsDef("type");
+  Record *attr = constAttr->getValueAsDef("attr");
   auto value = constAttr->getValue("value");
+  Record *type = attr->getValueAsDef("type");
+  auto storageType = attr->getValueAsString("storageType").trim();
 
-  // Construct the attribute based on `type`.
-  // TODO(jpienaar): Generalize this to avoid hardcoding here.
-  if (type->isSubClassOf("F")) {
-    string mlirType;
-    switch (type->getValueAsInt("bitwidth")) {
-    case 32:
-      mlirType = "Type::getF32(context)";
-      break;
-    default:
-      PrintFatalError("unsupported floating point width");
-    }
-    // TODO(jpienaar): Verify the floating point constant here.
-    os << formatv("FloatAttr::get({0}, {1})", mlirType,
+  // For attributes stored as strings we do not need to query builder etc.
+  if (storageType == "StringAttr") {
+    os << formatv("rewriter.getStringAttr({0})",
+                  value->getValue()->getAsString());
+    return;
+  }
+
+  // Construct the attribute based on storage type and builder.
+  if (auto b = type->getValue("builderCall")) {
+    if (isa<UnsetInit>(b->getValue()))
+      PrintFatalError(pattern->getLoc(),
+                      "no builder specified for " + type->getName());
+    CodeInit *builder = cast<CodeInit>(b->getValue());
+    // TODO(jpienaar): Verify the constants here
+    os << formatv("{0}::get(rewriter.{1}, {2})", storageType,
+                  builder->getValue(),
                   value->getValue()->getAsUnquotedString());
     return;
   }
 
-  // Fallback to the type of value.
-  switch (value->getType()->getRecTyKind()) {
-  case RecTy::IntRecTyKind:
-    // TODO(jpienaar): This is using 64-bits for all the bitwidth of the
-    // type could instead be queried. These are expected to be mostly used
-    // for enums or constant indices and so no arithmetic operations are
-    // expected on these.
-    os << formatv("IntegerAttr::get(Type::getInteger(64, context), {0})",
-                  value->getValue()->getAsString());
-    break;
-  case RecTy::StringRecTyKind:
-    os << formatv("StringAttr::get({0}, context)",
-                  value->getValue()->getAsString());
-    break;
-  default:
-    PrintFatalError(pattern->getLoc(),
-                    Twine("unsupported/unimplemented value type for ") +
-                        value->getName());
-  }
+  PrintFatalError(pattern->getLoc(), "unable to emit attribute");
 }
 
 void Pattern::collectBoundArguments(DagInit *tree) {
@@ -237,7 +224,6 @@
   os << formatv(R"(
   void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
                PatternRewriter &rewriter) const override {
-    auto* context = op->getContext(); (void)context;
     auto& s = *static_cast<MatchedState *>(state.get());
     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
                 resultOp.cppClassName());
@@ -273,7 +259,7 @@
     (os << ",\n").indent(6);
 
     // The argument in the result DAG pattern.
-    std::string name = resultTree->getArgNameStr(i);
+    auto name = resultOp.getArgName(i);
     auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
     auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
     if (!value)