blob: ab86c9dd8cc448e89d460c874cb9597f204e387f [file] [log] [blame]
//===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// EnumsGen generates common utility functions for enums.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using llvm::formatv;
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::StringRef;
using mlir::tblgen::EnumAttr;
using mlir::tblgen::EnumAttrCase;
static void emitEnumClass(const Record &enumDef, StringRef enumName,
StringRef underlyingType, StringRef description,
const std::vector<EnumAttrCase> &enumerants,
raw_ostream &os) {
os << "// " << description << "\n";
os << "enum class " << enumName;
if (!underlyingType.empty())
os << " : " << underlyingType;
os << " {\n";
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
auto value = enumerant.getValue();
if (value < 0) {
llvm::PrintFatalError(enumDef.getLoc(),
"all enumerants must have a non-negative value");
}
os << formatv(" {0} = {1},\n", symbol, value);
}
os << "};\n\n";
}
static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
StringRef cppNamespace, raw_ostream &os) {
std::string qualName = formatv("{0}::{1}", cppNamespace, enumName);
if (underlyingType.empty())
underlyingType = formatv("std::underlying_type<{0}>::type", qualName);
const char *const mapInfo = R"(
namespace llvm {
template<> struct DenseMapInfo<{0}> {{
using StorageInfo = llvm::DenseMapInfo<{1}>;
static inline {0} getEmptyKey() {{
return static_cast<{0}>(StorageInfo::getEmptyKey());
}
static inline {0} getTombstoneKey() {{
return static_cast<{0}>(StorageInfo::getTombstoneKey());
}
static unsigned getHashValue(const {0} &val) {{
return StorageInfo::getHashValue(static_cast<{1}>(val));
}
static bool isEqual(const {0} &lhs, const {0} &rhs) {{
return lhs == rhs;
}
};
})";
os << formatv(mapInfo, qualName, underlyingType);
os << "\n\n";
}
static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
std::string underlyingType = enumAttr.getUnderlyingType();
StringRef description = enumAttr.getDescription();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
auto enumerants = enumAttr.getAllCases();
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
// Emit the enum class definition
emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
// Emit coversion function declarations
os << formatv("llvm::StringRef {1}({0});\n", enumName, symToStrFnName);
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
strToSymFnName);
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
// Emit DenseMapInfo for this enum class
emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
}
static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Declarations", os);
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr");
for (const auto *def : defs)
emitEnumDecl(*def, os);
return false;
}
static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
auto enumerants = enumAttr.getAllCases();
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName);
os << " switch (val) {\n";
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
os << formatv(" case {0}::{1}: return \"{1}\";\n", enumName, symbol);
}
os << " }\n";
os << " return \"\";\n";
os << "}\n\n";
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
strToSymFnName);
os << formatv(" return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n",
enumName);
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
os << formatv(" .Case(\"{1}\", {0}::{1})\n", enumName, symbol);
}
os << " .Default(llvm::None);\n";
os << "}\n";
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
os << "\n";
}
static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Definitions", os);
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr");
for (const auto *def : defs)
emitEnumDef(*def, os);
return false;
}
// Registers the enum utility generator to mlir-tblgen.
static mlir::GenRegistration
genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitEnumDecls(records, os);
});
// Registers the enum utility generator to mlir-tblgen.
static mlir::GenRegistration
genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitEnumDefs(records, os);
});