blob: d282dfc5058dd23912d4f1f3ea78393802e3b545 [file] [log] [blame]
//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// OpDefinitionsGen uses the description of operations to generate C++
// definitions for ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
using mlir::tblgen::Operator;
static const char *const generatedArgName = "_arg";
// Helper macro that returns indented os.
#define OUT(X) os.indent((X))
// TODO(jpienaar): The builder body should probably be separate from the header.
// Variation of method in FormatVariadic.h which takes a StringRef as input
// instead.
template <typename... Ts>
inline auto formatv(StringRef fmt, Ts &&... vals) -> formatv_object<decltype(
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...))> {
using ParamTuple = decltype(
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...));
return llvm::formatv_object<ParamTuple>(
fmt,
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...));
}
// Returns whether the record has a value of the given name that can be returned
// via getValueAsString.
static inline bool hasStringAttribute(const Record &record,
StringRef fieldName) {
auto valueInit = record.getValueInit(fieldName);
return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit);
}
static std::string getArgumentName(const Operator &op, int index) {
const auto &operand = op.getOperand(index);
if (!operand.name.empty())
return operand.name;
else
return formatv("{0}_{1}", generatedArgName, index);
}
namespace {
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
public:
IfDefScope(StringRef name, raw_ostream &os) : name(name), os(os) {
os << "#ifdef " << name << "\n"
<< "#undef " << name << "\n";
}
~IfDefScope() { os << "\n#endif // " << name << "\n\n"; }
private:
StringRef name;
raw_ostream &os;
};
} // end anonymous namespace
namespace {
// Helper class to emit a record into the given output stream.
class OpEmitter {
public:
static void emit(const Record &def, raw_ostream &os);
// Emit getters for the attributes of the operation.
void emitAttrGetters();
// Emit query methods for the named operands.
void emitNamedOperands();
// Emit builder method for the operation.
void emitBuilder();
// Emit method declaration for the getCanonicalizationPatterns() interface.
void emitCanonicalizationPatterns();
// Emit the folder methods for the operation.
void emitFolders();
// Emit the parser for the operation.
void emitParser();
// Emit the printer for the operation.
void emitPrinter();
// Emit verify method for the operation.
void emitVerifier();
// Emit the traits used by the object.
void emitTraits();
private:
OpEmitter(const Record &def, raw_ostream &os);
// Invokes the given function over all the namespaces of the class.
void mapOverClassNamespaces(function_ref<void(StringRef)> fn);
// The record corresponding to the op.
const Record &def;
// The operator being emitted.
Operator op;
raw_ostream &os;
};
} // end anonymous namespace
OpEmitter::OpEmitter(const Record &def, raw_ostream &os)
: def(def), op(def), os(os) {}
void OpEmitter::mapOverClassNamespaces(function_ref<void(StringRef)> fn) {
auto &splittedDefName = op.getSplitDefName();
for (auto it = splittedDefName.begin(), e = std::prev(splittedDefName.end());
it != e; ++it)
fn(*it);
}
void OpEmitter::emit(const Record &def, raw_ostream &os) {
OpEmitter emitter(def, os);
emitter.mapOverClassNamespaces(
[&os](StringRef ns) { os << "\nnamespace " << ns << "{\n"; });
os << formatv("class {0} : public Op<{0}", emitter.op.getCppClassName());
emitter.emitTraits();
os << "> {\npublic:\n";
// Build operation name.
OUT(2) << "static StringRef getOperationName() { return \""
<< emitter.op.getOperationName() << "\"; };\n";
emitter.emitNamedOperands();
emitter.emitBuilder();
emitter.emitParser();
emitter.emitPrinter();
emitter.emitVerifier();
emitter.emitAttrGetters();
emitter.emitCanonicalizationPatterns();
emitter.emitFolders();
os << "private:\n friend class ::mlir::Instruction;\n"
<< " explicit " << emitter.op.getCppClassName()
<< "(const Instruction* state) : Op(state) {}\n};\n";
emitter.mapOverClassNamespaces(
[&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; });
}
void OpEmitter::emitAttrGetters() {
for (auto &namedAttr : op.getAttributes()) {
auto name = namedAttr.getName();
const auto &attr = namedAttr.attr;
// Determine the name of the attribute getter. The name matches the
// attribute name excluding dialect prefix.
StringRef getter = name;
auto it = getter.rfind('$');
if (it != StringRef::npos)
getter = getter.substr(it + 1);
// Emit the derived attribute body.
if (attr.isDerivedAttr()) {
OUT(2) << attr.getReturnType() << ' ' << getter << "() const {"
<< attr.getDerivedCodeBody() << " }\n";
continue;
}
// Emit normal emitter.
OUT(2) << attr.getReturnType() << ' ' << getter << "() const {\n";
// Return the queried attribute with the correct return type.
std::string attrVal = formatv("this->getAttr(\"{1}\").dyn_cast<{0}>()",
attr.getStorageType(), name);
OUT(4) << "auto attr = " << attrVal << ";\n";
if (attr.hasDefaultValue()) {
// Returns the default value if not set.
// TODO: this is inefficient, we are recreating the attribute for every
// call. This should be set instead.
OUT(4) << "if (!attr)\n";
OUT(6) << "return "
<< formatv(
attr.getConvertFromStorageCall(),
formatv(
attr.getDefaultValueTemplate(),
"mlir::Builder(this->getInstruction()->getContext())"))
<< ";\n";
}
OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), "attr")
<< ";\n }\n";
}
}
void OpEmitter::emitNamedOperands() {
const auto operandMethods = R"( Value *{0}() {
return this->getInstruction()->getOperand({1});
}
const Value *{0}() const {
return this->getInstruction()->getOperand({1});
}
)";
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
const auto &operand = op.getOperand(i);
if (!operand.type.isVariadic() && !operand.name.empty())
os << formatv(operandMethods, operand.name, i);
}
}
void OpEmitter::emitBuilder() {
if (hasStringAttribute(def, "builder")) {
// If a custom builder is given then print that out instead.
auto builder = def.getValueAsString("builder");
if (!builder.empty())
os << builder << '\n';
}
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
// We generate two builders here, one having a stand-alone parameter for
// each result type / operand / attribute, the other having an aggregated
// parameter for all result types / operands / attributes, to facilitate
// different call patterns.
// 1. Stand-alone parameters
OUT(2) << "static void build(Builder* builder, OperationState* result";
auto numResults = op.getNumResults();
// Emit parameters for all return types
for (unsigned i = 0, e = numResults; i != e; ++i)
os << ", Type returnType" << i;
// Emit parameters for all operands
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
auto &operand = op.getOperand(i);
os << (operand.type.isVariadic() ? ", ArrayRef<Value*> " : ", Value* ")
<< getArgumentName(op, i);
}
// Emit parameters for all attributes
// TODO(antiagainst): Support default initializer for attributes
for (const auto &namedAttr : op.getAttributes()) {
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr())
break;
os << ", " << attr.getStorageType() << ' ' << namedAttr.getName();
}
os << ") {\n";
// Push all result types to the result
if (numResults > 0) {
OUT(4) << "result->addTypes({returnType0";
for (unsigned i = 1; i != numResults; ++i)
os << ", returnType" << i;
os << "});\n\n";
}
// Push all operands to the result
auto numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand();
int numNonVariadicOperands = numOperands - int(hasVariadicOperand);
if (numNonVariadicOperands > 0) {
OUT(4) << "result->addOperands({" << getArgumentName(op, 0);
for (int i = 1, e = numNonVariadicOperands; i < e; ++i) {
os << ", " << getArgumentName(op, i);
}
os << "});\n";
}
if (hasVariadicOperand) {
OUT(4) << "result->addOperands(" << getArgumentName(op, numOperands - 1)
<< ");\n";
}
// Push all attributes to the result
for (const auto &namedAttr : op.getAttributes())
if (!namedAttr.attr.isDerivedAttr())
OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
namedAttr.getName());
OUT(2) << "}\n";
// 2. Aggregated parameters
// Signature
OUT(2) << "static void build(Builder* builder, OperationState* result, "
<< "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
"ArrayRef<NamedAttribute> attributes) {\n";
// Result types
OUT(4) << "assert(resultTypes.size() == " << numResults
<< "u && \"mismatched number of return types\");\n"
<< " result->addTypes(resultTypes);\n";
// Operands
OUT(4) << "assert(args.size() == " << numNonVariadicOperands
<< "u && \"mismatched number of parameters\");\n"
<< " result->addOperands(args);\n\n";
// Attributes
if (op.getNumAttributes() > 0) {
OUT(4) << "assert(!attributes.size() && \"no attributes expected\");\n"
<< " }\n";
} else {
OUT(4) << "assert(attributes.size() >= " << op.getNumAttributes()
<< "u && \"not enough attributes\");\n"
<< " for (const auto& pair : attributes)\n"
<< " result->addAttribute(pair.first, pair.second);\n"
<< " }\n";
}
}
void OpEmitter::emitCanonicalizationPatterns() {
if (!def.getValueAsBit("hasCanonicalizer"))
return;
OUT(2) << "static void getCanonicalizationPatterns("
<< "OwningRewritePatternList &results, MLIRContext* context);\n";
}
void OpEmitter::emitFolders() {
bool hasSingleResult = op.getNumResults() == 1;
if (def.getValueAsBit("hasConstantFolder")) {
if (hasSingleResult) {
os << " Attribute constantFold(ArrayRef<Attribute> operands,\n"
" MLIRContext *context) const;\n";
} else {
os << " bool constantFold(ArrayRef<Attribute> operands,\n"
<< " SmallVectorImpl<Attribute> &results,\n"
<< " MLIRContext *context) const;\n";
}
}
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
os << " Value *fold();\n";
} else {
os << " bool fold(SmallVectorImpl<Value *> &results);\n";
}
}
}
void OpEmitter::emitParser() {
if (!hasStringAttribute(def, "parser"))
return;
os << " static bool parse(OpAsmParser *parser, OperationState *result) {"
<< "\n " << def.getValueAsString("parser") << "\n }\n";
}
void OpEmitter::emitPrinter() {
auto valueInit = def.getValueInit("printer");
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
if (!codeInit)
return;
auto printer = codeInit->getValue();
os << " void print(OpAsmPrinter *p) const {\n"
<< " " << printer << "\n }\n";
}
void OpEmitter::emitVerifier() {
auto valueInit = def.getValueInit("verifier");
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
if (!hasCustomVerify && op.getNumArgs() == 0)
return;
OUT(2) << "bool verify() const {\n";
// Verify the attributes have the correct type.
for (const auto &namedAttr : op.getAttributes()) {
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr())
continue;
auto name = namedAttr.getName();
if (!attr.hasStorageType() && !attr.hasDefaultValue()) {
// TODO: Some verification can be done even without storage type.
OUT(4) << "if (!this->getAttr(\"" << name
<< "\")) return emitOpError(\"requires attribute '" << name
<< "'\");\n";
continue;
}
if (attr.hasDefaultValue()) {
// If the attribute has a default value, then only verify the predicate if
// set. This does effectively assume that the default value is valid.
// TODO: verify the debug value is valid (perhaps in debug mode only).
OUT(4) << "if (this->getAttr(\"" << name << "\")) {\n";
}
OUT(6) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
<< attr.getStorageType() << ">()) return emitOpError(\"requires "
<< attr.getReturnType() << " attribute '" << name << "'\");\n";
auto attrPred = attr.getPredicate();
if (!attrPred.isNull()) {
OUT(6) << formatv("if (!({0})) return emitOpError(\"attribute '{1}' "
"failed to satisfy constraint of {2}\");\n",
formatv(attrPred.getCondition(),
formatv("this->getAttr(\"{0}\")", name)),
name, attr.getTableGenDefName());
}
if (attr.hasDefaultValue())
OUT(4) << "}\n";
}
int opIndex = 0;
for (const auto &operand : op.getOperands()) {
// TODO: Handle variadic operand verification.
if (operand.type.isVariadic())
continue;
// TODO: Commonality between matchers could be extracted to have a more
// concise code.
if (operand.hasMatcher()) {
auto description = operand.type.getDescription();
OUT(4) << "if (!("
<< formatv(operand.type.getConditionTemplate(),
"this->getInstruction()->getOperand(" + Twine(opIndex) +
")->getType()")
<< ")) {\n";
OUT(6) << "return emitOpError(\"operand #" + Twine(opIndex)
<< (description.empty() ? " type precondition failed"
: " must be " + Twine(description))
<< "\");";
OUT(4) << "}\n";
}
++opIndex;
}
if (hasCustomVerify)
OUT(4) << codeInit->getValue() << "\n";
else
OUT(4) << "return false;\n";
OUT(2) << "}\n";
}
void OpEmitter::emitTraits() {
auto numResults = op.getNumResults();
// Add return size trait.
switch (numResults) {
case 0:
os << ", OpTrait::ZeroResult";
break;
case 1:
os << ", OpTrait::OneResult";
break;
default:
os << ", OpTrait::NResults<" << numResults << ">::Impl";
break;
}
// Add variadic size trait and normal op traits.
for (StringRef trait : def.getValueAsListOfStrings("traits")) {
os << ", OpTrait::" << trait;
}
auto numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand();
// Add operand size trait.
os << ", OpTrait::";
if (hasVariadicOperand) {
if (numOperands == 1)
os << "VariadicOperands";
else
os << "AtLeastNOperands<" << (numOperands - 1) << ">::Impl";
} else {
switch (op.getNumOperands()) {
case 0:
os << "ZeroOperands";
break;
case 1:
os << "OneOperand";
break;
default:
os << "NOperands<" << numOperands << ">::Impl";
break;
}
}
}
// Emits the opcode enum and op classes.
static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_CLASSES", os);
for (auto *def : defs)
OpEmitter::emit(*def, os);
}
// Emits a comma-separated list of the ops.
static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_LIST", os);
bool first = true;
for (auto &def : defs) {
if (!first)
os << ",";
os << Operator(def).getQualCppClassName();
first = false;
}
}
static void emitOpDefinitions(const RecordKeeper &recordKeeper,
raw_ostream &os) {
emitSourceFileHeader("List of ops", os);
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
emitOpList(defs, os);
emitOpClasses(defs, os);
}
static void emitOpDefFile(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op def file", os);
const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
os << "#ifndef ALL_OPS\n#define ALL_OPS(OP, NAME)\n#endif\n";
for (const auto *def : defs) {
os << "ALL_OPS(" << def->getName() << ", \""
<< def->getValueAsString("opName") << "\")\n";
}
os << "#undef ALL_OPS";
}
static mlir::GenRegistration
genOpDefinitions("gen-op-definitions", "Generate op definitions",
[](const RecordKeeper &records, raw_ostream &os) {
emitOpDefinitions(records, os);
return false;
});