Expand rewriter gen to handle string attributes in output.
* Extend to handle rewrite patterns with output attributes;
- Constant attributes are defined with a value and a type;
- The type of the value is mapped to the corresponding attribute type (string -> StringAttr);
* Verifies the type of operands in the resultant matches the defined op's operands;
PiperOrigin-RevId: 226468908
diff --git a/include/mlir/IR/op_base.td b/include/mlir/IR/op_base.td
index 4d75d53..5b57944 100644
--- a/include/mlir/IR/op_base.td
+++ b/include/mlir/IR/op_base.td
@@ -213,3 +213,18 @@
class TernaryOp<string mnemonic, list<OpProperty> props> :
Op<mnemonic, props>, Arguments<(ins Tensor, Tensor, Tensor)>;
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+// Base class for op+ -> op+ rewrite patterns. These allow declaratively
+// specifying rewrite patterns.
+// TODO(jpienaar): Add the constraint list along with the Pattern.
+class Pattern<dag patternToMatch, list<dag> resultOps> {
+ dag PatternToMatch = patternToMatch;
+ list<dag> ResultOps = resultOps;
+}
+
+// Form of a pattern which produces a single result.
+class Pat<dag pattern, dag result> : Pattern<pattern, [result]>;
+
diff --git a/test/mlir-rewriter-gen/one-op-one-result.td b/test/mlir-rewriter-gen/one-op-one-result.td
new file mode 100644
index 0000000..55b9e36
--- /dev/null
+++ b/test/mlir-rewriter-gen/one-op-one-result.td
@@ -0,0 +1,44 @@
+// RUN: mlir-rewriter-gen %s | FileCheck %s
+
+// Extracted & simplified from op_base.td to do more directed testing.
+class Type;
+class Pattern<dag patternToMatch, list<dag> resultOps> {
+ dag PatternToMatch = patternToMatch;
+ list<dag> ResultOps = resultOps;
+}
+class Pat<dag pattern, dag result> : Pattern<pattern, [result]>;
+def ins;
+class Op<string mnemonic> {
+ string name = mnemonic;
+ dag operands = (ins);
+}
+class Attr<Type t> {
+ Type type = t;
+}
+
+// Create a Type and Attribute.
+def YT : Type;
+def Y_Attr : Attr<YT>;
+def Y_Const_Attr {
+ Type type = YT;
+ string value = "attrValue";
+}
+
+// Define ops to rewrite.
+def T1: Type;
+def X_AddOp : Op<"x.add">;
+def Y_AddOp : Op<"y.add"> {
+ let operands = (ins T1, T1, Y_Attr:$attrName);
+}
+
+// Define rewrite pattern.
+def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, T1:$rhs, Y_Const_Attr:$x)>;
+
+// CHECK: struct GeneratedConvert0 : public RewritePattern
+// CHECK: RewritePattern("x.add", 1, context)
+// CHECK: PatternMatchResult match(Operation *op)
+// CHECK: void rewrite(Operation *op, PatternRewriter &rewriter)
+// CHECK: rewriter.replaceOpWithNewOp<Y::AddOp>(op, op->getResult(0)->getType()
+
+// CHECK: void populateWithGenerated
+// CHECK: patterns->push_back(std::make_unique<GeneratedConvert0>(context))
\ No newline at end of file
diff --git a/tools/mlir-rewriter-gen/mlir-rewriter-gen.cpp b/tools/mlir-rewriter-gen/mlir-rewriter-gen.cpp
index 591a741..bb23187 100644
--- a/tools/mlir-rewriter-gen/mlir-rewriter-gen.cpp
+++ b/tools/mlir-rewriter-gen/mlir-rewriter-gen.cpp
@@ -42,6 +42,7 @@
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Rewriters", os);
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
+ Record *attrClass = recordKeeper.getClass("Attr");
// Ensure unique patterns simply by appending unique suffix.
unsigned rewritePatternCount = 0;
@@ -57,7 +58,7 @@
for (auto arg : tree->getArgs()) {
if (isa<DagInit>(arg))
PrintFatalError(pattern->getLoc(),
- "Only single pattern inputs supported");
+ "only single pattern inputs supported");
}
// Emit RewritePattern for Pattern.
@@ -75,31 +76,105 @@
ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
if (resultOps->size() != 1)
- PrintFatalError("Can only handle single result rules");
+ PrintFatalError("only single result rules supported");
DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
// TODO(jpienaar): Expand to multiple results.
for (auto result : resultTree->getArgs()) {
if (isa<DagInit>(result))
- PrintFatalError(pattern->getLoc(), "Only single op result supported");
+ PrintFatalError(pattern->getLoc(), "only single op result supported");
}
DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
std::string opName = resultRoot->getAsUnquotedString();
+ auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
SmallVector<StringRef, 2> split;
SplitString(opName, split, "_");
auto className = join(split, "::");
- os << " void rewrite(Operation *op, PatternRewriter &rewriter) const "
- << "override {\n rewriter.replaceOpWithNewOp<" << className
- << ">(op, op->getResult(0)->getType()";
- for (auto arg : resultTree->getArgNames()) {
- if (!arg)
- continue;
- // TODO(jpienaar): Change to /*x=*/ form once operands are named.
- os << ", /* " << arg->getAsUnquotedString() << " */op->getOperand("
- << nameToOrdinal[arg->getAsUnquotedString()] << ")";
+ os << formatv(R"(
+ void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ auto* context = op->getContext(); (void)context;
+ rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
+ className);
+ if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
+ PrintFatalError(pattern->getLoc(),
+ Twine("mismatch between arguments of resultant op (") +
+ Twine(resultOperands->getNumArgs()) +
+ ") and arguments provided for rewrite (" +
+ Twine(resultTree->getNumArgs()) + Twine(')'));
}
- os << ");\n }\n};\n";
+
+ // Create the builder call for the result.
+ for (int i = 0, e = resultTree->getNumArgs(); i != e; ++i) {
+ // Start each operand on its own line.
+ (os << ",\n").indent(6);
+
+ auto *arg = resultTree->getArg(i);
+ std::string name = resultTree->getArgName(i)->getAsUnquotedString();
+ auto defInit = dyn_cast<DefInit>(arg);
+
+ // TODO(jpienaar): Refactor out into map to avoid recomputing these.
+ auto *argument = resultOperands->getArg(i);
+ auto argumentDefInit = dyn_cast<DefInit>(argument);
+ bool argumentIsAttr = false;
+ if (argumentDefInit) {
+ if (auto recTy = dyn_cast<RecordRecTy>(argumentDefInit->getType()))
+ argumentIsAttr = recTy->isSubClassOf(attrClass);
+ }
+
+ if (argumentIsAttr) {
+ if (!defInit) {
+ std::string argumentName =
+ resultOperands->getArgName(i)->getAsUnquotedString();
+ PrintFatalError(pattern->getLoc(),
+ Twine("attribute '") + argumentName +
+ "' needs to be constant initialized");
+ }
+
+ auto value = defInit->getDef()->getValue("value");
+ if (!value)
+ PrintFatalError(pattern->getLoc(), Twine("'value' not defined in ") +
+ arg->getAsString());
+
+ switch (value->getType()->getRecTyKind()) {
+ case RecTy::IntRecTyKind:
+ // TODO(jpienaar): This is using 64-bits for all the bitwidth of the
+ // type could instead be queried. These are expected to be mostly used
+ // for enums or constant indices and so no arithmetic operations are
+ // expected on these.
+ os << formatv(
+ "/*{0}=*/IntegerAttr::get(Type::getInteger(64, context), {1})",
+ name, value->getValue()->getAsString());
+ break;
+ case RecTy::StringRecTyKind:
+ os << formatv("/*{0}=*/StringAttr::get({1}, context)", name,
+ value->getValue()->getAsString());
+ break;
+ default:
+ PrintFatalError(pattern->getLoc(),
+ Twine("unsupported/unimplemented value type for ") +
+ name);
+ }
+ continue;
+ }
+
+ // Verify the types match between the rewriter's result and the
+ if (defInit && argumentDefInit &&
+ defInit->getType() != argumentDefInit->getType()) {
+ PrintFatalError(
+ pattern->getLoc(),
+ "mismatch in type of operation's argument and rewrite argument " +
+ Twine(i));
+ }
+
+ // Lookup the ordinal for the named operand.
+ auto ord = nameToOrdinal.find(name);
+ if (ord == nameToOrdinal.end())
+ PrintFatalError(pattern->getLoc(),
+ Twine("unknown named operand '") + name + "'");
+ os << "/*" << name << "=*/op->getOperand(" << ord->getValue() << ")";
+ }
+ os << "\n );\n }\n};\n";
}
// Emit function to add the generated matchers to the pattern list.