blob: bb23187f0461e0d19ae386572d12e948c1d31dc8 [file] [log] [blame]
//===- mlir-rewriter-gen.cpp - MLIR pattern rewriter generator ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is a command line utility that generates rewrite patterns from
// declaritive description.
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
enum ActionType { GenRewriters };
static cl::opt<ActionType>
action(cl::desc("Action to perform:"),
cl::values(clEnumValN(GenRewriters, "gen-rewriters",
"Generate rewriter definitions")));
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Rewriters", os);
const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
Record *attrClass = recordKeeper.getClass("Attr");
// Ensure unique patterns simply by appending unique suffix.
unsigned rewritePatternCount = 0;
std::string baseRewriteName = "GeneratedConvert";
for (Record *pattern : patterns) {
DagInit *tree = pattern->getValueAsDag("PatternToMatch");
StringMap<int> nameToOrdinal;
for (int i = 0, e = tree->getNumArgs(); i != e; ++i)
nameToOrdinal[tree->getArgNameStr(i)] = i;
// TODO(jpienaar): Expand to multiple matches.
for (auto arg : tree->getArgs()) {
if (isa<DagInit>(arg))
PrintFatalError(pattern->getLoc(),
"only single pattern inputs supported");
}
// Emit RewritePattern for Pattern.
DefInit *root = cast<DefInit>(tree->getOperator());
std::string rewriteName =
baseRewriteName + llvm::utostr(rewritePatternCount++);
auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
os << "struct " << rewriteName << " : public RewritePattern {\n"
<< " " << rewriteName << "(MLIRContext *context) : RewritePattern("
<< rootName->getAsString() << ", 1, context) {}\n"
<< " PatternMatchResult match(Operation *op) const override {\n"
<< " // TODO: This just handle 1 result\n"
<< " if (op->getNumResults() != 1) return matchFailure();\n"
<< " return matchSuccess();\n }\n";
ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
if (resultOps->size() != 1)
PrintFatalError("only single result rules supported");
DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
// TODO(jpienaar): Expand to multiple results.
for (auto result : resultTree->getArgs()) {
if (isa<DagInit>(result))
PrintFatalError(pattern->getLoc(), "only single op result supported");
}
DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
std::string opName = resultRoot->getAsUnquotedString();
auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
SmallVector<StringRef, 2> split;
SplitString(opName, split, "_");
auto className = join(split, "::");
os << formatv(R"(
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto* context = op->getContext(); (void)context;
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
className);
if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
PrintFatalError(pattern->getLoc(),
Twine("mismatch between arguments of resultant op (") +
Twine(resultOperands->getNumArgs()) +
") and arguments provided for rewrite (" +
Twine(resultTree->getNumArgs()) + Twine(')'));
}
// Create the builder call for the result.
for (int i = 0, e = resultTree->getNumArgs(); i != e; ++i) {
// Start each operand on its own line.
(os << ",\n").indent(6);
auto *arg = resultTree->getArg(i);
std::string name = resultTree->getArgName(i)->getAsUnquotedString();
auto defInit = dyn_cast<DefInit>(arg);
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto *argument = resultOperands->getArg(i);
auto argumentDefInit = dyn_cast<DefInit>(argument);
bool argumentIsAttr = false;
if (argumentDefInit) {
if (auto recTy = dyn_cast<RecordRecTy>(argumentDefInit->getType()))
argumentIsAttr = recTy->isSubClassOf(attrClass);
}
if (argumentIsAttr) {
if (!defInit) {
std::string argumentName =
resultOperands->getArgName(i)->getAsUnquotedString();
PrintFatalError(pattern->getLoc(),
Twine("attribute '") + argumentName +
"' needs to be constant initialized");
}
auto value = defInit->getDef()->getValue("value");
if (!value)
PrintFatalError(pattern->getLoc(), Twine("'value' not defined in ") +
arg->getAsString());
switch (value->getType()->getRecTyKind()) {
case RecTy::IntRecTyKind:
// TODO(jpienaar): This is using 64-bits for all the bitwidth of the
// type could instead be queried. These are expected to be mostly used
// for enums or constant indices and so no arithmetic operations are
// expected on these.
os << formatv(
"/*{0}=*/IntegerAttr::get(Type::getInteger(64, context), {1})",
name, value->getValue()->getAsString());
break;
case RecTy::StringRecTyKind:
os << formatv("/*{0}=*/StringAttr::get({1}, context)", name,
value->getValue()->getAsString());
break;
default:
PrintFatalError(pattern->getLoc(),
Twine("unsupported/unimplemented value type for ") +
name);
}
continue;
}
// Verify the types match between the rewriter's result and the
if (defInit && argumentDefInit &&
defInit->getType() != argumentDefInit->getType()) {
PrintFatalError(
pattern->getLoc(),
"mismatch in type of operation's argument and rewrite argument " +
Twine(i));
}
// Lookup the ordinal for the named operand.
auto ord = nameToOrdinal.find(name);
if (ord == nameToOrdinal.end())
PrintFatalError(pattern->getLoc(),
Twine("unknown named operand '") + name + "'");
os << "/*" << name << "=*/op->getOperand(" << ord->getValue() << ")";
}
os << "\n );\n }\n};\n";
}
// Emit function to add the generated matchers to the pattern list.
os << "void populateWithGenerated(MLIRContext *context, "
<< "OwningRewritePatternList *patterns) {\n";
for (unsigned i = 0; i != rewritePatternCount; ++i) {
os << " patterns->push_back(std::make_unique<" << baseRewriteName << i
<< ">(context));\n";
}
os << "}\n";
}
static bool MlirOpTableGenMain(raw_ostream &os, RecordKeeper &records) {
switch (action) {
case GenRewriters:
emitRewriters(records, os);
return false;
}
}
int main(int argc, char **argv) {
sys::PrintStackTraceOnErrorSignal(argv[0]);
PrettyStackTraceProgram X(argc, argv);
cl::ParseCommandLineOptions(argc, argv);
llvm_shutdown_obj Y;
return TableGenMain(argv[0], &MlirOpTableGenMain);
}