Add builderCall to Type and add constant attr class.
With the builder to construct the type on the Type, the appropriate mlir::Type can be constructed where needed. Also add a constant attr class that has the attribute and value as members.
PiperOrigin-RevId: 227564789
diff --git a/include/mlir/IR/op_base.td b/include/mlir/IR/op_base.td
index e10acfd..b3a4d99 100644
--- a/include/mlir/IR/op_base.td
+++ b/include/mlir/IR/op_base.td
@@ -24,11 +24,15 @@
//===----------------------------------------------------------------------===//
// Base class for all types.
-class Type;
+class Type {
+ // The builder call to invoke (if specified) to construct the Type.
+ code builderCall = ?;
+}
// Integer types.
class I<int width> : Type {
int bitwidth = width;
+ let builderCall = "getIntegerType(" # bitwidth # ")";
}
def I1 : I<1>;
def I32 : I<32>;
@@ -37,7 +41,9 @@
class F<int width> : Type {
int bitwidth = width;
}
-def F32 : F<32>;
+def F32 : F<32> {
+ let builderCall = "getF32Type()";
+}
// Vector types.
class Vector<Type t, list<int> dims> : Type {
@@ -108,10 +114,18 @@
code body = Body;
}
+// Derived attribute that returns a mlir::Type.
+class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
+
+// Represents a constant attribute of specific Attr type. The leaf class that
+// derives from this should additionally include a `value` member.
+class ConstantAttr<Attr attribute> {
+ Attr attr = attribute;
+}
+
// The values for const F32 attributes are set as strings as floating point
// values can't be provided directly in TableGen.
-class ConstF32Attr<string val> {
- Type type = F32;
+class ConstF32Attr<string val> : ConstantAttr<F32Attr> {
string value = val;
}
diff --git a/tools/mlir-tblgen/RewriterGen.cpp b/tools/mlir-tblgen/RewriterGen.cpp
index 7aa2997..e4ba02b 100644
--- a/tools/mlir-tblgen/RewriterGen.cpp
+++ b/tools/mlir-tblgen/RewriterGen.cpp
@@ -85,45 +85,32 @@
} // end namespace
void Pattern::emitAttributeValue(Record *constAttr) {
- Record *type = constAttr->getValueAsDef("type");
+ Record *attr = constAttr->getValueAsDef("attr");
auto value = constAttr->getValue("value");
+ Record *type = attr->getValueAsDef("type");
+ auto storageType = attr->getValueAsString("storageType").trim();
- // Construct the attribute based on `type`.
- // TODO(jpienaar): Generalize this to avoid hardcoding here.
- if (type->isSubClassOf("F")) {
- string mlirType;
- switch (type->getValueAsInt("bitwidth")) {
- case 32:
- mlirType = "Type::getF32(context)";
- break;
- default:
- PrintFatalError("unsupported floating point width");
- }
- // TODO(jpienaar): Verify the floating point constant here.
- os << formatv("FloatAttr::get({0}, {1})", mlirType,
+ // For attributes stored as strings we do not need to query builder etc.
+ if (storageType == "StringAttr") {
+ os << formatv("rewriter.getStringAttr({0})",
+ value->getValue()->getAsString());
+ return;
+ }
+
+ // Construct the attribute based on storage type and builder.
+ if (auto b = type->getValue("builderCall")) {
+ if (isa<UnsetInit>(b->getValue()))
+ PrintFatalError(pattern->getLoc(),
+ "no builder specified for " + type->getName());
+ CodeInit *builder = cast<CodeInit>(b->getValue());
+ // TODO(jpienaar): Verify the constants here
+ os << formatv("{0}::get(rewriter.{1}, {2})", storageType,
+ builder->getValue(),
value->getValue()->getAsUnquotedString());
return;
}
- // Fallback to the type of value.
- switch (value->getType()->getRecTyKind()) {
- case RecTy::IntRecTyKind:
- // TODO(jpienaar): This is using 64-bits for all the bitwidth of the
- // type could instead be queried. These are expected to be mostly used
- // for enums or constant indices and so no arithmetic operations are
- // expected on these.
- os << formatv("IntegerAttr::get(Type::getInteger(64, context), {0})",
- value->getValue()->getAsString());
- break;
- case RecTy::StringRecTyKind:
- os << formatv("StringAttr::get({0}, context)",
- value->getValue()->getAsString());
- break;
- default:
- PrintFatalError(pattern->getLoc(),
- Twine("unsupported/unimplemented value type for ") +
- value->getName());
- }
+ PrintFatalError(pattern->getLoc(), "unable to emit attribute");
}
void Pattern::collectBoundArguments(DagInit *tree) {
@@ -237,7 +224,6 @@
os << formatv(R"(
void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override {
- auto* context = op->getContext(); (void)context;
auto& s = *static_cast<MatchedState *>(state.get());
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
resultOp.cppClassName());
@@ -273,7 +259,7 @@
(os << ",\n").indent(6);
// The argument in the result DAG pattern.
- std::string name = resultTree->getArgNameStr(i);
+ auto name = resultOp.getArgName(i);
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
if (!value)