Expand rewriter gen to handle string attributes in output.

* Extend to handle rewrite patterns with output attributes;
  - Constant attributes are defined with a value and a type;
  - The type of the value is mapped to the corresponding attribute type (string -> StringAttr);
* Verifies the type of operands in the resultant matches the defined op's operands;

PiperOrigin-RevId: 226468908
diff --git a/include/mlir/IR/op_base.td b/include/mlir/IR/op_base.td
index 4d75d53..5b57944 100644
--- a/include/mlir/IR/op_base.td
+++ b/include/mlir/IR/op_base.td
@@ -213,3 +213,18 @@
 
 class TernaryOp<string mnemonic, list<OpProperty> props> :
     Op<mnemonic, props>, Arguments<(ins Tensor, Tensor, Tensor)>;
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+// Base class for op+ -> op+ rewrite patterns. These allow declaratively
+// specifying rewrite patterns.
+// TODO(jpienaar): Add the constraint list along with the Pattern.
+class Pattern<dag patternToMatch, list<dag> resultOps> {
+  dag PatternToMatch = patternToMatch;
+  list<dag> ResultOps = resultOps;
+}
+
+// Form of a pattern which produces a single result.
+class Pat<dag pattern, dag result> : Pattern<pattern, [result]>;
+
diff --git a/test/mlir-rewriter-gen/one-op-one-result.td b/test/mlir-rewriter-gen/one-op-one-result.td
new file mode 100644
index 0000000..55b9e36
--- /dev/null
+++ b/test/mlir-rewriter-gen/one-op-one-result.td
@@ -0,0 +1,44 @@
+// RUN: mlir-rewriter-gen %s | FileCheck %s
+
+// Extracted & simplified from op_base.td to do more directed testing.
+class Type;
+class Pattern<dag patternToMatch, list<dag> resultOps> {
+  dag PatternToMatch = patternToMatch;
+  list<dag> ResultOps = resultOps;
+}
+class Pat<dag pattern, dag result> : Pattern<pattern, [result]>;
+def ins;
+class Op<string mnemonic> {
+  string name = mnemonic;
+  dag operands = (ins);
+}
+class Attr<Type t> {
+  Type type = t;
+}
+
+// Create a Type and Attribute.
+def YT : Type;
+def Y_Attr : Attr<YT>;
+def Y_Const_Attr {
+  Type type = YT;
+  string value = "attrValue";
+}
+
+// Define ops to rewrite.
+def T1: Type;
+def X_AddOp : Op<"x.add">;
+def Y_AddOp : Op<"y.add"> {
+  let operands = (ins T1, T1, Y_Attr:$attrName);
+}
+
+// Define rewrite pattern.
+def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, T1:$rhs, Y_Const_Attr:$x)>;
+
+// CHECK: struct GeneratedConvert0 : public RewritePattern
+// CHECK: RewritePattern("x.add", 1, context)
+// CHECK: PatternMatchResult match(Operation *op)
+// CHECK: void rewrite(Operation *op, PatternRewriter &rewriter)
+// CHECK: rewriter.replaceOpWithNewOp<Y::AddOp>(op, op->getResult(0)->getType()
+
+// CHECK: void populateWithGenerated
+// CHECK: patterns->push_back(std::make_unique<GeneratedConvert0>(context))
\ No newline at end of file
diff --git a/tools/mlir-rewriter-gen/mlir-rewriter-gen.cpp b/tools/mlir-rewriter-gen/mlir-rewriter-gen.cpp
index 591a741..bb23187 100644
--- a/tools/mlir-rewriter-gen/mlir-rewriter-gen.cpp
+++ b/tools/mlir-rewriter-gen/mlir-rewriter-gen.cpp
@@ -42,6 +42,7 @@
 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Rewriters", os);
   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
+  Record *attrClass = recordKeeper.getClass("Attr");
 
   // Ensure unique patterns simply by appending unique suffix.
   unsigned rewritePatternCount = 0;
@@ -57,7 +58,7 @@
     for (auto arg : tree->getArgs()) {
       if (isa<DagInit>(arg))
         PrintFatalError(pattern->getLoc(),
-                        "Only single pattern inputs supported");
+                        "only single pattern inputs supported");
     }
 
     // Emit RewritePattern for Pattern.
@@ -75,31 +76,105 @@
 
     ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
     if (resultOps->size() != 1)
