blob: 1e6eecb3ed7bcc62295762526f7827909d693db2 [file] [log] [blame]
//===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Pattern.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
using mlir::tblgen::DagNode;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::Operand;
using mlir::tblgen::Operator;
using mlir::tblgen::RecordOperatorMap;
namespace {
class PatternEmitter {
public:
static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
raw_ostream &os);
private:
PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os)
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {}
// Emits the mlir::RewritePattern struct named `rewriteName`.
void emit(StringRef rewriteName);
// Emits the match() method.
void emitMatchMethod(DagNode tree);
// Emits the rewrite() method.
void emitRewriteMethod();
// Emits the C++ statement to replace the matched DAG with a new op.
void emitReplaceOpWithNewOp(DagNode resultTree);
// Emits the C++ statement to replace the matched DAG with an existing value.
void emitReplaceWithExistingValue(DagNode resultTree);
// Emits the C++ statement to replace the matched DAG with a native C++ built
// value.
void emitReplaceWithNativeBuilder(DagNode resultTree);
// Emits the value of constant attribute to `os`.
void emitConstantAttr(tblgen::ConstantAttr constAttr);
// Emits C++ statements for matching the op constrained by the given DAG
// `tree`.
void emitOpMatch(DagNode tree, int depth);
// Emits C++ statements for matching the `index`-th argument of the given DAG
// `tree` as an operand.
void emitOperandMatch(DagNode tree, int index, int depth, int indent);
// Emits C++ statements for matching the `index`-th argument of the given DAG
// `tree` as an attribute.
void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
private:
// Pattern instantiation location followed by the location of multiclass
// prototypes used. This is intended to be used as a whole to
// PrintFatalError() on errors.
ArrayRef<llvm::SMLoc> loc;
// Op's TableGen Record to wrapper object
RecordOperatorMap *opMap;
// Handy wrapper for pattern being emitted
tblgen::Pattern pattern;
raw_ostream &os;
};
} // end namespace
void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) {
auto attr = constAttr.getAttribute();
if (!attr.isConstBuildable())
PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
" does not have the 'constBuilderCall' field");
// TODO(jpienaar): Verify the constants here
os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
constAttr.getConstantValue());
}
// Helper function to match patterns.
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
Operator &op = tree.getDialectOp(opMap);
int indent = 4 + 2 * depth;
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining instruction (e.g., arguments to function).
os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
os.indent(indent) << formatv(
"if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
op.getQualCppClassName());
}
if (tree.getNumArgs() != op.getNumArgs()) {
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
"pattern vs. {2} in definition",
op.getOperationName(), tree.getNumArgs(),
op.getNumArgs()));
}
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
os.indent(indent) << "{\n";
os.indent(indent + 2) << formatv(
"auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
depth + 1, depth, i);
emitOpMatch(argTree, depth + 1);
os.indent(indent) << "}\n";
continue;
}
// Next handle DAG leaf: operand or attribute
if (auto *operand = opArg.dyn_cast<Operand *>()) {
emitOperandMatch(tree, i, depth, indent);
} else if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
emitAttributeMatch(tree, i, depth, indent);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
}
}
void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(index).get<Operand *>();
auto matcher = tree.getArgAsLeaf(index);
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
if (!matcher.isUnspecified()) {
if (!matcher.isOperandMatcher()) {
PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an operand",
op.getOperationName(), index + 1));
}
// Only need to verify if the matcher's type is different from the one
// of op definition.
if (static_cast<tblgen::TypeConstraint>(operand->type) !=
matcher.getAsTypeConstraint()) {
os.indent(indent) << "if (!("
<< formatv(matcher.getConditionTemplate().c_str(),
formatv("op{0}->getOperand({1})->getType()",
depth, index))
<< ")) return matchFailure();\n";
}
}
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getOperand(" << index << ");\n";
}
}
void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
auto matcher = tree.getArgAsLeaf(index);
if (!matcher.isUnspecified() && !matcher.isAttrMatcher()) {
PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
op.getOperationName(), index + 1));
}
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
std::string condition =
formatv(matcher.getConditionTemplate().c_str(),
formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
namedAttr->attr.getStorageType(), namedAttr->getName()));
os.indent(indent) << "if (!(" << condition << ")) return matchFailure();\n";
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getAttrOfType<" << namedAttr->attr.getStorageType()
<< ">(\"" << namedAttr->getName() << "\");\n";
}
}
void PatternEmitter::emitMatchMethod(DagNode tree) {
// Emit the heading.
os << R"(
PatternMatchResult match(Instruction *op0) const override {
// TODO: This just handle 1 result
if (op0->getNumResults() != 1) return matchFailure();
auto ctx = op0->getContext(); (void)ctx;
auto state = std::make_unique<MatchedState>();)"
<< "\n";
emitOpMatch(tree, 0);
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
}
void PatternEmitter::emit(StringRef rewriteName) {
// Get the DAG tree for the source pattern
DagNode tree = pattern.getSourcePattern();
// TODO(jpienaar): the benefit metric is simply number of ops matched at the
// moment, revise.
unsigned benefit = tree.getNumOps();
const Operator &rootOp = pattern.getSourceRootOp();
auto rootName = rootOp.getOperationName();
// Emit RewritePattern for Pattern.
os << formatv(R"(struct {0} : public RewritePattern {
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
rewriteName, rootName, benefit)
<< "\n";
// Emit matched state.
os << " struct MatchedState : public PatternState {\n";
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
auto fieldName = arg.first();
if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
<< ";\n";
} else {
os.indent(4) << "Value* " << fieldName << ";\n";
}
}
os << " };\n";
emitMatchMethod(tree);
emitRewriteMethod();
os << "};\n";
}
void PatternEmitter::emitRewriteMethod() {
if (pattern.getNumResults() != 1)
PrintFatalError("only single result rules supported");
DagNode resultTree = pattern.getResultPattern(0);
// TODO(jpienaar): Expand to multiple results.
for (unsigned i = 0, e = resultTree.getNumArgs(); i != e; ++i)
if (resultTree.getArgAsNestedDag(i))
PrintFatalError(loc, "only single op result supported");
os << R"(
void rewrite(Instruction *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override {
auto& s = *static_cast<MatchedState *>(state.get());
)";
if (resultTree.isNativeCodeBuilder())
emitReplaceWithNativeBuilder(resultTree);
else if (resultTree.isReplaceWithValue())
emitReplaceWithExistingValue(resultTree);
else
emitReplaceOpWithNewOp(resultTree);
os << " }\n";
}
void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
Operator &resultOp = resultTree.getDialectOp(opMap);
auto numOpArgs =
resultOp.getNumOperands() + resultOp.getNumNativeAttributes();
os << formatv(R"(
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
resultOp.getCppClassName());
if (numOpArgs != resultTree.getNumArgs()) {
PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
resultOp.getOperationName(),
resultTree.getNumArgs(), numOpArgs));
}
// Create the builder call for the result.
// Add operands.
int i = 0;
for (auto operand : resultOp.getOperands()) {
// Start each operand on its own line.
(os << ",\n").indent(6);
auto name = resultTree.getArgName(i);
pattern.ensureArgBoundInSourcePattern(name);
if (!operand.name.empty())
os << "/*" << operand.name << "=*/";
os << "s." << name;
// TODO(jpienaar): verify types
++i;
}
// Add attributes.
for (int e = resultTree.getNumArgs(); i != e; ++i) {
// Start each attribute on its own line.
(os << ",\n").indent(6);
auto leaf = resultTree.getArgAsLeaf(i);
// The argument in the result DAG pattern.
auto patArgName = resultTree.getArgName(i);
// The argument in the op definition.
auto opArgName = resultOp.getArgName(i);
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
pattern.ensureArgBoundInSourcePattern(patArgName);
os << formatv("/*{0}=*/s.{1}", opArgName, patArgName);
} else if (leaf.isAttrTransformer()) {
pattern.ensureArgBoundInSourcePattern(patArgName);
std::string result = std::string("s.") + patArgName.str();
result = formatv(leaf.getTransformationTemplate().c_str(), result);
os << formatv("/*{0}=*/{1}", opArgName, result);
} else if (leaf.isConstantAttr()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
if (!argument.is<NamedAttribute *>())
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";
emitConstantAttr(leaf.getAsConstantAttr());
// TODO(jpienaar): verify types
} else {
PrintFatalError(loc, "unhandled case when rewriting op");
}
}
os << "\n );\n";
}
void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) {
if (resultTree.getNumArgs() != 1) {
PrintFatalError(loc, "exactly one argument needed in the result pattern");
}
auto name = resultTree.getArgName(0);
pattern.ensureArgBoundInSourcePattern(name);
os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n";
}
void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
os.indent(4) << resultTree.getNativeCodeBuilder() << "(op, {";
const auto &boundedValues = pattern.getSourcePatternBoundArgs();
bool first = true;
bool printingAttr = false;
for (int i = 0, e = resultTree.getNumArgs(); i != e; ++i) {
auto name = resultTree.getArgName(i);
pattern.ensureArgBoundInSourcePattern(name);
const auto &val = boundedValues.find(name);
if (val->second.dyn_cast<NamedAttribute *>() && !printingAttr) {
os << "}, {";
first = true;
printingAttr = true;
}
if (!first)
os << ",";
os << "s." << name;
first = false;
}
if (!printingAttr)
os << "},{";
os << "}, rewriter);\n";
}
void PatternEmitter::emit(StringRef rewriteName, Record *p,
RecordOperatorMap *mapper, raw_ostream &os) {
PatternEmitter(p, mapper, os).emit(rewriteName);
}
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Rewriters", os);
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
// We put the map here because it can be shared among multiple patterns.
RecordOperatorMap recordOpMap;
// Ensure unique patterns simply by appending unique suffix.
std::string baseRewriteName = "GeneratedConvert";
int rewritePatternCount = 0;
for (Record *p : patterns) {
PatternEmitter::emit(baseRewriteName + llvm::utostr(rewritePatternCount++),
p, &recordOpMap, os);
}
// Emit function to add the generated matchers to the pattern list.
os << "void populateWithGenerated(MLIRContext *context, "
<< "OwningRewritePatternList *patterns) {\n";
for (unsigned i = 0; i != rewritePatternCount; ++i) {
os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName
<< i << ">(context));\n";
}
os << "}\n";
}
static mlir::GenRegistration
genRewriters("gen-rewriters", "Generate pattern rewriters",
[](const RecordKeeper &records, raw_ostream &os) {
emitRewriters(records, os);
return false;
});