blob: 25d0adea7176216c0cbcb7a87f16ad554aee8460 [file] [log] [blame]
//===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// SPIRVSerializationGen generates common utility functions for SPIR-V
// serialization.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/StringExtras.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using llvm::ArrayRef;
using llvm::formatv;
using llvm::raw_ostream;
using llvm::raw_string_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::SmallVector;
using llvm::SMLoc;
using llvm::StringMap;
using llvm::StringRef;
using llvm::Twine;
using mlir::tblgen::Attribute;
using mlir::tblgen::EnumAttr;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::Operator;
//===----------------------------------------------------------------------===//
// Serialization AutoGen
//===----------------------------------------------------------------------===//
// Writes the following function to `os`:
// inline uint32_t getOpcode(<op-class-name>) { return <opcode>; }
static void emitGetOpcodeFunction(const Record *record, Operator const &op,
raw_ostream &os) {
os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
"getOpcode<{0}>() {{\n",
op.getQualCppClassName());
os << formatv(" return ::mlir::spirv::Opcode::{0};\n",
record->getValueAsString("spirvOpName"));
os << "}\n";
}
/// Forward declaration of function to return the SPIR-V opcode corresponding to
/// an operation. This function will be generated for all SPV_Op instances that
/// have hasOpcode = 1.
static void declareOpcodeFn(raw_ostream &os) {
os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
"getOpcode();\n";
}
/// Generates code to serialize attributes of a SPV_Op `op` into `os`. The
/// generates code extracts the attribute with name `attrName` from
/// `operandList` of `op`.
static void emitAttributeSerialization(const Attribute &attr,
ArrayRef<SMLoc> loc, StringRef tabs,
StringRef opVar, StringRef operandList,
StringRef attrName, raw_ostream &os) {
os << tabs << formatv("auto attr = {0}.getAttr(\"{1}\");\n", opVar, attrName);
os << tabs << "if (attr) {\n";
if (attr.getAttrDefName() == "I32ArrayAttr") {
// Serialize all the elements of the array
os << tabs << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
os << tabs
<< formatv(" {0}.push_back(static_cast<uint32_t>("
"attrElem.cast<IntegerAttr>().getValue().getZExtValue()));\n",
operandList);
os << tabs << " }\n";
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
os << tabs
<< formatv(" {0}.push_back(static_cast<uint32_t>("
"attr.cast<IntegerAttr>().getValue().getZExtValue()));\n",
operandList);
} else {
PrintFatalError(
loc,
llvm::Twine(
"unhandled attribute type in SPIR-V serialization generation : '") +
attr.getAttrDefName() + llvm::Twine("'"));
}
os << tabs << "}\n";
}
/// Generates code to serialize the operands of a SPV_Op `op` into `os`. The
/// generated queries the SSA-ID if operand is a SSA-Value, or serializes the
/// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
/// updated as well to include the serialized attributes.
static void emitOperandSerialization(const Operator &op, ArrayRef<SMLoc> loc,
StringRef tabs, StringRef opVar,
StringRef operands, StringRef elidedAttrs,
raw_ostream &os) {
auto operandNum = 0;
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
os << tabs << "{\n";
if (argument.is<NamedTypeConstraint *>()) {
os << tabs
<< formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar,
operandNum);
os << tabs << " auto argID = getValueID(arg);\n";
os << tabs << " if (!argID) {\n";
os << tabs
<< formatv(
" emitError({0}.getLoc(), \"operand {1} has a use before "
"def\");\n",
opVar, operandNum);
os << tabs << " }\n";
os << tabs << formatv(" {0}.push_back(argID);\n", operands);
os << " }\n";
operandNum++;
} else {
auto attr = argument.get<NamedAttribute *>();
auto newtabs = tabs.str() + " ";
emitAttributeSerialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
loc, newtabs, opVar, operands, attr->name, os);
os << newtabs
<< formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr->name);
}
os << tabs << "}\n";
}
}
/// Generates code to serializes the result of SPV_Op `op` into `os`. The
/// generated gets the ID for the type of the result (if any), the SSA-ID of
/// the result and updates `resultID` with the SSA-ID.
static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc,
StringRef tabs, StringRef opVar,
StringRef operands, StringRef resultID,
raw_ostream &os) {
if (op.getNumResults() == 1) {
StringRef resultTypeID("resultTypeID");
os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID);
os << tabs
<< formatv(
"if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
opVar, resultTypeID);
os << tabs << " return failure();\n";
os << tabs << "}\n";
os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID);
// Create an SSA result <id> for the op
os << tabs << formatv("{0} = getNextID();\n", resultID);
os << tabs
<< formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar, resultID);
os << tabs << formatv("{0}.push_back({1});\n", operands, resultID);
} else if (op.getNumResults() != 0) {
PrintFatalError(loc, "SPIR-V ops can only have zero or one result");
}
}
/// Generates code to serialize attributes of SPV_Op `op` that become
/// decorations on the `resultID` of the serialized operation `opVar` in the
/// SPIR-V binary.
static void emitDecorationSerialization(const Operator &op, StringRef tabs,
StringRef opVar, StringRef elidedAttrs,
StringRef resultID, raw_ostream &os) {
if (op.getNumResults() == 1) {
// All non-argument attributes translated into OpDecorate instruction
os << tabs << formatv("for (auto attr : {0}.getAttrs()) {{\n", opVar);
os << tabs
<< formatv(" if (llvm::any_of({0}, [&](StringRef elided)", elidedAttrs);
os << " {return attr.first.is(elided);})) {\n";
os << tabs << " continue;\n";
os << tabs << " }\n";
os << tabs
<< formatv(
" if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
opVar, resultID);
os << tabs << " return failure();\n";
os << tabs << " }\n";
os << tabs << "}\n";
}
}
/// Generates code to serialize an SPV_Op `op` into `os`.
static void emitSerializationFunction(const Record *attrClass,
const Record *record, const Operator &op,
raw_ostream &os) {
// If the record has 'autogenSerialization' set to 0, nothing to do
if (!record->getValueAsBit("autogenSerialization")) {
return;
}
StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
resultID("resultID");
os << formatv(
"template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
op.getQualCppClassName(), opVar);
os << formatv(" SmallVector<uint32_t, 4> {0};\n", operands);
os << formatv(" SmallVector<StringRef, 2> {0};\n", elidedAttrs);
// Serialize result information.
if (op.getNumResults() == 1) {
os << formatv(" uint32_t {0} = 0;\n", resultID);
emitResultSerialization(op, record->getLoc(), " ", opVar, operands,
resultID, os);
}
// Process arguments.
emitOperandSerialization(op, record->getLoc(), " ", opVar, operands,
elidedAttrs, os);
if (record->isSubClassOf("SPV_ExtInstOp")) {
os << formatv(" encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n",
opVar, record->getValueAsString("extendedInstSetName"),
record->getValueAsInt("extendedInstOpcode"), operands);
} else {
os << formatv(" encodeInstructionInto("
"functions, spirv::getOpcode<{0}>(), {1});\n",
op.getQualCppClassName(), operands);
}
// Process decorations.
emitDecorationSerialization(op, " ", opVar, elidedAttrs, resultID, os);
os << " return success();\n";
os << "}\n\n";
}
/// Generates the prologue for the function that dispatches the serialization of
/// the operation `opVar` based on its opcode.
static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
os << formatv(
"LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
"*{0}) {{\n ",
opVar);
}
/// Generates the body of the dispatch function. This function generates the
/// check that if satisfied, will call the serialization function generated for
/// the `op`.
static void emitSerializationDispatch(const Operator &op, StringRef tabs,
StringRef opVar, raw_ostream &os) {
os << tabs
<< formatv("if (isa<{0}>({1})) {{\n", op.getQualCppClassName(), opVar);
os << tabs
<< formatv(" return processOp(cast<{0}>({1}));\n",
op.getQualCppClassName(), opVar);
os << tabs << "} else";
}
/// Generates the epilogue for the function that dispatches the serialization of
/// the operation.
static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
os << " {\n";
os << formatv(
" return {0}->emitError(\"unhandled operation serialization\");\n",
opVar);
os << " }\n";
os << " return success();\n";
os << "}\n\n";
}
/// Generates code to deserialize the attribute of a SPV_Op into `os`. The
/// generated code reads the `words` of the serialized instruction at
/// position `wordIndex` and adds the deserialized attribute into `attrList`.
static void emitAttributeDeserialization(const Attribute &attr,
ArrayRef<SMLoc> loc, StringRef tabs,
StringRef attrList, StringRef attrName,
StringRef words, StringRef wordIndex,
raw_ostream &os) {
if (attr.getAttrDefName() == "I32ArrayAttr") {
os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words);
os << tabs
<< formatv(
" "
"attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
";\n",
words, wordIndex);
os << tabs << "}\n";
os << tabs
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"opBuilder.getArrayAttr(attrListElems)));\n",
attrList, attrName);
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
os << tabs
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
attrList, attrName, words, wordIndex);
} else {
PrintFatalError(
loc, llvm::Twine(
"unhandled attribute type in deserialization generation : '") +
attr.getAttrDefName() + llvm::Twine("'"));
}
}
/// Generates the code to deserialize the result of an SPV_Op `op` into
/// `os`. The generated code gets the type of the result specified at
/// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
/// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
/// respectively.
static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
StringRef tabs, StringRef words,
StringRef wordIndex,
StringRef resultTypes, StringRef valueID,
raw_ostream &os) {
// Deserialize result information if it exists
if (op.getNumResults() == 1) {
os << tabs << "{\n";
os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words);
os << tabs
<< formatv(
" return emitError(unknownLoc, \"expected result type <id> "
"while deserializing {0}\");\n",
op.getQualCppClassName());
os << tabs << " }\n";
os << tabs << formatv(" auto ty = getType({0}[{1}]);\n", words, wordIndex);
os << tabs << " if (!ty) {\n";
os << tabs
<< formatv(
" return emitError(unknownLoc, \"unknown type result <id> : "
"\") << {0}[{1}];\n",
words, wordIndex);
os << tabs << " }\n";
os << tabs << formatv(" {0}.push_back(ty);\n", resultTypes);
os << tabs << formatv(" {0}++;\n", wordIndex);
os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words);
os << tabs
<< formatv(
" return emitError(unknownLoc, \"expected result <id> while "
"deserializing {0}\");\n",
op.getQualCppClassName());
os << tabs << " }\n";
os << tabs << "}\n";
os << tabs << formatv("{0} = {1}[{2}++];\n", valueID, words, wordIndex);
} else if (op.getNumResults() != 0) {
PrintFatalError(loc, "SPIR-V ops can have only zero or one result");
}
}
/// Generates the code to deserialize the operands of an SPV_Op `op` into
/// `os`. The generated code reads the `words` of the binary instruction, from
/// position `wordIndex` to the end, and either gets the Value corresponding to
/// the ID encoded, or deserializes the attributes encoded. The parsed
/// operand(attribute) is added to the `operands` list or `attributes` list.
static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
StringRef tabs, StringRef words,
StringRef wordIndex, StringRef operands,
StringRef attributes, raw_ostream &os) {
// Process operands/attributes
unsigned operandNum = 0;
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
if (valueArg->isVariadic()) {
if (i != e - 1) {
PrintFatalError(loc,
"SPIR-V ops can have Variadic<..> argument only if "
"it's the last argument");
}
os << tabs
<< formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);
} else {
os << tabs << formatv("if ({0} < {1}.size())", wordIndex, words);
}
os << " {\n";
os << tabs
<< formatv(" auto arg = getValue({0}[{1}]);\n", words, wordIndex);
os << tabs << " if (!arg) {\n";
os << tabs
<< formatv(
" return emitError(unknownLoc, \"unknown result <id> : \") "
"<< {0}[{1}];\n",
words, wordIndex);
os << tabs << " }\n";
os << tabs << formatv(" {0}.push_back(arg);\n", operands);
if (!valueArg->isVariadic()) {
os << tabs << formatv(" {0}++;\n", wordIndex);
}
operandNum++;
os << tabs << "}\n";
} else {
os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words);
auto attr = argument.get<NamedAttribute *>();
auto newtabs = tabs.str() + " ";
emitAttributeDeserialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
loc, newtabs, attributes, attr->name, words, wordIndex, os);
os << " }\n";
}
}
os << tabs << formatv("if ({0} != {1}.size()) {{\n", wordIndex, words);
os << tabs
<< formatv(
" return emitError(unknownLoc, \"found more operands than "
"expected when deserializing {0}, only \") << {1} << \" of \" << "
"{2}.size() << \" processed\";\n",
op.getQualCppClassName(), wordIndex, words);
os << tabs << "}\n\n";
}
/// Generates code to update the `attributes` vector with the attributes
/// obtained from parsing the decorations in the SPIR-V binary associated with
/// an <id> `valueID`
static void emitDecorationDeserialization(const Operator &op, StringRef tabs,
StringRef valueID,
StringRef attributes,
raw_ostream &os) {
// Import decorations parsed
if (op.getNumResults() == 1) {
os << tabs << formatv("if (decorations.count({0})) {{\n", valueID);
os << tabs
<< formatv(" auto attrs = decorations[{0}].getAttrs();\n", valueID);
os << tabs
<< formatv(" {0}.append(attrs.begin(), attrs.end());\n", attributes);
os << tabs << "}\n";
}
}
/// Generates code to deserialize an SPV_Op `op` into `os`.
static void emitDeserializationFunction(const Record *attrClass,
const Record *record,
const Operator &op, raw_ostream &os) {
// If the record has 'autogenSerialization' set to 0, nothing to do
if (!record->getValueAsBit("autogenSerialization")) {
return;
}
StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
wordIndex("wordIndex"), opVar("op"), operands("operands"),
attributes("attributes");
os << formatv("template <> "
"LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
"uint32_t> {1}) {{\n",
op.getQualCppClassName(), words);
os << formatv(" SmallVector<Type, 1> {0};\n", resultTypes);
os << formatv(" size_t {0} = 0; (void){0};\n", wordIndex);
os << formatv(" uint32_t {0} = 0; (void){0};\n", valueID);
// Deserialize result information
emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex,
resultTypes, valueID, os);
os << formatv(" SmallVector<Value *, 4> {0};\n", operands);
os << formatv(" SmallVector<NamedAttribute, 4> {0};\n", attributes);
// Operand deserialization
emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex,
operands, attributes, os);
os << formatv(
" auto {1} = opBuilder.create<{0}>(unknownLoc, {2}, {3}, {4}); "
"(void){1};\n",
op.getQualCppClassName(), opVar, resultTypes, operands, attributes);
if (op.getNumResults() == 1) {
os << formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar);
}
// Decorations
emitDecorationDeserialization(op, " ", valueID, attributes, os);
os << " return success();\n";
os << "}\n\n";
}
/// Generates the prologue for the function that dispatches the deserialization
/// based on the `opcode`.
static void initDispatchDeserializationFn(StringRef opcode, StringRef words,
raw_ostream &os) {
os << formatv(
"LogicalResult "
"Deserializer::dispatchToAutogenDeserialization(spirv::Opcode {0}, "
"ArrayRef<uint32_t> {1}) {{\n",
opcode, words);
os << formatv(" switch ({0}) {{\n", opcode);
}
/// Generates the body of the dispatch function, by generating the case label
/// for an opcode and the call to the method to perform the deserialization.
static void emitDeserializationDispatch(const Operator &op, const Record *def,
StringRef tabs, StringRef words,
raw_ostream &os) {
os << tabs
<< formatv("case spirv::Opcode::{0}:\n",
def->getValueAsString("spirvOpName"));
os << tabs
<< formatv(" return processOp<{0}>({1});\n", op.getQualCppClassName(),
words);
}
/// Generates the epilogue for the function that dispatches the deserialization
/// of the operation.
static void finalizeDispatchDeserializationFn(StringRef opcode,
raw_ostream &os) {
os << " default:\n";
os << " ;\n";
os << " }\n";
StringRef opcodeVar("opcodeString");
os << formatv(" auto {0} = spirv::stringifyOpcode({1});\n", opcodeVar,
opcode);
os << formatv(" if (!{0}.empty()) {{\n", opcodeVar);
os << formatv(" return emitError(unknownLoc, \"unhandled deserialization "
"of \") << {0};\n",
opcodeVar);
os << " } else {\n";
os << formatv(" return emitError(unknownLoc, \"unhandled opcode \") << "
"static_cast<uint32_t>({0});\n",
opcode);
os << " }\n";
os << "}\n";
}
static void initExtendedSetDeserializationDispatch(StringRef extensionSetName,
StringRef instructionID,
StringRef words,
raw_ostream &os) {
os << formatv("LogicalResult "
"Deserializer::dispatchToExtensionSetAutogenDeserialization("
"StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
extensionSetName, instructionID, words);
}
static void
emitExtendedSetDeserializationDispatch(const RecordKeeper &recordKeeper,
raw_ostream &os) {
StringRef extensionSetName("extensionSetName"),
instructionID("instructionID"), words("words");
// First iterate over all ops derived from SPV_ExtensionSetOps to get all
// extensionSets.
// For each of the extensions a separate raw_string_ostream is used to
// generate code into. These are then concatenated at the end. Since
// raw_string_ostream needs a string&, use a vector to store all the string
// that are captured by reference within raw_string_ostream.
StringMap<raw_string_ostream> extensionSets;
SmallVector<std::string, 1> extensionSetNames;
initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words,
os);
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_ExtInstOp");
for (const auto *def : defs) {
if (!def->getValueAsBit("autogenSerialization")) {
continue;
}
Operator op(def);
auto setName = def->getValueAsString("extendedInstSetName");
if (!extensionSets.count(setName)) {
extensionSetNames.push_back("");
extensionSets.try_emplace(setName, extensionSetNames.back());
auto &setos = extensionSets.find(setName)->second;
setos << formatv(" if ({0} == \"{1}\") {{\n", extensionSetName, setName);
setos << formatv(" switch ({0}) {{\n", instructionID);
}
auto &setos = extensionSets.find(setName)->second;
setos << formatv(" case {0}:\n",
def->getValueAsInt("extendedInstOpcode"));
setos << formatv(" return processOp<{0}>({1});\n",
op.getQualCppClassName(), words);
}
// Append the dispatch code for all the extended sets.
for (auto &extensionSet : extensionSets) {
os << extensionSet.second.str();
os << " default:\n";
os << formatv(
" return emitError(unknownLoc, \"unhandled deserializations of "
"\") << {0} << \" from extension set \" << {1};\n",
instructionID, extensionSetName);
os << " }\n";
os << " }\n";
}
os << formatv(" return emitError(unknownLoc, \"unhandled deserialization of "
"extended instruction set {0}\");\n",
extensionSetName);
os << "}\n";
}
/// Emits all the autogenerated serialization/deserializations functions for the
/// SPV_Ops.
static bool emitSerializationFns(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os);
std::string dSerFnString, dDesFnString, serFnString, deserFnString,
utilsString;
raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
serFn(serFnString), deserFn(deserFnString), utils(utilsString);
auto attrClass = recordKeeper.getClass("Attr");
// Emit the serialization and deserialization functions simultaneously.
declareOpcodeFn(utils);
StringRef opVar("op");
StringRef opcode("opcode"), words("words");
// Handle the SPIR-V ops.
initDispatchSerializationFn(opVar, dSerFn);
initDispatchDeserializationFn(opcode, words, dDesFn);
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
for (const auto *def : defs) {
Operator op(def);
emitSerializationFunction(attrClass, def, op, serFn);
emitDeserializationFunction(attrClass, def, op, deserFn);
if (def->getValueAsBit("hasOpcode") || def->isSubClassOf("SPV_ExtInstOp")) {
emitSerializationDispatch(op, " ", opVar, dSerFn);
}
if (def->getValueAsBit("hasOpcode")) {
emitGetOpcodeFunction(def, op, utils);
emitDeserializationDispatch(op, def, " ", words, dDesFn);
}
}
finalizeDispatchSerializationFn(opVar, dSerFn);
finalizeDispatchDeserializationFn(opcode, dDesFn);
emitExtendedSetDeserializationDispatch(recordKeeper, dDesFn);
os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n";
os << utils.str();
os << "#endif // GET_SPIRV_SERIALIZATION_UTILS\n\n";
os << "#ifdef GET_SERIALIZATION_FNS\n\n";
os << serFn.str();
os << dSerFn.str();
os << "#endif // GET_SERIALIZATION_FNS\n\n";
os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
os << deserFn.str();
os << dDesFn.str();
os << "#endif // GET_DESERIALIZATION_FNS\n\n";
return false;
}
//===----------------------------------------------------------------------===//
// Op Utils AutoGen
//===----------------------------------------------------------------------===//
static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
os << formatv("template <typename EnumClass> inline constexpr StringRef "
"attributeName();\n");
}
static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) {
os << "template <typename EnumClass> using SymbolizeFnTy = "
"llvm::Optional<EnumClass> (*)(StringRef);\n";
os << "template <typename EnumClass> inline constexpr "
"SymbolizeFnTy<EnumClass> symbolizeEnum();\n";
}
static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
raw_ostream &os) {
auto enumName = enumAttr.getEnumClassName();
os << formatv("template <> inline StringRef attributeName<{0}>() {{\n",
enumName);
os << " "
<< formatv("static constexpr const char attrName[] = \"{0}\";\n",
mlir::convertToSnakeCase(enumName));
os << " return attrName;\n";
os << "}\n";
}
static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr,
raw_ostream &os) {
auto enumName = enumAttr.getEnumClassName();
auto strToSymFnName = enumAttr.getStringToSymbolFnName();
os << formatv(
"template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>() {{\n",
enumName);
os << " return " << strToSymFnName << ";\n";
os << "}\n";
}
static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("SPIR-V Op Utilites", os);
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
os << "#ifndef SPIRV_OP_UTILS_H_\n";
os << "#define SPIRV_OP_UTILS_H_\n";
emitEnumGetAttrNameFnDecl(os);
emitEnumGetSymbolizeFnDecl(os);
for (const auto *def : defs) {
EnumAttr enumAttr(*def);
emitEnumGetAttrNameFnDefn(enumAttr, os);
emitEnumGetSymbolizeFnDefn(enumAttr, os);
}
os << "#endif // SPIRV_OP_UTILS_H\n";
return false;
}
//===----------------------------------------------------------------------===//
// Hook Registration
//===----------------------------------------------------------------------===//
static mlir::GenRegistration genSerialization(
"gen-spirv-serialization",
"Generate SPIR-V (de)serialization utilities and functions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitSerializationFns(records, os);
});
static mlir::GenRegistration
genOpUtils("gen-spirv-op-utils",
"Generate SPIR-V operation utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitOpUtils(records, os);
});