-      PrintFatalError("Can only handle single result rules");
+      PrintFatalError("only single result rules supported");
     DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
 
     // TODO(jpienaar): Expand to multiple results.
     for (auto result : resultTree->getArgs()) {
       if (isa<DagInit>(result))
-        PrintFatalError(pattern->getLoc(), "Only single op result supported");
+        PrintFatalError(pattern->getLoc(), "only single op result supported");
     }
     DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
     std::string opName = resultRoot->getAsUnquotedString();
+    auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
 
     SmallVector<StringRef, 2> split;
     SplitString(opName, split, "_");
     auto className = join(split, "::");
-    os << "  void rewrite(Operation *op, PatternRewriter &rewriter) const "
-       << "override {\n   rewriter.replaceOpWithNewOp<" << className
-       << ">(op, op->getResult(0)->getType()";
-    for (auto arg : resultTree->getArgNames()) {
-      if (!arg)
-        continue;
-      // TODO(jpienaar): Change to /*x=*/ form once operands are named.
-      os << ", /* " << arg->getAsUnquotedString() << " */op->getOperand("
-         << nameToOrdinal[arg->getAsUnquotedString()] << ")";
+    os << formatv(R"(
+  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+    auto* context = op->getContext(); (void)context;
+    rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
+                  className);
+    if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
+      PrintFatalError(pattern->getLoc(),
+                      Twine("mismatch between arguments of resultant op (") +
+                          Twine(resultOperands->getNumArgs()) +
+                          ") and arguments provided for rewrite (" +
+                          Twine(resultTree->getNumArgs()) + Twine(')'));
     }
-    os << ");\n  }\n};\n";
+
+    // Create the builder call for the result.
+    for (int i = 0, e = resultTree->getNumArgs(); i != e; ++i) {
+      // Start each operand on its own line.
+      (os << ",\n").indent(6);
+
+      auto *arg = resultTree->getArg(i);
+      std::string name = resultTree->getArgName(i)->getAsUnquotedString();
+      auto defInit = dyn_cast<DefInit>(arg);
+
+      // TODO(jpienaar): Refactor out into map to avoid recomputing these.
+      auto *argument = resultOperands->getArg(i);
+      auto argumentDefInit = dyn_cast<DefInit>(argument);
+      bool argumentIsAttr = false;
+      if (argumentDefInit) {
+        if (auto recTy = dyn_cast<RecordRecTy>(argumentDefInit->getType()))
+          argumentIsAttr = recTy->isSubClassOf(attrClass);
+      }
+
+      if (argumentIsAttr) {
+        if (!defInit) {
+          std::string argumentName =
+              resultOperands->getArgName(i)->getAsUnquotedString();
+          PrintFatalError(pattern->getLoc(),
+                          Twine("attribute '") + argumentName +
+                              "' needs to be constant initialized");
+        }
+
+        auto value = defInit->getDef()->getValue("value");
+        if (!value)
+          PrintFatalError(pattern->getLoc(), Twine("'value' not defined in ") +
+                                                 arg->getAsString());
+
+        switch (value->getType()->getRecTyKind()) {
+        case RecTy::IntRecTyKind:
+          // TODO(jpienaar): This is using 64-bits for all the bitwidth of the
+          // type could instead be queried. These are expected to be mostly used
+          // for enums or constant indices and so no arithmetic operations are
+          // expected on these.
+          os << formatv(
+              "/*{0}=*/IntegerAttr::get(Type::getInteger(64, context), {1})",
+              name, value->getValue()->getAsString());
+          break;
+        case RecTy::StringRecTyKind:
+          os << formatv("/*{0}=*/StringAttr::get({1}, context)", name,
+                        value->getValue()->getAsString());
+          break;
+        default:
+          PrintFatalError(pattern->getLoc(),
+                          Twine("unsupported/unimplemented value type for ") +
+                              name);
+        }
+        continue;
+      }
+
+      // Verify the types match between the rewriter's result and the
+      if (defInit && argumentDefInit &&
+          defInit->getType() != argumentDefInit->getType()) {
+        PrintFatalError(
+            pattern->getLoc(),
+            "mismatch in type of operation's argument and rewrite argument " +
+                Twine(i));
+      }
+
+      // Lookup the ordinal for the named operand.
+      auto ord = nameToOrdinal.find(name);
+      if (ord == nameToOrdinal.end())
+        PrintFatalError(pattern->getLoc(),
+                        Twine("unknown named operand '") + name + "'");
+      os << "/*" << name << "=*/op->getOperand(" << ord->getValue() << ")";
+    }
+    os << "\n    );\n  }\n};\n";
   }
 
   // Emit function to add the generated matchers to the pattern list.