blob: b00be1c8c95c9bda93b304c8eafc8e2ed7ad23a9 [file] [log] [blame]
//===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Pattern.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatAdapters.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
namespace llvm {
template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
raw_ostream &os, StringRef style) {
os << v.first << ":" << v.second;
}
};
} // end namespace llvm
// Returns the bound symbol for the given op argument or op named `symbol`.
//
// Arguments and ops bound in the source pattern are grouped into a
// transient `PatternState` struct. This struct can be accessed in both
// `match()` and `rewrite()` via the local variable named as `s`.
static Twine getBoundSymbol(const StringRef &symbol) {
return Twine("s.") + symbol;
}
//===----------------------------------------------------------------------===//
// PatternSymbolResolver
//===----------------------------------------------------------------------===//
namespace {
// A class for resolving symbols bound in patterns.
//
// Symbols can be bound to op arguments and ops in the source pattern and ops
// in result patterns. For example, in
//
// ```
// def : Pattern<(SrcOp:$op1 $arg0, %arg1),
// [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>;
// ```
//
// `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
// `$op2` is bound to `ResOp1`.
//
// This class keeps track of such symbols and translates them into their bound
// values.
//
// Note that we also generate local variables for unnamed DAG nodes, like
// `(ResOp3)` in the above. Since we don't bind a symbol to the op, the
// generated local variable will be implicitly named. Those implicit names are
// not tracked in this class.
class PatternSymbolResolver {
public:
PatternSymbolResolver(const StringMap<Argument> &srcArgs,
const StringSet<> &srcOperations);
// Marks the given `symbol` as bound. Returns false if the `symbol` is
// already bound.
bool add(StringRef symbol);
// Queries the substitution for the given `symbol`.
std::string query(StringRef symbol) const;
private:
// Symbols bound to arguments in source pattern.
const StringMap<Argument> &sourceArguments;
// Symbols bound to ops (for their results) in source pattern.
const StringSet<> &sourceOps;
// Symbols bound to ops (for their results) in result patterns.
StringSet<> resultOps;
};
} // end anonymous namespace
PatternSymbolResolver::PatternSymbolResolver(const StringMap<Argument> &srcArgs,
const StringSet<> &srcOperations)
: sourceArguments(srcArgs), sourceOps(srcOperations) {}
bool PatternSymbolResolver::add(StringRef symbol) {
return resultOps.insert(symbol).second;
}
std::string PatternSymbolResolver::query(StringRef symbol) const {
{
auto it = resultOps.find(symbol);
if (it != resultOps.end())
return it->getKey();
}
{
auto it = sourceArguments.find(symbol);
if (it != sourceArguments.end())
return getBoundSymbol(symbol).str();
}
{
auto it = sourceOps.find(symbol);
if (it != sourceOps.end())
return getBoundSymbol(symbol).str();
}
return {};
}
//===----------------------------------------------------------------------===//
// PatternEmitter
//===----------------------------------------------------------------------===//
namespace {
class PatternEmitter {
public:
static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper,
raw_ostream &os);
private:
PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
// Emits the mlir::RewritePattern struct named `rewriteName`.
void emit(StringRef rewriteName);
// Emits the match() method.
void emitMatchMethod(DagNode tree);
// Emits the rewrite() method.
void emitRewriteMethod();
// Emits C++ statements for matching the op constrained by the given DAG
// `tree`.
void emitOpMatch(DagNode tree, int depth);
// Emits C++ statements for matching the `index`-th argument of the given DAG
// `tree` as an operand.
void emitOperandMatch(DagNode tree, int index, int depth, int indent);
// Emits C++ statements for matching the `index`-th argument of the given DAG
// `tree` as an attribute.
void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
// Returns a unique name for an value of the given `op`.
std::string getUniqueValueName(const Operator *op);
// Entry point for handling a rewrite pattern rooted at `resultTree` and
// dispatches to concrete handlers. The given tree is the `resultIndex`-th
// argument of the enclosing DAG.
std::string handleRewritePattern(DagNode resultTree, int resultIndex,
int depth);
// Emits the C++ statement to replace the matched DAG with a value built via
// calling native C++ code.
std::string emitReplaceWithNativeCodeCall(DagNode resultTree);
// Returns the C++ expression referencing the old value serving as the
// replacement.
std::string handleReplaceWithValue(DagNode tree);
// Handles the `verifyUnusedValue` directive: emitting C++ statements to check
// the `index`-th result of the source op is not used.
void handleVerifyUnusedValue(DagNode tree, int index);
// Emits the C++ statement to build a new op out of the given DAG `tree` and
// returns the variable name that this op is assigned to. If the root op in
// DAG `tree` has a specified name, the created op will be assigned to a
// variable of the given name. Otherwise, a unique name will be used as the
// result value name.
std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
// Returns the C++ expression to construct a constant attribute of the given
// `value` for the given attribute kind `attr`.
std::string handleConstantAttr(Attribute attr, StringRef value);
// Returns the C++ expression to build an argument from the given DAG `leaf`.
// `patArgName` is used to bound the argument to the source pattern.
std::string handleOpArgument(DagLeaf leaf, llvm::StringRef patArgName);
// Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
// is already bound.
void addSymbol(DagNode node);
// Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
std::string resolveSymbol(StringRef symbol);
private:
// Pattern instantiation location followed by the location of multiclass
// prototypes used. This is intended to be used as a whole to
// PrintFatalError() on errors.
ArrayRef<llvm::SMLoc> loc;
// Op's TableGen Record to wrapper object
RecordOperatorMap *opMap;
// Handy wrapper for pattern being emitted
Pattern pattern;
PatternSymbolResolver symbolResolver;
// The next unused ID for newly created values
unsigned nextValueId;
raw_ostream &os;
// Format contexts containing placeholder substitutations for match().
FmtContext matchCtx;
// Format contexts containing placeholder substitutations for rewrite().
FmtContext rewriteCtx;
};
} // end anonymous namespace
PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
raw_ostream &os)
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
symbolResolver(pattern.getSourcePatternBoundArgs(),
pattern.getSourcePatternBoundOps()),
nextValueId(0), os(os) {
matchCtx.withBuilder("mlir::Builder(ctx)");
rewriteCtx.withBuilder("rewriter");
}
std::string PatternEmitter::handleConstantAttr(Attribute attr,
StringRef value) {
if (!attr.isConstBuildable())
PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
" does not have the 'constBuilderCall' field");
// TODO(jpienaar): Verify the constants here
return tgfmt(attr.getConstBuilderTemplate(),
&rewriteCtx.withBuilder("rewriter"), value);
}
// Helper function to match patterns.
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
Operator &op = tree.getDialectOp(opMap);
int indent = 4 + 2 * depth;
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining operation (e.g., arguments to function).
os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
os.indent(indent) << formatv(
"if (!isa<{1}>(op{0})) return matchFailure();\n", depth,
op.getQualCppClassName());
}
if (tree.getNumArgs() != op.getNumArgs()) {
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
"pattern vs. {2} in definition",
op.getOperationName(), tree.getNumArgs(),
op.getNumArgs()));
}
// If the operand's name is set, set to that variable.
auto name = tree.getOpName();
if (!name.empty())
os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
os.indent(indent) << "{\n";
os.indent(indent + 2)
<< formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
depth + 1, depth, i);
emitOpMatch(argTree, depth + 1);
os.indent(indent) << "}\n";
continue;
}
// Next handle DAG leaf: operand or attribute
if (opArg.is<NamedTypeConstraint *>()) {
emitOperandMatch(tree, i, depth, indent);
} else if (opArg.is<NamedAttribute *>()) {
emitAttributeMatch(tree, i, depth, indent);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
}
}
void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
auto matcher = tree.getArgAsLeaf(index);
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
if (!matcher.isUnspecified()) {
if (!matcher.isOperandMatcher()) {
PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an operand",
op.getOperationName(), index + 1));
}
// Only need to verify if the matcher's type is different from the one
// of op definition.
if (operand->constraint != matcher.getAsConstraint()) {
auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
os.indent(indent) << "if (!("
<< tgfmt(matcher.getConditionTemplate(),
&matchCtx.withSelf(self))
<< ")) return matchFailure();\n";
}
}
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
os.indent(indent) << getBoundSymbol(name) << " = op" << depth
<< "->getOperand(" << index << ");\n";
}
}
void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
const auto &attr = namedAttr->attr;
os.indent(indent) << "{\n";
indent += 2;
os.indent(indent) << formatv(
"auto attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
attr.getStorageType(), namedAttr->name);
// TODO(antiagainst): This should use getter method to avoid duplication.
if (attr.hasDefaultValueInitializer()) {
os.indent(indent) << "if (!attr) attr = "
<< tgfmt(attr.getConstBuilderTemplate(), &matchCtx,
attr.getDefaultValueInitializer())
<< ";\n";
} else if (attr.isOptional()) {
// For a missing attribut that is optional according to definition, we
// should just capature a mlir::Attribute() to signal the missing state.
// That is precisely what getAttr() returns on missing attributes.
} else {
os.indent(indent) << "if (!attr) return matchFailure();\n";
}
auto matcher = tree.getArgAsLeaf(index);
if (!matcher.isUnspecified()) {
if (!matcher.isAttrMatcher()) {
PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
op.getOperationName(), index + 1));
}
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
os.indent(indent) << "if (!("
<< tgfmt(matcher.getConditionTemplate(),
&matchCtx.withSelf("attr"))
<< ")) return matchFailure();\n";
}
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
os.indent(indent) << getBoundSymbol(name) << " = attr;\n";
}
indent -= 2;
os.indent(indent) << "}\n";
}
void PatternEmitter::emitMatchMethod(DagNode tree) {
// Emit the heading.
os << R"(
PatternMatchResult match(Operation *op0) const override {
auto ctx = op0->getContext(); (void)ctx;
auto state = llvm::make_unique<MatchedState>();
auto &s = *state;
)";
// The rewrite pattern may specify that certain outputs should be unused in
// the source IR. Check it here.
for (int i = 0, e = pattern.getNumResults(); i < e; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
if (resultTree.isVerifyUnusedValue()) {
handleVerifyUnusedValue(resultTree, i);
}
}
emitOpMatch(tree, 0);
for (auto &appliedConstraint : pattern.getConstraints()) {
auto &constraint = appliedConstraint.constraint;
auto &entities = appliedConstraint.entities;
auto condition = constraint.getConditionTemplate();
auto cmd = "if (!{0}) return matchFailure();\n";
if (isa<TypeConstraint>(constraint)) {
auto self = formatv("(*{0}->result_type_begin())",
resolveSymbol(entities.front()));
// TODO(jpienaar): Verify op only has one result.
os.indent(4) << formatv(cmd,
tgfmt(condition, &matchCtx.withSelf(self.str())));
} else if (isa<AttrConstraint>(constraint)) {
PrintFatalError(
loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
} else {
// TODO(fengliuai): replace formatv arguments with the exact specified
// args.
if (entities.size() > 4) {
PrintFatalError(loc, "only support up to 4-entity constraints now");
}
SmallVector<std::string, 4> names;
int i = 0;
for (int e = entities.size(); i < e; ++i)
names.push_back(resolveSymbol(entities[i]));
for (; i < 4; ++i)
names.push_back("<unused>");
os.indent(4) << formatv(cmd, tgfmt(condition, &matchCtx, names[0],
names[1], names[2], names[3]));
}
}
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
}
void PatternEmitter::emit(StringRef rewriteName) {
// Get the DAG tree for the source pattern
DagNode tree = pattern.getSourcePattern();
const Operator &rootOp = pattern.getSourceRootOp();
auto rootName = rootOp.getOperationName();
if (rootOp.getNumVariadicResults() != 0)
PrintFatalError(
loc, "replacing op with variadic results not supported right now");
// Emit RewritePattern for Pattern.
auto locs = pattern.getLocation();
os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
make_range(locs.rbegin(), locs.rend()));
os << formatv(R"(struct {0} : public RewritePattern {
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
rewriteName, rootName, pattern.getBenefit())
<< "\n";
// Emit matched state.
os << " struct MatchedState : public PatternState {\n";
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
auto fieldName = arg.first();
if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
<< ";\n";
} else {
os.indent(4) << "Value *" << fieldName << ";\n";
}
}
for (const auto &result : pattern.getSourcePatternBoundOps()) {
os.indent(4) << "Operation *" << result.getKey() << ";\n";
}
os << " };\n";
emitMatchMethod(tree);
emitRewriteMethod();
os << "};\n";
}
void PatternEmitter::emitRewriteMethod() {
const Operator &rootOp = pattern.getSourceRootOp();
int numExpectedResults = rootOp.getNumResults();
int numProvidedResults = pattern.getNumResults();
if (numProvidedResults < numExpectedResults)
PrintFatalError(
loc, "no enough result patterns to replace root op in source pattern");
os << R"(
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override {
auto& s = *static_cast<MatchedState *>(state.get());
auto loc = op->getLoc(); (void)loc;
)";
// Collect the replacement value for each result
llvm::SmallVector<std::string, 2> resultValues;
for (int i = 0; i < numProvidedResults; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
resultValues.push_back(handleRewritePattern(resultTree, i, 0));
// Keep track of bound symbols at the top-level DAG nodes
addSymbol(resultTree);
}
// Emit the final replaceOp() statement
os.indent(4) << "rewriter.replaceOp(op, {";
interleave(
// We only use the last numExpectedResults ones to replace the root op.
ArrayRef<std::string>(resultValues).take_back(numExpectedResults),
[&](const std::string &name) { os << name; }, [&]() { os << ", "; });
os << "});\n }\n";
}
std::string PatternEmitter::getUniqueValueName(const Operator *op) {
return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
}
std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
int resultIndex, int depth) {
if (resultTree.isNativeCodeCall())
return emitReplaceWithNativeCodeCall(resultTree);
if (resultTree.isVerifyUnusedValue()) {
if (depth > 0) {
// TODO: Revisit this when we have use cases of matching an intermediate
// multi-result op with no uses of its certain results.
PrintFatalError(loc, "verifyUnusedValue directive can only be used to "
"verify top-level result");
}
if (!resultTree.getOpName().empty()) {
PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
}
// The C++ statements to check that this result value is unused are already
// emitted in the match() method. So returning a nullptr here directly
// should be safe because the C++ RewritePattern harness will use it to
// replace nothing.
return "nullptr";
}
if (resultTree.isReplaceWithValue())
return handleReplaceWithValue(resultTree);
return emitOpCreate(resultTree, resultIndex, depth);
}
std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
assert(tree.isReplaceWithValue());
if (tree.getNumArgs() != 1) {
PrintFatalError(
loc, "replaceWithValue directive must take exactly one argument");
}
if (!tree.getOpName().empty()) {
PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
}
auto name = tree.getArgName(0);
pattern.ensureBoundInSourcePattern(name);
return getBoundSymbol(name).str();
}
void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
assert(tree.isVerifyUnusedValue());
os.indent(4) << "if (!op0->getResult(" << index
<< ")->use_empty()) return matchFailure();\n";
}
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
llvm::StringRef argName) {
if (leaf.isConstantAttr()) {
auto constAttr = leaf.getAsConstantAttr();
return handleConstantAttr(constAttr.getAttribute(),
constAttr.getConstantValue());
}
if (leaf.isEnumAttrCase()) {
auto enumCase = leaf.getAsEnumAttrCase();
return handleConstantAttr(enumCase, enumCase.getSymbol());
}
pattern.ensureBoundInSourcePattern(argName);
std::string result = getBoundSymbol(argName).str();
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
return result;
}
if (leaf.isNativeCodeCall()) {
return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result));
}
PrintFatalError(loc, "unhandled case when rewriting op");
}
std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) {
auto fmt = tree.getNativeCodeTemplate();
// TODO(fengliuai): replace formatv arguments with the exact specified args.
SmallVector<std::string, 8> attrs(8);
if (tree.getNumArgs() > 8) {
PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
Twine(tree.getNumArgs()));
}
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
}
return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3],
attrs[4], attrs[5], attrs[6], attrs[7]);
}
void PatternEmitter::addSymbol(DagNode node) {
StringRef symbol = node.getOpName();
// Skip empty-named symbols, which happen for unbound ops in result patterns.
if (symbol.empty())
return;
if (!symbolResolver.add(symbol))
PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
}
std::string PatternEmitter::resolveSymbol(StringRef symbol) {
auto subst = symbolResolver.query(symbol);
if (subst.empty())
PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
return subst;
}
std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
int depth) {
Operator &resultOp = tree.getDialectOp(opMap);
auto numOpArgs = resultOp.getNumArgs();
if (numOpArgs != tree.getNumArgs()) {
PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
resultOp.getOperationName(), tree.getNumArgs(),
numOpArgs));
}
if (resultOp.getNumResults() > 1) {
PrintFatalError(
loc, formatv("generating multiple-result op '{0}' is unsupported now",
resultOp.getOperationName()));
}
// A map to collect all nested DAG child nodes' names, with operand index as
// the key. This includes both bound and unbound child nodes. Bound child
// nodes will additionally be tracked in `symbolResolver` so they can be
// referenced by other patterns. Unbound child nodes will only be used once
// to build this op.
llvm::DenseMap<unsigned, std::string> childNodeNames;
// First go through all the child nodes who are nested DAG constructs to
// create ops for them, so that we can use the results in the current node.
// This happens in a recursive manner.
for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
if (auto child = tree.getArgAsNestedDag(i)) {
childNodeNames[i] = handleRewritePattern(child, i, depth + 1);
// Keep track of bound symbols at the middle-level DAG nodes
addSymbol(child);
}
}
// Use the specified name for this op if available. Generate one otherwise.
std::string resultValue = tree.getOpName();
if (resultValue.empty())
resultValue = getUniqueValueName(&resultOp);
// Then we build the new op corresponding to this DAG node.
// TODO: this is a hack to support various constant ops. We are assuming
// all of them have no operands and one attribute here. Figure out a better
// way to do this.
bool isConstOp =
resultOp.getNumOperands() == 0 && resultOp.getNumNativeAttributes() == 1;
bool isSameValueType = resultOp.hasTrait("SameOperandsAndResultType");
bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
if (isConstOp || isSameValueType || isBroadcastable || useFirstAttr) {
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue,
resultOp.getQualCppClassName());
} else {
std::string resultType = formatv("op->getResult({0})", resultIndex).str();
os.indent(4) << formatv(
"auto {0} = rewriter.create<{1}>(loc, {2}->getType()", resultValue,
resultOp.getQualCppClassName(), resultType);
}
// Create the builder call for the result.
// Add operands.
int i = 0;
for (int e = resultOp.getNumOperands(); i < e; ++i) {
const auto &operand = resultOp.getOperand(i);
// Start each operand on its own line.
(os << ",\n").indent(6);
if (!operand.name.empty())
os << "/*" << operand.name << "=*/";
if (tree.isNestedDagArg(i)) {
os << childNodeNames[i];
} else {
DagLeaf leaf = tree.getArgAsLeaf(i);
auto symbol = resolveSymbol(tree.getArgName(i));
if (leaf.isNativeCodeCall()) {
os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol));
} else {
os << symbol;
}
}
// TODO(jpienaar): verify types
}
// Add attributes.
for (int e = tree.getNumArgs(); i != e; ++i) {
// Start each attribute on its own line.
(os << ",\n").indent(6);
// The argument in the op definition.
auto opArgName = resultOp.getArgName(i);
if (auto subTree = tree.getArgAsNestedDag(i)) {
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os << formatv("/*{0}=*/{1}", opArgName,
emitReplaceWithNativeCodeCall(subTree));
} else {
auto leaf = tree.getArgAsLeaf(i);
// The argument in the result DAG pattern.
auto patArgName = tree.getArgName(i);
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
if (!argument.is<NamedAttribute *>())
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";
} else {
os << "/*" << opArgName << "=*/";
}
os << handleOpArgument(leaf, patArgName);
}
}
os << "\n );\n";
return resultValue;
}
void PatternEmitter::emit(StringRef rewriteName, Record *p,
RecordOperatorMap *mapper, raw_ostream &os) {
PatternEmitter(p, mapper, os).emit(rewriteName);
}
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Rewriters", os);
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
auto numPatterns = patterns.size();
// We put the map here because it can be shared among multiple patterns.
RecordOperatorMap recordOpMap;
std::vector<std::string> rewriterNames;
rewriterNames.reserve(numPatterns);
std::string baseRewriterName = "GeneratedConvert";
int rewriterIndex = 0;
for (Record *p : patterns) {
std::string name;
if (p->isAnonymous()) {
// If no name is provided, ensure unique rewriter names simply by
// appending unique suffix.
name = baseRewriterName + llvm::utostr(rewriterIndex++);
} else {
name = p->getName();
}
PatternEmitter::emit(name, p, &recordOpMap, os);
rewriterNames.push_back(std::move(name));
}
// Emit function to add the generated matchers to the pattern list.
os << "void populateWithGenerated(MLIRContext *context, "
<< "OwningRewritePatternList *patterns) {\n";
for (const auto &name : rewriterNames) {
os << " patterns->push_back(llvm::make_unique<" << name
<< ">(context));\n";
}
os << "}\n";
}
static mlir::GenRegistration
genRewriters("gen-rewriters", "Generate pattern rewriters",
[](const RecordKeeper &records, raw_ostream &os) {
emitRewriters(records, os);
return false;
});