| //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| // |
| // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/TableGen/Attribute.h" |
| #include "mlir/TableGen/GenInfo.h" |
| #include "mlir/TableGen/Operator.h" |
| #include "mlir/TableGen/Pattern.h" |
| #include "mlir/TableGen/Predicate.h" |
| #include "mlir/TableGen/Type.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/PrettyStackTrace.h" |
| #include "llvm/Support/Signals.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/Main.h" |
| #include "llvm/TableGen/Record.h" |
| #include "llvm/TableGen/TableGenBackend.h" |
| |
| using namespace llvm; |
| using namespace mlir; |
| |
| using mlir::tblgen::DagNode; |
| using mlir::tblgen::NamedAttribute; |
| using mlir::tblgen::Operand; |
| using mlir::tblgen::Operator; |
| using mlir::tblgen::RecordOperatorMap; |
| |
| namespace { |
| class PatternEmitter { |
| public: |
| static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper, |
| raw_ostream &os); |
| |
| private: |
| PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os) |
| : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {} |
| |
| // Emits the mlir::RewritePattern struct named `rewriteName`. |
| void emit(StringRef rewriteName); |
| |
| // Emits the match() method. |
| void emitMatchMethod(DagNode tree); |
| |
| // Emits the rewrite() method. |
| void emitRewriteMethod(); |
| |
| // Emits the C++ statement to replace the matched DAG with a new op. |
| void emitReplaceOpWithNewOp(DagNode resultTree); |
| // Emits the C++ statement to replace the matched DAG with an existing value. |
| void emitReplaceWithExistingValue(DagNode resultTree); |
| // Emits the C++ statement to replace the matched DAG with a native C++ built |
| // value. |
| void emitReplaceWithNativeBuilder(DagNode resultTree); |
| |
| // Emits the value of constant attribute to `os`. |
| void emitConstantAttr(tblgen::ConstantAttr constAttr); |
| |
| // Emits C++ statements for matching the op constrained by the given DAG |
| // `tree`. |
| void emitOpMatch(DagNode tree, int depth); |
| |
| // Emits C++ statements for matching the `index`-th argument of the given DAG |
| // `tree` as an operand. |
| void emitOperandMatch(DagNode tree, int index, int depth, int indent); |
| // Emits C++ statements for matching the `index`-th argument of the given DAG |
| // `tree` as an attribute. |
| void emitAttributeMatch(DagNode tree, int index, int depth, int indent); |
| |
| private: |
| // Pattern instantiation location followed by the location of multiclass |
| // prototypes used. This is intended to be used as a whole to |
| // PrintFatalError() on errors. |
| ArrayRef<llvm::SMLoc> loc; |
| // Op's TableGen Record to wrapper object |
| RecordOperatorMap *opMap; |
| // Handy wrapper for pattern being emitted |
| tblgen::Pattern pattern; |
| raw_ostream &os; |
| }; |
| } // end namespace |
| |
| void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) { |
| auto attr = constAttr.getAttribute(); |
| |
| if (!attr.isConstBuildable()) |
| PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() + |
| " does not have the 'constBuilderCall' field"); |
| |
| // TODO(jpienaar): Verify the constants here |
| os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", |
| constAttr.getConstantValue()); |
| } |
| |
| // Helper function to match patterns. |
| void PatternEmitter::emitOpMatch(DagNode tree, int depth) { |
| Operator &op = tree.getDialectOp(opMap); |
| int indent = 4 + 2 * depth; |
| // Skip the operand matching at depth 0 as the pattern rewriter already does. |
| if (depth != 0) { |
| // Skip if there is no defining instruction (e.g., arguments to function). |
| os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); |
| os.indent(indent) << formatv( |
| "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, |
| op.getQualCppClassName()); |
| } |
| if (tree.getNumArgs() != op.getNumArgs()) { |
| PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " |
| "pattern vs. {2} in definition", |
| op.getOperationName(), tree.getNumArgs(), |
| op.getNumArgs())); |
| } |
| |
| for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { |
| auto opArg = op.getArg(i); |
| |
| // Handle nested DAG construct first |
| if (DagNode argTree = tree.getArgAsNestedDag(i)) { |
| os.indent(indent) << "{\n"; |
| os.indent(indent + 2) << formatv( |
| "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", |
| depth + 1, depth, i); |
| emitOpMatch(argTree, depth + 1); |
| os.indent(indent) << "}\n"; |
| continue; |
| } |
| |
| // Next handle DAG leaf: operand or attribute |
| if (auto *operand = opArg.dyn_cast<Operand *>()) { |
| emitOperandMatch(tree, i, depth, indent); |
| } else if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) { |
| emitAttributeMatch(tree, i, depth, indent); |
| } else { |
| PrintFatalError(loc, "unhandled case when matching op"); |
| } |
| } |
| } |
| |
| void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, |
| int indent) { |
| Operator &op = tree.getDialectOp(opMap); |
| auto *operand = op.getArg(index).get<Operand *>(); |
| auto matcher = tree.getArgAsLeaf(index); |
| |
| // If a constraint is specified, we need to generate C++ statements to |
| // check the constraint. |
| if (!matcher.isUnspecified()) { |
| if (!matcher.isOperandMatcher()) { |
| PrintFatalError( |
| loc, formatv("the {1}-th argument of op '{0}' should be an operand", |
| op.getOperationName(), index + 1)); |
| } |
| |
| // Only need to verify if the matcher's type is different from the one |
| // of op definition. |
| if (static_cast<tblgen::TypeConstraint>(operand->type) != |
| matcher.getAsTypeConstraint()) { |
| os.indent(indent) << "if (!(" |
| << formatv(matcher.getConditionTemplate().c_str(), |
| formatv("op{0}->getOperand({1})->getType()", |
| depth, index)) |
| << ")) return matchFailure();\n"; |
| } |
| } |
| |
| // Capture the value |
| auto name = tree.getArgName(index); |
| if (!name.empty()) { |
| os.indent(indent) << "state->" << name << " = op" << depth |
| << "->getOperand(" << index << ");\n"; |
| } |
| } |
| |
| void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, |
| int indent) { |
| Operator &op = tree.getDialectOp(opMap); |
| auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); |
| auto matcher = tree.getArgAsLeaf(index); |
| |
| if (!matcher.isUnspecified() && !matcher.isAttrMatcher()) { |
| PrintFatalError( |
| loc, formatv("the {1}-th argument of op '{0}' should be an attribute", |
| op.getOperationName(), index + 1)); |
| } |
| |
| // If a constraint is specified, we need to generate C++ statements to |
| // check the constraint. |
| std::string condition = |
| formatv(matcher.getConditionTemplate().c_str(), |
| formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, |
| namedAttr->attr.getStorageType(), namedAttr->getName())); |
| os.indent(indent) << "if (!(" << condition << ")) return matchFailure();\n"; |
| |
| // Capture the value |
| auto name = tree.getArgName(index); |
| if (!name.empty()) { |
| os.indent(indent) << "state->" << name << " = op" << depth |
| << "->getAttrOfType<" << namedAttr->attr.getStorageType() |
| << ">(\"" << namedAttr->getName() << "\");\n"; |
| } |
| } |
| |
| void PatternEmitter::emitMatchMethod(DagNode tree) { |
| // Emit the heading. |
| os << R"( |
| PatternMatchResult match(Instruction *op0) const override { |
| // TODO: This just handle 1 result |
| if (op0->getNumResults() != 1) return matchFailure(); |
| auto ctx = op0->getContext(); (void)ctx; |
| auto state = std::make_unique<MatchedState>();)" |
| << "\n"; |
| emitOpMatch(tree, 0); |
| os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; |
| } |
| |
| void PatternEmitter::emit(StringRef rewriteName) { |
| // Get the DAG tree for the source pattern |
| DagNode tree = pattern.getSourcePattern(); |
| |
| // TODO(jpienaar): the benefit metric is simply number of ops matched at the |
| // moment, revise. |
| unsigned benefit = tree.getNumOps(); |
| |
| const Operator &rootOp = pattern.getSourceRootOp(); |
| auto rootName = rootOp.getOperationName(); |
| |
| // Emit RewritePattern for Pattern. |
| os << formatv(R"(struct {0} : public RewritePattern { |
| {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})", |
| rewriteName, rootName, benefit) |
| << "\n"; |
| |
| // Emit matched state. |
| os << " struct MatchedState : public PatternState {\n"; |
| for (const auto &arg : pattern.getSourcePatternBoundArgs()) { |
| auto fieldName = arg.first(); |
| if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) { |
| os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName |
| << ";\n"; |
| } else { |
| os.indent(4) << "Value* " << fieldName << ";\n"; |
| } |
| } |
| os << " };\n"; |
| |
| emitMatchMethod(tree); |
| emitRewriteMethod(); |
| |
| os << "};\n"; |
| } |
| |
| void PatternEmitter::emitRewriteMethod() { |
| if (pattern.getNumResults() != 1) |
| PrintFatalError("only single result rules supported"); |
| |
| DagNode resultTree = pattern.getResultPattern(0); |
| |
| // TODO(jpienaar): Expand to multiple results. |
| for (unsigned i = 0, e = resultTree.getNumArgs(); i != e; ++i) |
| if (resultTree.getArgAsNestedDag(i)) |
| PrintFatalError(loc, "only single op result supported"); |
| |
| os << R"( |
| void rewrite(Instruction *op, std::unique_ptr<PatternState> state, |
| PatternRewriter &rewriter) const override { |
| auto& s = *static_cast<MatchedState *>(state.get()); |
| )"; |
| |
| if (resultTree.isNativeCodeBuilder()) |
| emitReplaceWithNativeBuilder(resultTree); |
| else if (resultTree.isReplaceWithValue()) |
| emitReplaceWithExistingValue(resultTree); |
| else |
| emitReplaceOpWithNewOp(resultTree); |
| |
| os << " }\n"; |
| } |
| |
| void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) { |
| Operator &resultOp = resultTree.getDialectOp(opMap); |
| auto numOpArgs = |
| resultOp.getNumOperands() + resultOp.getNumNativeAttributes(); |
| |
| os << formatv(R"( |
| rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", |
| resultOp.getCppClassName()); |
| if (numOpArgs != resultTree.getNumArgs()) { |
| PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " |
| "{1} in pattern vs. {2} in definition", |
| resultOp.getOperationName(), |
| resultTree.getNumArgs(), numOpArgs)); |
| } |
| |
| // Create the builder call for the result. |
| // Add operands. |
| int i = 0; |
| for (auto operand : resultOp.getOperands()) { |
| // Start each operand on its own line. |
| (os << ",\n").indent(6); |
| |
| auto name = resultTree.getArgName(i); |
| pattern.ensureArgBoundInSourcePattern(name); |
| if (!operand.name.empty()) |
| os << "/*" << operand.name << "=*/"; |
| os << "s." << name; |
| // TODO(jpienaar): verify types |
| ++i; |
| } |
| |
| // Add attributes. |
| for (int e = resultTree.getNumArgs(); i != e; ++i) { |
| // Start each attribute on its own line. |
| (os << ",\n").indent(6); |
| |
| auto leaf = resultTree.getArgAsLeaf(i); |
| // The argument in the result DAG pattern. |
| auto patArgName = resultTree.getArgName(i); |
| // The argument in the op definition. |
| auto opArgName = resultOp.getArgName(i); |
| |
| if (leaf.isUnspecified() || leaf.isOperandMatcher()) { |
| pattern.ensureArgBoundInSourcePattern(patArgName); |
| os << formatv("/*{0}=*/s.{1}", opArgName, patArgName); |
| } else if (leaf.isAttrTransformer()) { |
| pattern.ensureArgBoundInSourcePattern(patArgName); |
| std::string result = std::string("s.") + patArgName.str(); |
| result = formatv(leaf.getTransformationTemplate().c_str(), result); |
| os << formatv("/*{0}=*/{1}", opArgName, result); |
| } else if (leaf.isConstantAttr()) { |
| // TODO(jpienaar): Refactor out into map to avoid recomputing these. |
| auto argument = resultOp.getArg(i); |
| if (!argument.is<NamedAttribute *>()) |
| PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); |
| |
| if (!patArgName.empty()) |
| os << "/*" << patArgName << "=*/"; |
| emitConstantAttr(leaf.getAsConstantAttr()); |
| // TODO(jpienaar): verify types |
| } else { |
| PrintFatalError(loc, "unhandled case when rewriting op"); |
| } |
| } |
| os << "\n );\n"; |
| } |
| |
| void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) { |
| if (resultTree.getNumArgs() != 1) { |
| PrintFatalError(loc, "exactly one argument needed in the result pattern"); |
| } |
| |
| auto name = resultTree.getArgName(0); |
| pattern.ensureArgBoundInSourcePattern(name); |
| os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n"; |
| } |
| |
| void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) { |
| os.indent(4) << resultTree.getNativeCodeBuilder() << "(op, {"; |
| const auto &boundedValues = pattern.getSourcePatternBoundArgs(); |
| bool first = true; |
| bool printingAttr = false; |
| for (int i = 0, e = resultTree.getNumArgs(); i != e; ++i) { |
| auto name = resultTree.getArgName(i); |
| pattern.ensureArgBoundInSourcePattern(name); |
| const auto &val = boundedValues.find(name); |
| if (val->second.dyn_cast<NamedAttribute *>() && !printingAttr) { |
| os << "}, {"; |
| first = true; |
| printingAttr = true; |
| } |
| if (!first) |
| os << ","; |
| os << "s." << name; |
| first = false; |
| } |
| if (!printingAttr) |
| os << "},{"; |
| os << "}, rewriter);\n"; |
| } |
| |
| void PatternEmitter::emit(StringRef rewriteName, Record *p, |
| RecordOperatorMap *mapper, raw_ostream &os) { |
| PatternEmitter(p, mapper, os).emit(rewriteName); |
| } |
| |
| static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { |
| emitSourceFileHeader("Rewriters", os); |
| const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); |
| |
| // We put the map here because it can be shared among multiple patterns. |
| RecordOperatorMap recordOpMap; |
| |
| // Ensure unique patterns simply by appending unique suffix. |
| std::string baseRewriteName = "GeneratedConvert"; |
| int rewritePatternCount = 0; |
| for (Record *p : patterns) { |
| PatternEmitter::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), |
| p, &recordOpMap, os); |
| } |
| |
| // Emit function to add the generated matchers to the pattern list. |
| os << "void populateWithGenerated(MLIRContext *context, " |
| << "OwningRewritePatternList *patterns) {\n"; |
| for (unsigned i = 0; i != rewritePatternCount; ++i) { |
| os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName |
| << i << ">(context));\n"; |
| } |
| os << "}\n"; |
| } |
| |
| static mlir::GenRegistration |
| genRewriters("gen-rewriters", "Generate pattern rewriters", |
| [](const RecordKeeper &records, raw_ostream &os) { |
| emitRewriters(records, os); |
| return false; |
| }); |