blob: f55805dfd23abf67794fb1848f042730390c35b5 [file] [log] [blame]
//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the SPIR-V binary to MLIR SPIR-V module deseralization.
//
//===----------------------------------------------------------------------===//
#include "mlir/SPIRV/Serialization.h"
#include "SPIRVBinaryUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/SPIRV/SPIRVOps.h"
#include "mlir/SPIRV/SPIRVTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
namespace {
/// A SPIR-V module serializer.
///
/// A SPIR-V binary module is a single linear stream of instructions; each
/// instruction is composed of 32-bit words. The first word of an instruction
/// records the total number of words of that instruction using the 16
/// higher-order bits. So this deserializer uses that to get instruction
/// boundary and parse instructions and build a SPIR-V ModuleOp gradually.
///
// TODO(antiagainst): clean up created ops on errors
class Deserializer {
public:
/// Creates a deserializer for the given SPIR-V `binary` module.
/// The SPIR-V ModuleOp will be created into `context.
explicit Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context);
/// Deserializes the remembered SPIR-V binary module.
LogicalResult deserialize();
/// Collects the final SPIR-V ModuleOp.
Optional<spirv::ModuleOp> collect();
private:
/// Processes SPIR-V module header.
LogicalResult processHeader();
/// Processes a SPIR-V instruction with the given `opcode` and `operands`.
LogicalResult processInstruction(uint32_t opcode,
ArrayRef<uint32_t> operands);
LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
/// Initializes the `module` ModuleOp in this deserializer instance.
spirv::ModuleOp createModuleOp();
private:
/// The SPIR-V binary module.
ArrayRef<uint32_t> binary;
/// The current word offset into the binary module.
unsigned curOffset = 0;
/// MLIRContext to create SPIR-V ModuleOp into.
MLIRContext *context;
// TODO(antiagainst): create Location subclass for binary blob
Location unknownLoc;
/// The SPIR-V ModuleOp.
Optional<spirv::ModuleOp> module;
OpBuilder opBuilder;
};
} // namespace
Deserializer::Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context)
: binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
module(createModuleOp()),
opBuilder(module->getOperation()->getRegion(0)) {}
LogicalResult Deserializer::deserialize() {
if (failed(processHeader()))
return failure();
auto binarySize = binary.size();
curOffset = spirv::kHeaderWordCount;
while (curOffset < binarySize) {
// For each instruction, get its word count from the first word to slice it
// from the stream properly, and then dispatch to the instruction handler.
uint32_t wordCount = binary[curOffset] >> 16;
uint32_t opcode = binary[curOffset] & 0xffff;
if (wordCount == 0)
return context->emitError(unknownLoc, "word count cannot be zero");
uint32_t nextOffset = curOffset + wordCount;
if (nextOffset > binarySize)
return context->emitError(unknownLoc,
"insufficient words for the last instruction");
auto operands = binary.slice(curOffset + 1, wordCount - 1);
if (failed(processInstruction(opcode, operands)))
return failure();
curOffset = nextOffset;
}
return success();
}
Optional<spirv::ModuleOp> Deserializer::collect() { return module; }
LogicalResult Deserializer::processHeader() {
if (binary.size() < spirv::kHeaderWordCount)
return context->emitError(unknownLoc,
"SPIR-V binary module must have a 5-word header");
if (binary[0] != spirv::kMagicNumber)
return context->emitError(unknownLoc, "incorrect magic number");
// TODO(antiagainst): generator number, bound, schema
return success();
}
LogicalResult Deserializer::processInstruction(uint32_t opcode,
ArrayRef<uint32_t> operands) {
switch (opcode) {
case spirv::kOpMemoryModelOpcode:
return processMemoryModel(operands);
default:
break;
}
return context->emitError(unknownLoc, "NYI: opcode ") << opcode;
}
LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
return context->emitError(unknownLoc,
"OpMemoryModel must have two operands");
// TODO(antiagainst): use IntegerAttr-backed enum attributes to avoid the
// excessive string conversions here.
auto am = spirv::symbolizeAddressingModel(operands.front());
if (!am)
return context->emitError(unknownLoc,
"unknown addressing model for OpMemoryModel");
auto mm = spirv::symbolizeMemoryModel(operands.back());
if (!mm)
return context->emitError(unknownLoc,
"unknown memory model for OpMemoryModel");
module->setAttr(
"addressing_model",
opBuilder.getStringAttr(spirv::stringifyAddressingModel(*am)));
module->setAttr("memory_model",
opBuilder.getStringAttr(spirv::stringifyMemoryModel(*mm)));
return success();
}
spirv::ModuleOp Deserializer::createModuleOp() {
Builder builder(context);
OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
// TODO(antiagainst): use target environment to select the version
state.addAttribute("major_version", builder.getI32IntegerAttr(1));
state.addAttribute("minor_version", builder.getI32IntegerAttr(0));
spirv::ModuleOp::build(&builder, &state);
return llvm::cast<spirv::ModuleOp>(Operation::create(state));
}
Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
MLIRContext *context) {
Deserializer deserializer(binary, context);
if (failed(deserializer.deserialize()))
return llvm::None;
return deserializer.collect();
}