blob: 575d995bf45657efbe8b97f92cc167fbddd9d993 [file] [log] [blame]
//===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the MLIR SPIR-V module to SPIR-V binary seralization.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
/// Returns the word-count-prefixed opcode for an SPIR-V instruction.
static inline uint32_t getPrefixedOpcode(uint32_t wordCount,
spirv::Opcode opcode) {
assert(((wordCount >> 16) == 0) && "word count out of range!");
return (wordCount << 16) | static_cast<uint32_t>(opcode);
}
/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
/// the given `binary` vector.
static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
spirv::Opcode op,
ArrayRef<uint32_t> operands) {
uint32_t wordCount = 1 + operands.size();
binary.push_back(getPrefixedOpcode(wordCount, op));
if (!operands.empty()) {
binary.append(operands.begin(), operands.end());
}
return success();
}
/// Encodes an SPIR-V `literal` string into the given `binary` vector.
static LogicalResult encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
StringRef literal) {
// We need to encode the literal and the null termination.
auto encodingSize = literal.size() / 4 + 1;
auto bufferStartSize = binary.size();
binary.resize(bufferStartSize + encodingSize, 0);
std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size());
return success();
}
namespace {
/// A SPIR-V module serializer.
///
/// A SPIR-V binary module is a single linear stream of instructions; each
/// instruction is composed of 32-bit words with the layout:
///
/// | <word-count>|<opcode> | <operand> | <operand> | ... |
/// | <------ word -------> | <-- word --> | <-- word --> | ... |
///
/// For the first word, the 16 high-order bits are the word count of the
/// instruction, the 16 low-order bits are the opcode enumerant. The
/// instructions then belong to different sections, which must be laid out in
/// the particular order as specified in "2.4 Logical Layout of a Module" of
/// the SPIR-V spec.
class Serializer {
public:
/// Creates a serializer for the given SPIR-V `module`.
explicit Serializer(spirv::ModuleOp module);
/// Serializes the remembered SPIR-V module.
LogicalResult serialize();
/// Collects the final SPIR-V `binary`.
void collect(SmallVectorImpl<uint32_t> &binary);
private:
// Note that there are two main categories of methods in this class:
// * process*() methods are meant to fully serialize a SPIR-V module entity
// (header, type, op, etc.). They update internal vectors containing
// different binary sections. They are not meant to be called except the
// top-level serialization loop.
// * prepare*() methods are meant to be helpers that prepare for serializing
// certain entity. They may or may not update internal vectors containing
// different binary sections. They are meant to be called among themselves
// or by other process*() methods for subtasks.
//===--------------------------------------------------------------------===//
// <id>
//===--------------------------------------------------------------------===//
// Note that it is illegal to use id <0> in SPIR-V binary module. Various
// methods in this class, if using SPIR-V word (uint32_t) as interface,
// check or return id <0> to indicate error in processing.
/// Consumes the next unused <id>. This method will never return 0.
uint32_t getNextID() { return nextID++; }
//===--------------------------------------------------------------------===//
// Module structure
//===--------------------------------------------------------------------===//
LogicalResult processMemoryModel();
LogicalResult processConstantOp(spirv::ConstantOp op);
uint32_t findFunctionID(StringRef fnName) const {
return funcIDMap.lookup(fnName);
}
uint32_t findVariableID(StringRef varName) const {
return globalVarIDMap.lookup(varName);
}
/// Emit OpName for the given `resultID`.
LogicalResult processName(uint32_t resultID, StringRef name);
/// Processes a SPIR-V function op.
LogicalResult processFuncOp(FuncOp op);
/// Process a SPIR-V GlobalVariableOp
LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
/// Process attributes that translate to decorations on the result <id>
LogicalResult processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr);
template <typename DType>
LogicalResult processTypeDecoration(Location loc, DType type,
uint32_t resultId) {
return emitError(loc, "unhandled decoraion for type:") << type;
}
//===--------------------------------------------------------------------===//
// Types
//===--------------------------------------------------------------------===//
uint32_t findTypeID(Type type) const { return typeIDMap.lookup(type); }
Type getVoidType() { return mlirBuilder.getNoneType(); }
bool isVoidType(Type type) const { return type.isa<NoneType>(); }
/// Main dispatch method for serializing a type. The result <id> of the
/// serialized type will be returned as `typeID`.
LogicalResult processType(Location loc, Type type, uint32_t &typeID);
/// Method for preparing basic SPIR-V type serialization. Returns the type's
/// opcode and operands for the instruction via `typeEnum` and `operands`.
LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands);
LogicalResult prepareFunctionType(Location loc, FunctionType type,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands);
//===--------------------------------------------------------------------===//
// Constant
//===--------------------------------------------------------------------===//
uint32_t findConstantID(Attribute value) const {
return constIDMap.lookup(value);
}
/// Main dispatch method for processing a constant with the given `constType`
/// and `valueAttr`. `constType` is needed here because we can interpret the
/// `valueAttr` as a different type than the type of `valueAttr` itself; for
/// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
/// constants. If `isSpec` is true, then the constant will be serialized as
/// a specialization constant.
uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr,
bool isSpec);
/// Prepares bool ElementsAttr serialization. This method updates `opcode`
/// with a proper OpConstant* instruction and pushes literal values for the
/// constant to `operands`.
LogicalResult prepareBoolVectorConstant(Location loc,
DenseIntElementsAttr elementsAttr,
bool isSpec, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares int ElementsAttr serialization. This method updates `opcode` with
/// a proper OpConstant* instruction and pushes literal values for the
/// constant to `operands`.
LogicalResult prepareIntVectorConstant(Location loc,
DenseIntElementsAttr elementsAttr,
bool isSpec, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares float ElementsAttr serialization. This method updates `opcode`
/// with a proper OpConstant* instruction and pushes literal values for the
/// constant to `operands`.
LogicalResult prepareFloatVectorConstant(Location loc,
DenseFPElementsAttr elementsAttr,
bool isSpec, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec);
uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec);
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec);
//===--------------------------------------------------------------------===//
// Operations
//===--------------------------------------------------------------------===//
uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); }
/// Process spv.addressOf operations.
LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
/// Main dispatch method for serializing an operation.
LogicalResult processOperation(Operation *op);
/// Method to dispatch to the serialization function for an operation in
/// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec.
/// This is auto-generated from ODS. Dispatch is handled for all operations
/// in SPIR-V dialect that have hasOpcode == 1.
LogicalResult dispatchToAutogenSerialization(Operation *op);
/// Method to serialize an operation in the SPIR-V dialect that is a mirror of
/// an instruction in the SPIR-V spec. This is auto generated if hasOpcode ==
/// 1 and autogenSerialization == 1 in ODS.
template <typename OpTy> LogicalResult processOp(OpTy op) {
return op.emitError("unsupported op serialization");
}
private:
/// The SPIR-V module to be serialized.
spirv::ModuleOp module;
/// An MLIR builder for getting MLIR constructs.
mlir::Builder mlirBuilder;
/// The next available result <id>.
uint32_t nextID = 1;
// The following are for different SPIR-V instruction sections. They follow
// the logical layout of a SPIR-V module.
SmallVector<uint32_t, 4> capabilities;
SmallVector<uint32_t, 0> extensions;
SmallVector<uint32_t, 0> extendedSets;
SmallVector<uint32_t, 3> memoryModel;
SmallVector<uint32_t, 0> entryPoints;
SmallVector<uint32_t, 4> executionModes;
// TODO(antiagainst): debug instructions
SmallVector<uint32_t, 0> names;
SmallVector<uint32_t, 0> decorations;
SmallVector<uint32_t, 0> typesGlobalValues;
SmallVector<uint32_t, 0> functions;
/// Map from type used in SPIR-V module to their <id>s
DenseMap<Type, uint32_t> typeIDMap;
/// Map from constant values to their <id>s
DenseMap<Attribute, uint32_t> constIDMap;
/// Map from FuncOps name to <id>s.
llvm::StringMap<uint32_t> funcIDMap;
/// Map from GlobalVariableOps name to <id>s
llvm::StringMap<uint32_t> globalVarIDMap;
/// Map from results of normal operations to their <id>s
DenseMap<Value *, uint32_t> valueIDMap;
};
} // namespace
Serializer::Serializer(spirv::ModuleOp module)
: module(module), mlirBuilder(module.getContext()) {}
LogicalResult Serializer::serialize() {
if (failed(module.verify()))
return failure();
// TODO(antiagainst): handle the other sections
processMemoryModel();
// Iterate over the module body to serialze it. Assumptions are that there is
// only one basic block in the moduleOp
for (auto &op : module.getBlock()) {
if (failed(processOperation(&op))) {
return failure();
}
}
return success();
}
void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
extensions.size() + extendedSets.size() +
memoryModel.size() + entryPoints.size() +
executionModes.size() + decorations.size() +
typesGlobalValues.size() + functions.size();
binary.clear();
binary.reserve(moduleSize);
spirv::appendModuleHeader(binary, nextID);
binary.append(capabilities.begin(), capabilities.end());
binary.append(extensions.begin(), extensions.end());
binary.append(extendedSets.begin(), extendedSets.end());
binary.append(memoryModel.begin(), memoryModel.end());
binary.append(entryPoints.begin(), entryPoints.end());
binary.append(executionModes.begin(), executionModes.end());
binary.append(names.begin(), names.end());
binary.append(decorations.begin(), decorations.end());
binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
binary.append(functions.begin(), functions.end());
}
//===----------------------------------------------------------------------===//
// Module structure
//===----------------------------------------------------------------------===//
LogicalResult Serializer::processMemoryModel() {
uint32_t mm = module.getAttrOfType<IntegerAttr>("memory_model").getInt();
uint32_t am = module.getAttrOfType<IntegerAttr>("addressing_model").getInt();
return encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel,
{am, mm});
}
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value(),
op.is_spec_const())) {
valueIDMap[op.getResult()] = resultID;
return success();
}
return failure();
}
LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr) {
auto attrName = attr.first.strref();
auto decorationName = mlir::convertToCamelCase(attrName, true);
auto decoration = spirv::symbolizeDecoration(decorationName);
if (!decoration) {
return emitError(
loc, "non-argument attributes expected to have snake-case-ified "
"decoration name, unhandled attribute with name : ")
<< attrName;
}
SmallVector<uint32_t, 1> args;
args.push_back(resultID);
args.push_back(static_cast<uint32_t>(decoration.getValue()));
switch (decoration.getValue()) {
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Binding:
if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) {
args.push_back(intAttr.getValue().getZExtValue());
break;
}
return emitError(loc, "expected integer attribute for ") << attrName;
case spirv::Decoration::BuiltIn:
if (auto strAttr = attr.second.dyn_cast<StringAttr>()) {
auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
if (enumVal) {
args.push_back(static_cast<uint32_t>(enumVal.getValue()));
break;
}
return emitError(loc, "invalid ")
<< attrName << " attribute " << strAttr.getValue();
}
return emitError(loc, "expected string attribute for ") << attrName;
default:
return emitError(loc, "unhandled decoration ") << decorationName;
}
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
}
LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
SmallVector<uint32_t, 4> nameOperands;
nameOperands.push_back(resultID);
if (failed(encodeStringLiteralInto(nameOperands, name))) {
return failure();
}
return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
}
namespace {
template <>
LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
Location loc, spirv::ArrayType type, uint32_t resultID) {
if (type.hasLayout()) {
// OpDecorate %arrayTypeSSA ArrayStride strideLiteral
SmallVector<uint32_t, 3> args;
args.push_back(resultID);
args.push_back(static_cast<uint32_t>(spirv::Decoration::ArrayStride));
args.push_back(type.getArrayStride());
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
}
return success();
}
} // namespace
LogicalResult Serializer::processFuncOp(FuncOp op) {
uint32_t fnTypeID = 0;
// Generate type of the function.
processType(op.getLoc(), op.getType(), fnTypeID);
// Add the function definition.
SmallVector<uint32_t, 4> operands;
uint32_t resTypeID = 0;
auto resultTypes = op.getType().getResults();
if (resultTypes.size() > 1) {
return emitError(op.getLoc(),
"cannot serialize function with multiple return types");
}
if (failed(processType(op.getLoc(),
(resultTypes.empty() ? getVoidType() : resultTypes[0]),
resTypeID))) {
return failure();
}
operands.push_back(resTypeID);
auto funcID = getNextID();
funcIDMap[op.getName()] = funcID;
operands.push_back(funcID);
// TODO : Support other function control options.
operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
operands.push_back(fnTypeID);
encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands);
// Add function name.
if (failed(processName(funcID, op.getName()))) {
return failure();
}
// Declare the parameters.
for (auto arg : op.getArguments()) {
uint32_t argTypeID = 0;
if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) {
return failure();
}
auto argValueID = getNextID();
valueIDMap[arg] = argValueID;
encodeInstructionInto(functions, spirv::Opcode::OpFunctionParameter,
{argTypeID, argValueID});
}
// Process the body.
if (op.isExternal()) {
return emitError(op.getLoc(), "external function is unhandled");
}
for (auto &b : op) {
for (auto &op : b) {
if (failed(processOperation(&op))) {
return failure();
}
}
}
// Insert Function End.
return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {});
}
LogicalResult
Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
// Get TypeID.
uint32_t resultTypeID = 0;
SmallVector<StringRef, 4> elidedAttrs;
if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
return failure();
}
elidedAttrs.push_back("type");
SmallVector<uint32_t, 4> operands;
operands.push_back(resultTypeID);
auto resultID = getNextID();
// Encode the name.
auto varName = varOp.sym_name();
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
if (failed(processName(resultID, varName))) {
return failure();
}
globalVarIDMap[varName] = resultID;
operands.push_back(resultID);
// Encode StorageClass.
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
// Encode initialization.
if (auto initializer = varOp.initializer()) {
auto initializerID = findVariableID(initializer.getValue());
if (!initializerID) {
return emitError(varOp.getLoc(),
"invalid usage of undefined variable as initializer");
}
operands.push_back(initializerID);
elidedAttrs.push_back("initializer");
}
if (failed(encodeInstructionInto(functions, spirv::Opcode::OpVariable,
operands))) {
elidedAttrs.push_back("initializer");
return failure();
}
// Encode decorations.
for (auto attr : varOp.getAttrs()) {
if (llvm::any_of(elidedAttrs,
[&](StringRef elided) { return attr.first.is(elided); })) {
continue;
}
if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
return failure();
}
}
return success();
}
//===----------------------------------------------------------------------===//
// Type
//===----------------------------------------------------------------------===//
LogicalResult Serializer::processType(Location loc, Type type,
uint32_t &typeID) {
typeID = findTypeID(type);
if (typeID) {
return success();
}
typeID = getNextID();
SmallVector<uint32_t, 4> operands;
operands.push_back(typeID);
auto typeEnum = spirv::Opcode::OpTypeVoid;
if ((type.isa<FunctionType>() &&
succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
operands))) ||
succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands))) {
typeIDMap[type] = typeID;
return encodeInstructionInto(typesGlobalValues, typeEnum, operands);
}
return failure();
}
LogicalResult
Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands) {
if (isVoidType(type)) {
typeEnum = spirv::Opcode::OpTypeVoid;
return success();
}
if (auto intType = type.dyn_cast<IntegerType>()) {
if (intType.getWidth() == 1) {
typeEnum = spirv::Opcode::OpTypeBool;
return success();
}
typeEnum = spirv::Opcode::OpTypeInt;
operands.push_back(intType.getWidth());
// TODO(antiagainst): support unsigned integers
operands.push_back(1);
return success();
}
if (auto floatType = type.dyn_cast<FloatType>()) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(floatType.getWidth());
return success();
}
if (auto vectorType = type.dyn_cast<VectorType>()) {
uint32_t elementTypeID = 0;
if (failed(processType(loc, vectorType.getElementType(), elementTypeID))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeVector;
operands.push_back(elementTypeID);
operands.push_back(vectorType.getNumElements());
return success();
}
if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
typeEnum = spirv::Opcode::OpTypeArray;
uint32_t elementTypeID = 0;
if (failed(processType(loc, arrayType.getElementType(), elementTypeID))) {
return failure();
}
operands.push_back(elementTypeID);
if (auto elementCountID = prepareConstantInt(
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
/*isSpec=*/false)) {
operands.push_back(elementCountID);
}
return processTypeDecoration(loc, arrayType, resultID);
}
if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
uint32_t pointeeTypeID = 0;
if (failed(processType(loc, ptrType.getPointeeType(), pointeeTypeID))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypePointer;
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
return success();
}
// TODO(ravishankarm) : Handle other types.
return emitError(loc, "unhandled type in serialization: ") << type;
}
LogicalResult
Serializer::prepareFunctionType(Location loc, FunctionType type,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands) {
typeEnum = spirv::Opcode::OpTypeFunction;
assert(type.getNumResults() <= 1 &&
"Serialization supports only a single return value");
uint32_t resultID = 0;
if (failed(processType(
loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
resultID))) {
return failure();
}
operands.push_back(resultID);
for (auto &res : type.getInputs()) {
uint32_t argTypeID = 0;
if (failed(processType(loc, res, argTypeID))) {
return failure();
}
operands.push_back(argTypeID);
}
return success();
}
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//
uint32_t Serializer::prepareConstant(Location loc, Type constType,
Attribute valueAttr, bool isSpec) {
if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
return prepareConstantFp(loc, floatAttr, isSpec);
}
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
return prepareConstantInt(loc, intAttr, isSpec);
}
if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
return prepareConstantBool(loc, boolAttr, isSpec);
}
// This is a composite literal. We need to handle each component separately
// and then emit an OpConstantComposite for the whole.
if (auto id = findConstantID(valueAttr)) {
return id;
}
uint32_t typeID = 0;
if (failed(processType(loc, constType, typeID))) {
return 0;
}
auto resultID = getNextID();
spirv::Opcode opcode = spirv::Opcode::OpNop;
SmallVector<uint32_t, 4> operands;
operands.push_back(typeID);
operands.push_back(resultID);
if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
if (vectorAttr.getType().getElementType().isInteger(1)) {
if (failed(prepareBoolVectorConstant(loc, vectorAttr, isSpec, opcode,
operands)))
return 0;
} else if (failed(prepareIntVectorConstant(loc, vectorAttr, isSpec, opcode,
operands)))
return 0;
} else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
if (failed(prepareFloatVectorConstant(loc, vectorAttr, isSpec, opcode,
operands)))
return 0;
} else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
operands.reserve(arrayAttr.size() + 2);
auto elementType = constType.cast<spirv::ArrayType>().getElementType();
for (Attribute elementAttr : arrayAttr)
if (auto elementID =
prepareConstant(loc, elementType, elementAttr, isSpec)) {
operands.push_back(elementID);
} else {
return 0;
}
} else {
emitError(loc, "cannot serialize attribute: ") << valueAttr;
return 0;
}
encodeInstructionInto(typesGlobalValues, opcode, operands);
constIDMap[valueAttr] = resultID;
return resultID;
}
LogicalResult Serializer::prepareBoolVectorConstant(
Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
"ElementsAttr");
assert(type.getElementType().isInteger(1) && "must be bool ElementsAttr");
auto count = type.getNumElements();
// Operands for constructing the SPIR-V OpConstant* instruction
operands.reserve(count + 2);
// For splat cases, we don't need to loop over all elements, especially when
// the splat value is zero.
if (elementsAttr.isSplat()) {
// We can use OpConstantNull if this bool ElementsAttr is splatting false.
if (!isSpec && !elementsAttr.getSplatValue<bool>()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
if (auto id = prepareConstantBool(
loc, elementsAttr.getSplatValue<BoolAttr>(), isSpec)) {
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
return failure();
}
// Otherwise, we need to process each element and compose them with
// OpConstantComposite.
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
for (auto boolAttr : elementsAttr.getValues<BoolAttr>()) {
// We are constructing an BoolAttr for each value here. But given that
// we only use ElementsAttr for vectors with no more than 4 elements, it
// should be fine here.
if (auto elementID = prepareConstantBool(loc, boolAttr, isSpec)) {
operands.push_back(elementID);
} else {
return failure();
}
}
return success();
}
LogicalResult Serializer::prepareIntVectorConstant(
Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
"ElementsAttr");
assert(!type.getElementType().isInteger(1) &&
"must be non-bool ElementsAttr");
auto count = type.getNumElements();
// Operands for constructing the SPIR-V OpConstant* instruction
operands.reserve(count + 2);
// For splat cases, we don't need to loop over all elements, especially when
// the splat value is zero.
if (elementsAttr.isSplat()) {
auto splatAttr = elementsAttr.getSplatValue<IntegerAttr>();
// We can use OpConstantNull if this int ElementsAttr is splatting 0.
if (!isSpec && splatAttr.getValue().isNullValue()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
if (auto id = prepareConstantInt(loc, splatAttr, isSpec)) {
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
return failure();
}
// Otherwise, we need to process each element and compose them with
// OpConstantComposite.
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
for (auto intAttr : elementsAttr.getValues<IntegerAttr>()) {
// We are constructing an IntegerAttr for each value here. But given that
// we only use ElementsAttr for vectors with no more than 4 elements, it
// should be fine here.
// TODO(antiagainst): revisit this if special extensions enabling large
// vectors are supported.
if (auto elementID = prepareConstantInt(loc, intAttr, isSpec)) {
operands.push_back(elementID);
} else {
return failure();
}
}
return success();
}
LogicalResult Serializer::prepareFloatVectorConstant(
Location loc, DenseFPElementsAttr elementsAttr, bool isSpec,
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
"ElementsAttr");
auto count = type.getNumElements();
operands.reserve(count + 2);
if (elementsAttr.isSplat()) {
FloatAttr splatAttr = elementsAttr.getSplatValue<FloatAttr>();
if (!isSpec && splatAttr.getValue().isZero()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
if (auto id = prepareConstantFp(loc, splatAttr, isSpec)) {
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
return failure();
}
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
: spirv::Opcode::OpConstantComposite;
for (auto fpAttr : elementsAttr.getValues<FloatAttr>()) {
if (auto elementID = prepareConstantFp(loc, fpAttr, isSpec)) {
operands.push_back(elementID);
} else {
return failure();
}
}
return success();
}
uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
bool isSpec) {
if (auto id = findConstantID(boolAttr)) {
return id;
}
// Process the type for this bool literal
uint32_t typeID = 0;
if (failed(processType(loc, boolAttr.getType(), typeID))) {
return 0;
}
auto resultID = getNextID();
auto opcode = boolAttr.getValue()
? (isSpec ? spirv::Opcode::OpSpecConstantTrue
: spirv::Opcode::OpConstantTrue)
: (isSpec ? spirv::Opcode::OpSpecConstantFalse
: spirv::Opcode::OpConstantFalse);
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
return constIDMap[boolAttr] = resultID;
}
uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
bool isSpec) {
if (auto id = findConstantID(intAttr)) {
return id;
}
// Process the type for this integer literal
uint32_t typeID = 0;
if (failed(processType(loc, intAttr.getType(), typeID))) {
return 0;
}
auto resultID = getNextID();
APInt value = intAttr.getValue();
unsigned bitwidth = value.getBitWidth();
bool isSigned = value.isSignedIntN(bitwidth);
auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
// According to SPIR-V spec, "When the type's bit width is less than 32-bits,
// the literal's value appears in the low-order bits of the word, and the
// high-order bits must be 0 for a floating-point type, or 0 for an integer
// type with Signedness of 0, or sign extended when Signedness is 1."
if (bitwidth == 32 || bitwidth == 16) {
uint32_t word = 0;
if (isSigned) {
word = static_cast<int32_t>(value.getSExtValue());
} else {
word = static_cast<uint32_t>(value.getZExtValue());
}
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
}
// According to SPIR-V spec: "When the type's bit width is larger than one
// word, the literal’s low-order words appear first."
else if (bitwidth == 64) {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words;
if (isSigned) {
words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
} else {
words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
}
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} else {
std::string valueStr;
llvm::raw_string_ostream rss(valueStr);
value.print(rss, /*isSigned*/ false);
emitError(loc, "cannot serialize ")
<< bitwidth << "-bit integer literal: " << rss.str();
return 0;
}
return constIDMap[intAttr] = resultID;
}
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
bool isSpec) {
if (auto id = findConstantID(floatAttr)) {
return id;
}
// Process the type for this float literal
uint32_t typeID = 0;
if (failed(processType(loc, floatAttr.getType(), typeID))) {
return 0;
}
auto resultID = getNextID();
APFloat value = floatAttr.getValue();
APInt intValue = value.bitcastToAPInt();
auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
if (&value.getSemantics() == &APFloat::IEEEsingle()) {
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
uint32_t word =
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else {
std::string valueStr;
llvm::raw_string_ostream rss(valueStr);
value.print(rss);
emitError(loc, "cannot serialize ")
<< floatAttr.getType() << "-typed float literal: " << rss.str();
return 0;
}
return constIDMap[floatAttr] = resultID;
}
//===----------------------------------------------------------------------===//
// Operation
//===----------------------------------------------------------------------===//
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
auto varName = addressOfOp.variable();
auto variableID = findVariableID(varName);
if (!variableID) {
return addressOfOp.emitError("unknown result <id> for variable ")
<< varName;
}
valueIDMap[addressOfOp.pointer()] = variableID;
return success();
}
LogicalResult Serializer::processOperation(Operation *op) {
// First dispatch the methods that do not directly mirror an operation from
// the SPIR-V spec
if (auto constOp = dyn_cast<spirv::ConstantOp>(op)) {
return processConstantOp(constOp);
}
if (auto fnOp = dyn_cast<FuncOp>(op)) {
return processFuncOp(fnOp);
}
if (isa<spirv::ModuleEndOp>(op)) {
return success();
}
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
return processGlobalVariableOp(varOp);
}
if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) {
return processAddressOfOp(addressOfOp);
}
return dispatchToAutogenSerialization(op);
}
namespace {
template <>
LogicalResult
Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
SmallVector<uint32_t, 4> operands;
// Add the ExectionModel.
operands.push_back(static_cast<uint32_t>(op.execution_model()));
// Add the function <id>.
auto funcID = findFunctionID(op.fn());
if (!funcID) {
return op.emitError("missing <id> for function ")
<< op.fn()
<< "; function needs to be defined before spv.EntryPoint is "
"serialized";
}
operands.push_back(funcID);
// Add the name of the function.
encodeStringLiteralInto(operands, op.fn());
// Add the interface values.
if (auto interface = op.interface()) {
for (auto var : interface.getValue()) {
auto id = findVariableID(var.cast<SymbolRefAttr>().getValue());
if (!id) {
return op.emitError("referencing undefined global variable."
"spv.EntryPoint is at the end of spv.module. All "
"referenced variables should already be defined");
}
operands.push_back(id);
}
}
return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
operands);
}
template <>
LogicalResult
Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
SmallVector<uint32_t, 4> operands;
// Add the function <id>.
auto funcID = findFunctionID(op.fn());
if (!funcID) {
return op.emitError("missing <id> for function ")
<< op.fn()
<< "; function needs to be serialized before ExecutionModeOp is "
"serialized";
}
operands.push_back(funcID);
// Add the ExecutionMode.
operands.push_back(static_cast<uint32_t>(op.execution_mode()));
// Serialize values if any.
auto values = op.values();
if (values) {
for (auto &intVal : values.getValue()) {
operands.push_back(static_cast<uint32_t>(
intVal.cast<IntegerAttr>().getValue().getZExtValue()));
}
}
return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
operands);
}
// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
// various Serializer::processOp<...>() specializations.
#define GET_SERIALIZATION_FNS
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
} // namespace
LogicalResult spirv::serialize(spirv::ModuleOp module,
SmallVectorImpl<uint32_t> &binary) {
Serializer serializer(module);
if (failed(serializer.serialize()))
return failure();
serializer.collect(binary);
return success();
}