| // Copyright (c) 2017 Google Inc. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // Contains |
| // - SPIR-V to MARK-V encoder |
| // - MARK-V to SPIR-V decoder |
| // |
| // MARK-V is a compression format for SPIR-V binaries. It strips away |
| // non-essential information (such as result ids which can be regenerated) and |
| // uses various bit reduction techiniques to reduce the size of the binary. |
| // |
| // MarkvModel is a flatbuffers object containing a set of rules defining how |
| // compression/decompression is done (coding schemes, dictionaries). |
| |
| #include <algorithm> |
| #include <cassert> |
| #include <cstring> |
| #include <functional> |
| #include <iostream> |
| #include <list> |
| #include <memory> |
| #include <numeric> |
| #include <string> |
| #include <vector> |
| |
| #include "binary.h" |
| #include "diagnostic.h" |
| #include "enum_string_mapping.h" |
| #include "extensions.h" |
| #include "ext_inst.h" |
| #include "instruction.h" |
| #include "opcode.h" |
| #include "operand.h" |
| #include "spirv-tools/libspirv.h" |
| #include "spirv-tools/markv.h" |
| #include "spirv_endian.h" |
| #include "spirv_validator_options.h" |
| #include "util/bit_stream.h" |
| #include "util/parse_number.h" |
| #include "validate.h" |
| #include "val/instruction.h" |
| #include "val/validation_state.h" |
| |
| using libspirv::Instruction; |
| using libspirv::ValidationState_t; |
| using spvtools::ValidateInstructionAndUpdateValidationState; |
| using spvutils::BitReaderWord64; |
| using spvutils::BitWriterWord64; |
| |
| struct spv_markv_encoder_options_t { |
| }; |
| |
| struct spv_markv_decoder_options_t { |
| }; |
| |
| namespace { |
| |
| const uint32_t kSpirvMagicNumber = SpvMagicNumber; |
| const uint32_t kMarkvMagicNumber = 0x07230303; |
| |
| enum { |
| kMarkvFirstOpcode = 65536, |
| kMarkvOpNextInstructionEncodesResultId = 65536, |
| }; |
| |
| const size_t kCommentNumWhitespaces = 2; |
| |
| // TODO(atgoo@github.com): This is a placeholder for an autogenerated flatbuffer |
| // containing MARK-V model for a specific dataset. |
| class MarkvModel { |
| public: |
| size_t opcode_chunk_length() const { return 7; } |
| size_t num_operands_chunk_length() const { return 3; } |
| size_t id_index_chunk_length() const { return 3; } |
| |
| size_t u16_chunk_length() const { return 4; } |
| size_t s16_chunk_length() const { return 4; } |
| size_t s16_block_exponent() const { return 6; } |
| |
| size_t u32_chunk_length() const { return 8; } |
| size_t s32_chunk_length() const { return 8; } |
| size_t s32_block_exponent() const { return 10; } |
| |
| size_t u64_chunk_length() const { return 8; } |
| size_t s64_chunk_length() const { return 8; } |
| size_t s64_block_exponent() const { return 10; } |
| }; |
| |
| const MarkvModel* GetDefaultModel() { |
| static MarkvModel model; |
| return &model; |
| } |
| |
| // Returns chunk length used for variable length encoding of spirv operand |
| // words. Returns zero if operand type corresponds to potentially multiple |
| // words or a word which is not expected to profit from variable width encoding. |
| // Chunk length is selected based on the size of expected value. |
| // Most of these values will later be encoded with probability-based coding, |
| // but variable width integer coding is a good quick solution. |
| // TODO(atgoo@github.com): Put this in MarkvModel flatbuffer. |
| size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) { |
| switch (type) { |
| case SPV_OPERAND_TYPE_TYPE_ID: |
| return 4; |
| case SPV_OPERAND_TYPE_RESULT_ID: |
| case SPV_OPERAND_TYPE_ID: |
| case SPV_OPERAND_TYPE_SCOPE_ID: |
| case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: |
| return 8; |
| case SPV_OPERAND_TYPE_LITERAL_INTEGER: |
| case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: |
| return 6; |
| case SPV_OPERAND_TYPE_CAPABILITY: |
| return 6; |
| case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: |
| case SPV_OPERAND_TYPE_EXECUTION_MODEL: |
| return 3; |
| case SPV_OPERAND_TYPE_ADDRESSING_MODEL: |
| case SPV_OPERAND_TYPE_MEMORY_MODEL: |
| return 2; |
| case SPV_OPERAND_TYPE_EXECUTION_MODE: |
| return 6; |
| case SPV_OPERAND_TYPE_STORAGE_CLASS: |
| return 4; |
| case SPV_OPERAND_TYPE_DIMENSIONALITY: |
| case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: |
| return 3; |
| case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: |
| return 2; |
| case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: |
| return 6; |
| case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: |
| case SPV_OPERAND_TYPE_LINKAGE_TYPE: |
| case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: |
| case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: |
| return 2; |
| case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: |
| return 3; |
| case SPV_OPERAND_TYPE_DECORATION: |
| case SPV_OPERAND_TYPE_BUILT_IN: |
| return 6; |
| case SPV_OPERAND_TYPE_GROUP_OPERATION: |
| case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: |
| case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: |
| return 2; |
| case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: |
| case SPV_OPERAND_TYPE_FUNCTION_CONTROL: |
| case SPV_OPERAND_TYPE_LOOP_CONTROL: |
| case SPV_OPERAND_TYPE_IMAGE: |
| case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: |
| case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: |
| case SPV_OPERAND_TYPE_SELECTION_CONTROL: |
| return 4; |
| case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: |
| return 6; |
| default: |
| return 0; |
| } |
| return 0; |
| } |
| |
| // Returns true if the opcode has a fixed number of operands. May return a |
| // false negative. |
| bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) { |
| switch (opcode) { |
| // TODO(atgoo@github.com) This is not a complete list. |
| case SpvOpNop: |
| case SpvOpName: |
| case SpvOpUndef: |
| case SpvOpSizeOf: |
| case SpvOpLine: |
| case SpvOpNoLine: |
| case SpvOpDecorationGroup: |
| case SpvOpExtension: |
| case SpvOpExtInstImport: |
| case SpvOpMemoryModel: |
| case SpvOpCapability: |
| case SpvOpTypeVoid: |
| case SpvOpTypeBool: |
| case SpvOpTypeInt: |
| case SpvOpTypeFloat: |
| case SpvOpTypeVector: |
| case SpvOpTypeMatrix: |
| case SpvOpTypeSampler: |
| case SpvOpTypeSampledImage: |
| case SpvOpTypeArray: |
| case SpvOpTypePointer: |
| case SpvOpConstantTrue: |
| case SpvOpConstantFalse: |
| case SpvOpLabel: |
| case SpvOpBranch: |
| case SpvOpFunction: |
| case SpvOpFunctionParameter: |
| case SpvOpFunctionEnd: |
| case SpvOpBitcast: |
| case SpvOpCopyObject: |
| case SpvOpTranspose: |
| case SpvOpSNegate: |
| case SpvOpFNegate: |
| case SpvOpIAdd: |
| case SpvOpFAdd: |
| case SpvOpISub: |
| case SpvOpFSub: |
| case SpvOpIMul: |
| case SpvOpFMul: |
| case SpvOpUDiv: |
| case SpvOpSDiv: |
| case SpvOpFDiv: |
| case SpvOpUMod: |
| case SpvOpSRem: |
| case SpvOpSMod: |
| case SpvOpFRem: |
| case SpvOpFMod: |
| case SpvOpVectorTimesScalar: |
| case SpvOpMatrixTimesScalar: |
| case SpvOpVectorTimesMatrix: |
| case SpvOpMatrixTimesVector: |
| case SpvOpMatrixTimesMatrix: |
| case SpvOpOuterProduct: |
| case SpvOpDot: |
| return true; |
| default: |
| break; |
| } |
| return false; |
| } |
| |
| size_t GetNumBitsToNextByte(size_t bit_pos) { |
| return (8 - (bit_pos % 8)) % 8; |
| } |
| |
| bool ShouldByteBreak(size_t bit_pos) { |
| const size_t num_bits_to_next_byte = GetNumBitsToNextByte(bit_pos); |
| return num_bits_to_next_byte > 0; // && num_bits_to_next_byte <= 2; |
| } |
| |
| // Defines and returns current MARK-V version. |
| uint32_t GetMarkvVersion() { |
| const uint32_t kVersionMajor = 1; |
| const uint32_t kVersionMinor = 0; |
| return kVersionMinor | (kVersionMajor << 16); |
| } |
| |
| class CommentLogger { |
| public: |
| void AppendText(const std::string& str) { |
| Append(str); |
| use_delimiter_ = false; |
| } |
| |
| void AppendTextNewLine(const std::string& str) { |
| Append(str); |
| Append("\n"); |
| use_delimiter_ = false; |
| } |
| |
| void AppendBitSequence(const std::string& str) { |
| if (use_delimiter_) |
| Append("-"); |
| Append(str); |
| use_delimiter_ = true; |
| } |
| |
| void AppendWhitespaces(size_t num) { |
| Append(std::string(num, ' ')); |
| use_delimiter_ = false; |
| } |
| |
| void NewLine() { |
| Append("\n"); |
| use_delimiter_ = false; |
| } |
| |
| std::string GetText() const { |
| return ss_.str(); |
| } |
| |
| private: |
| void Append(const std::string& str) { |
| ss_ << str; |
| // std::cerr << str; |
| } |
| |
| std::stringstream ss_; |
| |
| // If true a delimiter will be appended before the next bit sequence. |
| // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0. |
| bool use_delimiter_ = false; |
| }; |
| |
| // Creates spv_text object containing text from |str|. |
| // The returned value is owned by the caller and needs to be destroyed with |
| // spvTextDestroy. |
| spv_text CreateSpvText(const std::string& str) { |
| spv_text out = new spv_text_t(); |
| assert(out); |
| char* cstr = new char[str.length() + 1]; |
| assert(cstr); |
| std::strncpy(cstr, str.c_str(), str.length()); |
| cstr[str.length()] = '\0'; |
| out->str = cstr; |
| out->length = str.length(); |
| return out; |
| } |
| |
| // Base class for MARK-V encoder and decoder. Contains common functionality |
| // such as: |
| // - Validator connection and validation state. |
| // - SPIR-V grammar and helper functions. |
| class MarkvCodecBase { |
| public: |
| virtual ~MarkvCodecBase() { |
| spvValidatorOptionsDestroy(validator_options_); |
| } |
| |
| MarkvCodecBase() = delete; |
| |
| void SetModel(const MarkvModel* model) { |
| model_ = model; |
| } |
| |
| protected: |
| struct MarkvHeader { |
| MarkvHeader() { |
| magic_number = kMarkvMagicNumber; |
| markv_version = GetMarkvVersion(); |
| markv_model = 0; |
| markv_length_in_bits = 0; |
| spirv_version = 0; |
| spirv_generator = 0; |
| } |
| |
| uint32_t magic_number; |
| uint32_t markv_version; |
| // Magic number to identify or verify MarkvModel used for encoding. |
| uint32_t markv_model; |
| uint32_t markv_length_in_bits; |
| uint32_t spirv_version; |
| uint32_t spirv_generator; |
| }; |
| |
| explicit MarkvCodecBase(spv_const_context context, |
| spv_validator_options validator_options) |
| : validator_options_(validator_options), |
| vstate_(context, validator_options_), grammar_(context), |
| model_(GetDefaultModel()) {} |
| |
| // Validates a single instruction and updates validation state of the module. |
| spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) { |
| return ValidateInstructionAndUpdateValidationState(&vstate_, &inst); |
| } |
| |
| // Returns the current instruction (the one last processed by the validator). |
| const Instruction& GetCurrentInstruction() const { |
| return vstate_.ordered_instructions().back(); |
| } |
| |
| spv_validator_options validator_options_; |
| ValidationState_t vstate_; |
| const libspirv::AssemblyGrammar grammar_; |
| MarkvHeader header_; |
| const MarkvModel* model_; |
| |
| // Move-to-front list of all ids. |
| // TODO(atgoo@github.com) Consider a better move-to-front implementation. |
| std::list<uint32_t> move_to_front_ids_; |
| }; |
| |
| // SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and |
| // EncodeInstruction which can be used as callback by spvBinaryParse. |
| // Encoded binary is written to an internally maintained bitstream. |
| // After the last instruction is encoded, the resulting MARK-V binary can be |
| // acquired by calling GetMarkvBinary(). |
| // The encoder uses SPIR-V validator to keep internal state, therefore |
| // SPIR-V binary needs to be able to pass validator checks. |
| // CreateCommentsLogger() can be used to enable the encoder to write comments |
| // on how encoding was done, which can later be accessed with GetComments(). |
| class MarkvEncoder : public MarkvCodecBase { |
| public: |
| MarkvEncoder(spv_const_context context, |
| spv_const_markv_encoder_options options) |
| : MarkvCodecBase(context, GetValidatorOptions(options)), |
| options_(options) { |
| (void) options_; |
| } |
| |
| // Writes data from SPIR-V header to MARK-V header. |
| spv_result_t EncodeHeader( |
| spv_endianness_t /* endian */, uint32_t /* magic */, |
| uint32_t version, uint32_t generator, uint32_t id_bound, |
| uint32_t /* schema */) { |
| vstate_.setIdBound(id_bound); |
| header_.spirv_version = version; |
| header_.spirv_generator = generator; |
| return SPV_SUCCESS; |
| } |
| |
| // Encodes SPIR-V instruction to MARK-V and writes to bit stream. |
| // Operation can fail if the instruction fails to pass the validator or if |
| // the encoder stubmles on something unexpected. |
| spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst); |
| |
| // Concatenates MARK-V header and the bit stream with encoded instructions |
| // into a single buffer and returns it as spv_markv_binary. The returned |
| // value is owned by the caller and needs to be destroyed with |
| // spvMarkvBinaryDestroy(). |
| spv_markv_binary GetMarkvBinary() { |
| header_.markv_length_in_bits = |
| static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits()); |
| const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes(); |
| |
| spv_markv_binary markv_binary = new spv_markv_binary_t(); |
| markv_binary->data = new uint8_t[num_bytes]; |
| markv_binary->length = num_bytes; |
| assert(writer_.GetData()); |
| std::memcpy(markv_binary->data, &header_, sizeof(header_)); |
| std::memcpy(markv_binary->data + sizeof(header_), |
| writer_.GetData(), writer_.GetDataSizeBytes()); |
| return markv_binary; |
| } |
| |
| // Creates an internal logger which writes comments on the encoding process. |
| // Output can later be accessed with GetComments(). |
| void CreateCommentsLogger() { |
| logger_.reset(new CommentLogger()); |
| writer_.SetCallback([this](const std::string& str){ |
| logger_->AppendBitSequence(str); |
| }); |
| } |
| |
| // Optionally adds disassembly to the comments. |
| // Disassembly should contain all instructions in the module separated by |
| // \n, and no header. |
| void SetDisassembly(std::string&& disassembly) { |
| disassembly_.reset(new std::stringstream(std::move(disassembly))); |
| } |
| |
| // Extracts the next instruction line from the disassembly and logs it. |
| void LogDisassemblyInstruction() { |
| if (logger_ && disassembly_) { |
| std::string line; |
| std::getline(*disassembly_, line, '\n'); |
| logger_->AppendTextNewLine(line); |
| } |
| } |
| |
| // Extracts the text from the comment logger. |
| std::string GetComments() const { |
| if (!logger_) |
| return ""; |
| return logger_->GetText(); |
| } |
| |
| private: |
| // Creates and returns validator options. Return value owned by the caller. |
| static spv_validator_options GetValidatorOptions( |
| spv_const_markv_encoder_options) { |
| return spvValidatorOptionsCreate(); |
| } |
| |
| // Writes a single word to bit stream. |type| determines if the word is |
| // encoded and how. |
| void EncodeOperandWord(spv_operand_type_t type, uint32_t word) { |
| const size_t chunk_length = |
| GetOperandVariableWidthChunkLength(type); |
| if (chunk_length) { |
| writer_.WriteVariableWidthU32(word, chunk_length); |
| } else { |
| writer_.WriteUnencoded(word); |
| } |
| } |
| |
| // Returns id index and updates move-to-front. |
| // Index is uint16 as SPIR-V module is guaranteed to have no more than 65535 |
| // instructions. |
| uint16_t GetIdIndex(uint32_t id) { |
| if (all_known_ids_.count(id)) { |
| uint16_t index = 0; |
| for (auto it = move_to_front_ids_.begin(); |
| it != move_to_front_ids_.end(); ++it) { |
| if (*it == id) { |
| if (index != 0) { |
| move_to_front_ids_.erase(it); |
| move_to_front_ids_.push_front(id); |
| } |
| return index; |
| } |
| ++index; |
| } |
| assert(0 && "Id not found in move_to_front_ids_"); |
| return 0; |
| } else { |
| all_known_ids_.insert(id); |
| move_to_front_ids_.push_front(id); |
| return static_cast<uint16_t>(move_to_front_ids_.size() - 1); |
| } |
| } |
| |
| void AddByteBreakIfAgreed() { |
| if (!ShouldByteBreak(writer_.GetNumBits())) |
| return; |
| |
| if (logger_) { |
| logger_->AppendWhitespaces(kCommentNumWhitespaces); |
| logger_->AppendText("ByteBreak:"); |
| } |
| |
| writer_.WriteBits(0, GetNumBitsToNextByte(writer_.GetNumBits())); |
| } |
| |
| // Encodes a literal number operand and writes it to the bit stream. |
| void EncodeLiteralNumber(const Instruction& instruction, |
| const spv_parsed_operand_t& operand); |
| |
| spv_const_markv_encoder_options options_; |
| |
| // Bit stream where encoded instructions are written. |
| BitWriterWord64 writer_; |
| |
| // If not nullptr, encoder will write comments. |
| std::unique_ptr<CommentLogger> logger_; |
| |
| // If not nullptr, disassembled instruction lines will be written to comments. |
| // Format: \n separated instruction lines, no header. |
| std::unique_ptr<std::stringstream> disassembly_; |
| |
| // All ids which were previosly encountered in the module. |
| std::unordered_set<uint32_t> all_known_ids_; |
| }; |
| |
| // Decodes MARK-V buffers written by MarkvEncoder. |
| class MarkvDecoder : public MarkvCodecBase { |
| public: |
| MarkvDecoder(spv_const_context context, |
| const uint8_t* markv_data, |
| size_t markv_size_bytes, |
| spv_const_markv_decoder_options options) |
| : MarkvCodecBase(context, GetValidatorOptions(options)), |
| options_(options), reader_(markv_data, markv_size_bytes) { |
| (void) options_; |
| vstate_.setIdBound(1); |
| parsed_operands_.reserve(25); |
| } |
| |
| // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|. |
| // Can be called only once. Fails if data of wrong format or ends prematurely, |
| // of if validation fails. |
| spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary); |
| |
| private: |
| // Describes the format of a typed literal number. |
| struct NumberType { |
| spv_number_kind_t type; |
| uint32_t bit_width; |
| }; |
| |
| // Creates and returns validator options. Return value owned by the caller. |
| static spv_validator_options GetValidatorOptions( |
| spv_const_markv_decoder_options) { |
| return spvValidatorOptionsCreate(); |
| } |
| |
| // Reads a single word from bit stream. |type| determines if the word needs |
| // to be decoded and how. Returns false if read fails. |
| bool DecodeOperandWord(spv_operand_type_t type, uint32_t* word) { |
| const size_t chunk_length = GetOperandVariableWidthChunkLength(type); |
| if (chunk_length) { |
| return reader_.ReadVariableWidthU32(word, chunk_length); |
| } else { |
| return reader_.ReadUnencoded(word); |
| } |
| } |
| |
| // Fetches the id from the move-to-front list and moves it to front. |
| uint32_t GetIdAndMoveToFront(uint16_t index) { |
| if (index >= move_to_front_ids_.size()) { |
| // Issue new id. |
| const uint32_t id = vstate_.getIdBound(); |
| move_to_front_ids_.push_front(id); |
| vstate_.setIdBound(id + 1); |
| return id; |
| } else { |
| if (index == 0) |
| return move_to_front_ids_.front(); |
| |
| // Iterate to index. |
| auto it = move_to_front_ids_.begin(); |
| for (size_t i = 0; i < index; ++i) |
| ++it; |
| const uint32_t id = *it; |
| move_to_front_ids_.erase(it); |
| move_to_front_ids_.push_front(id); |
| return id; |
| } |
| } |
| |
| // Decodes id index and fetches the id from move-to-front list. |
| bool DecodeId(uint32_t* id) { |
| uint16_t index = 0; |
| if (!reader_.ReadVariableWidthU16(&index, model_->id_index_chunk_length())) |
| return false; |
| |
| *id = GetIdAndMoveToFront(index); |
| return true; |
| } |
| |
| bool ReadToByteBreakIfAgreed() { |
| if (!ShouldByteBreak(reader_.GetNumReadBits())) |
| return true; |
| |
| uint64_t bits = 0; |
| if (!reader_.ReadBits(&bits, |
| GetNumBitsToNextByte(reader_.GetNumReadBits()))) |
| return false; |
| |
| if (bits != 0) |
| return false; |
| |
| return true; |
| } |
| |
| // Reads a literal number as it is described in |operand| from the bit stream, |
| // decodes and writes it to spirv_. |
| spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand); |
| |
| // Reads instruction from bit stream, decodes and validates it. |
| // Decoded instruction is valid until the next call of DecodeInstruction(). |
| spv_result_t DecodeInstruction(spv_parsed_instruction_t* inst); |
| |
| // Read operand from the stream decodes and validates it. |
| spv_result_t DecodeOperand(size_t instruction_offset, size_t operand_offset, |
| spv_parsed_instruction_t* inst, |
| const spv_operand_type_t type, |
| spv_operand_pattern_t* expected_operands, |
| bool read_result_id); |
| |
| // Records the numeric type for an operand according to the type information |
| // associated with the given non-zero type Id. This can fail if the type Id |
| // is not a type Id, or if the type Id does not reference a scalar numeric |
| // type. On success, return SPV_SUCCESS and populates the num_words, |
| // number_kind, and number_bit_width fields of parsed_operand. |
| spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand, |
| uint32_t type_id); |
| |
| // Records the number type for the given instruction, if that |
| // instruction generates a type. For types that aren't scalar numbers, |
| // record something with number kind SPV_NUMBER_NONE. |
| void RecordNumberType(const spv_parsed_instruction_t& inst); |
| |
| spv_const_markv_decoder_options options_; |
| |
| // Temporary sink where decoded SPIR-V words are written. Once it contains the |
| // entire module, the container is moved and returned. |
| std::vector<uint32_t> spirv_; |
| |
| // Bit stream containing encoded data. |
| BitReaderWord64 reader_; |
| |
| // Temporary storage for operands of the currently parsed instruction. |
| // Valid until next DecodeInstruction call. |
| std::vector<spv_parsed_operand_t> parsed_operands_; |
| |
| // Maps a result ID to its type ID. By convention: |
| // - a result ID that is a type definition maps to itself. |
| // - a result ID without a type maps to 0. (E.g. for OpLabel) |
| std::unordered_map<uint32_t, uint32_t> id_to_type_id_; |
| // Maps a type ID to its number type description. |
| std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_; |
| // Maps an ExtInstImport id to the extended instruction type. |
| std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_; |
| }; |
| |
| void MarkvEncoder::EncodeLiteralNumber(const Instruction& instruction, |
| const spv_parsed_operand_t& operand) { |
| if (operand.number_bit_width == 32) { |
| const uint32_t word = instruction.word(operand.offset); |
| if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { |
| writer_.WriteVariableWidthU32(word, model_->u32_chunk_length()); |
| } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { |
| int32_t val = 0; |
| std::memcpy(&val, &word, 4); |
| writer_.WriteVariableWidthS32(val, model_->s32_chunk_length(), |
| model_->s32_block_exponent()); |
| } else if (operand.number_kind == SPV_NUMBER_FLOATING) { |
| writer_.WriteUnencoded(word); |
| } else { |
| assert(0); |
| } |
| } else if (operand.number_bit_width == 16) { |
| const uint16_t word = |
| static_cast<uint16_t>(instruction.word(operand.offset)); |
| if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { |
| writer_.WriteVariableWidthU16(word, model_->u16_chunk_length()); |
| } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { |
| int16_t val = 0; |
| std::memcpy(&val, &word, 2); |
| writer_.WriteVariableWidthS16(val, model_->s16_chunk_length(), |
| model_->s16_block_exponent()); |
| } else if (operand.number_kind == SPV_NUMBER_FLOATING) { |
| // TODO(atgoo@github.com) Write only 16 bits. |
| writer_.WriteUnencoded(word); |
| } else { |
| assert(0); |
| } |
| } else { |
| assert(operand.number_bit_width == 64); |
| const uint64_t word = |
| uint64_t(instruction.word(operand.offset)) | |
| (uint64_t(instruction.word(operand.offset + 1)) << 32); |
| if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { |
| writer_.WriteVariableWidthU64(word, model_->u64_chunk_length()); |
| } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { |
| int64_t val = 0; |
| std::memcpy(&val, &word, 8); |
| writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(), |
| model_->s64_block_exponent()); |
| } else if (operand.number_kind == SPV_NUMBER_FLOATING) { |
| writer_.WriteUnencoded(word); |
| } else { |
| assert(0); |
| } |
| } |
| } |
| |
| spv_result_t MarkvEncoder::EncodeInstruction( |
| const spv_parsed_instruction_t& inst) { |
| const spv_result_t validation_result = UpdateValidationState(inst); |
| if (validation_result != SPV_SUCCESS) |
| return validation_result; |
| |
| bool result_id_was_forward_declared = false; |
| if (all_known_ids_.count(inst.result_id)) { |
| // Result id of the instruction was forward declared. |
| // Write a service opcode to signal this to the decoder. |
| writer_.WriteVariableWidthU32(kMarkvOpNextInstructionEncodesResultId, |
| model_->opcode_chunk_length()); |
| result_id_was_forward_declared = true; |
| } |
| |
| const Instruction& instruction = GetCurrentInstruction(); |
| const auto& operands = instruction.operands(); |
| |
| LogDisassemblyInstruction(); |
| |
| // Write opcode. |
| writer_.WriteVariableWidthU32(inst.opcode, model_->opcode_chunk_length()); |
| |
| if (!OpcodeHasFixedNumberOfOperands(SpvOp(inst.opcode))) { |
| // If the opcode has a variable number of operands, encode the number of |
| // operands with the instruction. |
| |
| if (logger_) |
| logger_->AppendWhitespaces(kCommentNumWhitespaces); |
| |
| writer_.WriteVariableWidthU16(inst.num_operands, |
| model_->num_operands_chunk_length()); |
| } |
| |
| // Write operands. |
| for (const auto& operand : operands) { |
| if (operand.type == SPV_OPERAND_TYPE_RESULT_ID && |
| !result_id_was_forward_declared) { |
| // Register the id, but don't encode it. |
| GetIdIndex(instruction.word(operand.offset)); |
| continue; |
| } |
| |
| if (logger_) |
| logger_->AppendWhitespaces(kCommentNumWhitespaces); |
| |
| if (operand.type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) { |
| EncodeLiteralNumber(instruction, operand); |
| } else if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) { |
| const char* src = |
| reinterpret_cast<const char*>(&instruction.words()[operand.offset]); |
| const size_t length = spv_strnlen_s(src, operand.num_words * 4); |
| if (length == operand.num_words * 4) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to find terminal character of literal string"; |
| for (size_t i = 0; i < length + 1; ++i) |
| writer_.WriteUnencoded(src[i]); |
| } else if (spvIsIdType(operand.type)) { |
| const uint16_t id_index = GetIdIndex(instruction.word(operand.offset)); |
| writer_.WriteVariableWidthU16(id_index, model_->id_index_chunk_length()); |
| } else { |
| for (int i = 0; i < operand.num_words; ++i) { |
| const uint32_t word = instruction.word(operand.offset + i); |
| EncodeOperandWord(operand.type, word); |
| } |
| } |
| } |
| |
| AddByteBreakIfAgreed(); |
| |
| if (logger_) { |
| logger_->NewLine(); |
| logger_->NewLine(); |
| } |
| |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t MarkvDecoder::DecodeLiteralNumber( |
| const spv_parsed_operand_t& operand) { |
| if (operand.number_bit_width == 32) { |
| uint32_t word = 0; |
| if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { |
| if (!reader_.ReadVariableWidthU32(&word, model_->u32_chunk_length())) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal U32"; |
| } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { |
| int32_t val = 0; |
| if (!reader_.ReadVariableWidthS32(&val, model_->s32_chunk_length(), |
| model_->s32_block_exponent())) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal S32"; |
| std::memcpy(&word, &val, 4); |
| } else if (operand.number_kind == SPV_NUMBER_FLOATING) { |
| if (!reader_.ReadUnencoded(&word)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal F32"; |
| } else { |
| assert(0); |
| } |
| spirv_.push_back(word); |
| } else if (operand.number_bit_width == 16) { |
| uint32_t word = 0; |
| if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { |
| uint16_t val = 0; |
| if (!reader_.ReadVariableWidthU16(&val, model_->u16_chunk_length())) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal U16"; |
| word = val; |
| } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { |
| int16_t val = 0; |
| if (!reader_.ReadVariableWidthS16(&val, model_->s16_chunk_length(), |
| model_->s16_block_exponent())) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal S16"; |
| // Int16 is stored as int32 in SPIR-V, not as bits. |
| int32_t val32 = val; |
| std::memcpy(&word, &val32, 4); |
| } else if (operand.number_kind == SPV_NUMBER_FLOATING) { |
| uint16_t word16 = 0; |
| if (!reader_.ReadUnencoded(&word16)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal F16"; |
| word = word16; |
| } else { |
| assert(0); |
| } |
| spirv_.push_back(word); |
| } else { |
| assert(operand.number_bit_width == 64); |
| uint64_t word = 0; |
| if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { |
| if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length())) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal U64"; |
| } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { |
| int64_t val = 0; |
| if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(), |
| model_->s64_block_exponent())) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal S64"; |
| std::memcpy(&word, &val, 8); |
| } else if (operand.number_kind == SPV_NUMBER_FLOATING) { |
| if (!reader_.ReadUnencoded(&word)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal F64"; |
| } else { |
| assert(0); |
| } |
| spirv_.push_back(static_cast<uint32_t>(word)); |
| spirv_.push_back(static_cast<uint32_t>(word >> 32)); |
| } |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) { |
| const bool header_read_success = |
| reader_.ReadUnencoded(&header_.magic_number) && |
| reader_.ReadUnencoded(&header_.markv_version) && |
| reader_.ReadUnencoded(&header_.markv_model) && |
| reader_.ReadUnencoded(&header_.markv_length_in_bits) && |
| reader_.ReadUnencoded(&header_.spirv_version) && |
| reader_.ReadUnencoded(&header_.spirv_generator); |
| |
| if (!header_read_success) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Unable to read MARK-V header"; |
| |
| assert(header_.magic_number == kMarkvMagicNumber); |
| assert(header_.markv_length_in_bits > 0); |
| |
| if (header_.magic_number != kMarkvMagicNumber) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "MARK-V binary has incorrect magic number"; |
| |
| // TODO(atgoo@github.com): Print version strings. |
| if (header_.markv_version != GetMarkvVersion()) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "MARK-V binary and the codec have different versions"; |
| |
| spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic. |
| spirv_.resize(5, 0); |
| spirv_[0] = kSpirvMagicNumber; |
| spirv_[1] = header_.spirv_version; |
| spirv_[2] = header_.spirv_generator; |
| |
| while (reader_.GetNumReadBits() < header_.markv_length_in_bits) { |
| spv_parsed_instruction_t inst = {}; |
| const spv_result_t decode_result = DecodeInstruction(&inst); |
| if (decode_result != SPV_SUCCESS) |
| return decode_result; |
| |
| const spv_result_t validation_result = UpdateValidationState(inst); |
| if (validation_result != SPV_SUCCESS) |
| return validation_result; |
| } |
| |
| |
| if (reader_.GetNumReadBits() != header_.markv_length_in_bits || |
| !reader_.OnlyZeroesLeft()) { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "MARK-V binary has wrong stated bit length " |
| << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits; |
| } |
| |
| // Decoding of the module is finished, validation state should have correct |
| // id bound. |
| spirv_[3] = vstate_.getIdBound(); |
| |
| *spirv_binary = std::move(spirv_); |
| return SPV_SUCCESS; |
| } |
| |
| // TODO(atgoo@github.com): The implementation borrows heavily from |
| // Parser::parseOperand. |
| // Consider coupling them together in some way once MARK-V codec is more mature. |
| // For now it's better to keep the code independent for experimentation |
| // purposes. |
| spv_result_t MarkvDecoder::DecodeOperand( |
| size_t instruction_offset, size_t operand_offset, |
| spv_parsed_instruction_t* inst, const spv_operand_type_t type, |
| spv_operand_pattern_t* expected_operands, |
| bool read_result_id) { |
| const SpvOp opcode = static_cast<SpvOp>(inst->opcode); |
| |
| spv_parsed_operand_t parsed_operand; |
| memset(&parsed_operand, 0, sizeof(parsed_operand)); |
| |
| assert((operand_offset >> 16) == 0); |
| parsed_operand.offset = static_cast<uint16_t>(operand_offset); |
| parsed_operand.type = type; |
| |
| // Set default values, may be updated later. |
| parsed_operand.number_kind = SPV_NUMBER_NONE; |
| parsed_operand.number_bit_width = 0; |
| |
| const size_t first_word_index = spirv_.size(); |
| |
| switch (type) { |
| case SPV_OPERAND_TYPE_TYPE_ID: { |
| if (!DecodeId(&inst->type_id)) { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read type_id"; |
| } |
| |
| if (inst->type_id == 0) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded type_id is 0"; |
| |
| spirv_.push_back(inst->type_id); |
| vstate_.setIdBound(std::max(vstate_.getIdBound(), inst->type_id + 1)); |
| break; |
| } |
| |
| case SPV_OPERAND_TYPE_RESULT_ID: { |
| if (read_result_id) { |
| if (!DecodeId(&inst->result_id)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read result_id"; |
| } else { |
| inst->result_id = vstate_.getIdBound(); |
| vstate_.setIdBound(inst->result_id + 1); |
| move_to_front_ids_.push_front(inst->result_id); |
| } |
| |
| spirv_.push_back(inst->result_id); |
| |
| // Save the result ID to type ID mapping. |
| // In the grammar, type ID always appears before result ID. |
| // A regular value maps to its type. Some instructions (e.g. OpLabel) |
| // have no type Id, and will map to 0. The result Id for a |
| // type-generating instruction (e.g. OpTypeInt) maps to itself. |
| auto insertion_result = id_to_type_id_.emplace( |
| inst->result_id, |
| spvOpcodeGeneratesType(opcode) ? inst->result_id : inst->type_id); |
| if(!insertion_result.second) { |
| return vstate_.diag(SPV_ERROR_INVALID_ID) |
| << "Unexpected behavior: id->type_id pair was already registered"; |
| } |
| break; |
| } |
| |
| case SPV_OPERAND_TYPE_ID: |
| case SPV_OPERAND_TYPE_OPTIONAL_ID: |
| case SPV_OPERAND_TYPE_SCOPE_ID: |
| case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { |
| uint32_t id = 0; |
| if (!DecodeId(&id)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id"; |
| |
| if (id == 0) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0"; |
| |
| spirv_.push_back(id); |
| vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1)); |
| |
| if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) { |
| |
| parsed_operand.type = SPV_OPERAND_TYPE_ID; |
| |
| if (opcode == SpvOpExtInst && parsed_operand.offset == 3) { |
| // The current word is the extended instruction set id. |
| // Set the extended instruction set type for the current instruction. |
| auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id); |
| if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) { |
| return vstate_.diag(SPV_ERROR_INVALID_ID) |
| << "OpExtInst set id " << id |
| << " does not reference an OpExtInstImport result Id"; |
| } |
| inst->ext_inst_type = ext_inst_type_iter->second; |
| } |
| } |
| break; |
| } |
| |
| case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { |
| uint32_t word = 0; |
| if (!DecodeOperandWord(type, &word)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read enum"; |
| |
| spirv_.push_back(word); |
| |
| assert(SpvOpExtInst == opcode); |
| assert(inst->ext_inst_type != SPV_EXT_INST_TYPE_NONE); |
| spv_ext_inst_desc ext_inst; |
| if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Invalid extended instruction number: " << word; |
| spvPushOperandTypes(ext_inst->operandTypes, expected_operands); |
| break; |
| } |
| |
| case SPV_OPERAND_TYPE_LITERAL_INTEGER: |
| case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { |
| // These are regular single-word literal integer operands. |
| // Post-parsing validation should check the range of the parsed value. |
| parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_INTEGER; |
| // It turns out they are always unsigned integers! |
| parsed_operand.number_kind = SPV_NUMBER_UNSIGNED_INT; |
| parsed_operand.number_bit_width = 32; |
| |
| uint32_t word = 0; |
| if (!DecodeOperandWord(type, &word)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal integer"; |
| |
| spirv_.push_back(word); |
| break; |
| } |
| |
| case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: |
| case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: |
| parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER; |
| if (opcode == SpvOpSwitch) { |
| // The literal operands have the same type as the value |
| // referenced by the selector Id. |
| const uint32_t selector_id = spirv_.at(instruction_offset + 1); |
| const auto type_id_iter = id_to_type_id_.find(selector_id); |
| if (type_id_iter == id_to_type_id_.end() || |
| type_id_iter->second == 0) { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Invalid OpSwitch: selector id " << selector_id |
| << " has no type"; |
| } |
| uint32_t type_id = type_id_iter->second; |
| |
| if (selector_id == type_id) { |
| // Recall that by convention, a result ID that is a type definition |
| // maps to itself. |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Invalid OpSwitch: selector id " << selector_id |
| << " is a type, not a value"; |
| } |
| if (auto error = SetNumericTypeInfoForType(&parsed_operand, type_id)) |
| return error; |
| if (parsed_operand.number_kind != SPV_NUMBER_UNSIGNED_INT && |
| parsed_operand.number_kind != SPV_NUMBER_SIGNED_INT) { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Invalid OpSwitch: selector id " << selector_id |
| << " is not a scalar integer"; |
| } |
| } else { |
| assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant); |
| // The literal number type is determined by the type Id for the |
| // constant. |
| assert(inst->type_id); |
| if (auto error = |
| SetNumericTypeInfoForType(&parsed_operand, inst->type_id)) |
| return error; |
| } |
| |
| if (auto error = DecodeLiteralNumber(parsed_operand)) |
| return error; |
| |
| break; |
| |
| case SPV_OPERAND_TYPE_LITERAL_STRING: |
| case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { |
| parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_STRING; |
| std::vector<char> str; |
| // The loop is expected to terminate once we encounter '\0' or exhaust |
| // the bit stream. |
| while (true) { |
| char ch = 0; |
| if (!reader_.ReadUnencoded(&ch)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read literal string"; |
| |
| str.push_back(ch); |
| |
| if (ch == '\0') |
| break; |
| } |
| |
| while (str.size() % 4 != 0) |
| str.push_back('\0'); |
| |
| spirv_.resize(spirv_.size() + str.size() / 4); |
| std::memcpy(&spirv_[first_word_index], str.data(), str.size()); |
| |
| if (SpvOpExtInstImport == opcode) { |
| // Record the extended instruction type for the ID for this import. |
| // There is only one string literal argument to OpExtInstImport, |
| // so it's sufficient to guard this just on the opcode. |
| const spv_ext_inst_type_t ext_inst_type = |
| spvExtInstImportTypeGet(str.data()); |
| if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Invalid extended instruction import '" << str.data() << "'"; |
| } |
| // We must have parsed a valid result ID. It's a condition |
| // of the grammar, and we only accept non-zero result Ids. |
| assert(inst->result_id); |
| const bool inserted = import_id_to_ext_inst_type_.emplace( |
| inst->result_id, ext_inst_type).second; |
| (void)inserted; |
| assert(inserted); |
| } |
| break; |
| } |
| |
| case SPV_OPERAND_TYPE_CAPABILITY: |
| case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: |
| case SPV_OPERAND_TYPE_EXECUTION_MODEL: |
| case SPV_OPERAND_TYPE_ADDRESSING_MODEL: |
| case SPV_OPERAND_TYPE_MEMORY_MODEL: |
| case SPV_OPERAND_TYPE_EXECUTION_MODE: |
| case SPV_OPERAND_TYPE_STORAGE_CLASS: |
| case SPV_OPERAND_TYPE_DIMENSIONALITY: |
| case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: |
| case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: |
| case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: |
| case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: |
| case SPV_OPERAND_TYPE_LINKAGE_TYPE: |
| case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: |
| case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: |
| case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: |
| case SPV_OPERAND_TYPE_DECORATION: |
| case SPV_OPERAND_TYPE_BUILT_IN: |
| case SPV_OPERAND_TYPE_GROUP_OPERATION: |
| case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: |
| case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: { |
| // A single word that is a plain enum value. |
| uint32_t word = 0; |
| if (!DecodeOperandWord(type, &word)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read enum"; |
| |
| spirv_.push_back(word); |
| |
| // Map an optional operand type to its corresponding concrete type. |
| if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER) |
| parsed_operand.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER; |
| |
| spv_operand_desc entry; |
| if (grammar_.lookupOperand(type, word, &entry)) { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Invalid " |
| << spvOperandTypeStr(parsed_operand.type) |
| << " operand: " << word; |
| } |
| |
| // Prepare to accept operands to this operand, if needed. |
| spvPushOperandTypes(entry->operandTypes, expected_operands); |
| break; |
| } |
| |
| case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: |
| case SPV_OPERAND_TYPE_FUNCTION_CONTROL: |
| case SPV_OPERAND_TYPE_LOOP_CONTROL: |
| case SPV_OPERAND_TYPE_IMAGE: |
| case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: |
| case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: |
| case SPV_OPERAND_TYPE_SELECTION_CONTROL: { |
| // This operand is a mask. |
| uint32_t word = 0; |
| if (!DecodeOperandWord(type, &word)) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read " << spvOperandTypeStr(type) |
| << " for " << spvOpcodeString(SpvOp(inst->opcode)); |
| |
| spirv_.push_back(word); |
| |
| // Map an optional operand type to its corresponding concrete type. |
| if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE) |
| parsed_operand.type = SPV_OPERAND_TYPE_IMAGE; |
| else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) |
| parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; |
| |
| // Check validity of set mask bits. Also prepare for operands for those |
| // masks if they have any. To get operand order correct, scan from |
| // MSB to LSB since we can only prepend operands to a pattern. |
| // The only case in the grammar where you have more than one mask bit |
| // having an operand is for image operands. See SPIR-V 3.14 Image |
| // Operands. |
| uint32_t remaining_word = word; |
| for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) { |
| if (remaining_word & mask) { |
| spv_operand_desc entry; |
| if (grammar_.lookupOperand(type, mask, &entry)) { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Invalid " << spvOperandTypeStr(parsed_operand.type) |
| << " operand: " << word << " has invalid mask component " |
| << mask; |
| } |
| remaining_word ^= mask; |
| spvPushOperandTypes(entry->operandTypes, expected_operands); |
| } |
| } |
| if (word == 0) { |
| // An all-zeroes mask *might* also be valid. |
| spv_operand_desc entry; |
| if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) { |
| // Prepare for its operands, if any. |
| spvPushOperandTypes(entry->operandTypes, expected_operands); |
| } |
| } |
| break; |
| } |
| default: |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Internal error: Unhandled operand type: " << type; |
| } |
| |
| parsed_operand.num_words = uint16_t(spirv_.size() - first_word_index); |
| |
| assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(parsed_operand.type)); |
| assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(parsed_operand.type)); |
| |
| parsed_operands_.push_back(parsed_operand); |
| |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t MarkvDecoder::DecodeInstruction(spv_parsed_instruction_t* inst) { |
| parsed_operands_.clear(); |
| const size_t instruction_offset = spirv_.size(); |
| |
| bool read_result_id = false; |
| |
| while (true) { |
| uint32_t word = 0; |
| if (!reader_.ReadVariableWidthU32(&word, |
| model_->opcode_chunk_length())) { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read opcode of instruction"; |
| } |
| |
| if (word >= kMarkvFirstOpcode) { |
| if (word == kMarkvOpNextInstructionEncodesResultId) { |
| read_result_id = true; |
| } else { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Encountered unknown MARK-V opcode"; |
| } |
| } else { |
| inst->opcode = static_cast<uint16_t>(word); |
| break; |
| } |
| } |
| |
| const SpvOp opcode = static_cast<SpvOp>(inst->opcode); |
| |
| // Opcode/num_words placeholder, the word will be filled in later. |
| spirv_.push_back(0); |
| |
| spv_opcode_desc opcode_desc; |
| if (grammar_.lookupOpcode(opcode, &opcode_desc) |
| != SPV_SUCCESS) { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode"; |
| } |
| |
| spv_operand_pattern_t expected_operands; |
| expected_operands.reserve(opcode_desc->numTypes); |
| for (auto i = 0; i < opcode_desc->numTypes; i++) |
| expected_operands.push_back(opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]); |
| |
| if (!OpcodeHasFixedNumberOfOperands(opcode)) { |
| if (!reader_.ReadVariableWidthU16(&inst->num_operands, |
| model_->num_operands_chunk_length())) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read num_operands of instruction"; |
| } else { |
| inst->num_operands = static_cast<uint16_t>(expected_operands.size()); |
| } |
| |
| for (size_t operand_index = 0; |
| operand_index < static_cast<size_t>(inst->num_operands); |
| ++operand_index) { |
| assert(!expected_operands.empty()); |
| const spv_operand_type_t type = |
| spvTakeFirstMatchableOperand(&expected_operands); |
| |
| const size_t operand_offset = spirv_.size() - instruction_offset; |
| |
| const spv_result_t decode_result = |
| DecodeOperand(instruction_offset, operand_offset, inst, type, |
| &expected_operands, read_result_id); |
| |
| if (decode_result != SPV_SUCCESS) |
| return decode_result; |
| } |
| |
| assert(inst->num_operands == parsed_operands_.size()); |
| |
| // Only valid while spirv_ and parsed_operands_ remain unchanged. |
| inst->words = &spirv_[instruction_offset]; |
| inst->operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data(); |
| inst->num_words = static_cast<uint16_t>(spirv_.size() - instruction_offset); |
| spirv_[instruction_offset] = |
| spvOpcodeMake(inst->num_words, SpvOp(inst->opcode)); |
| |
| assert(inst->num_words == std::accumulate( |
| parsed_operands_.begin(), parsed_operands_.end(), 1, |
| [](int num_words, const spv_parsed_operand_t& operand) { |
| return num_words += operand.num_words; |
| }) && "num_words in instruction doesn't correspond to the sum of num_words" |
| "in the operands"); |
| |
| RecordNumberType(*inst); |
| |
| if (!ReadToByteBreakIfAgreed()) |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Failed to read to byte break"; |
| |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t MarkvDecoder::SetNumericTypeInfoForType( |
| spv_parsed_operand_t* parsed_operand, uint32_t type_id) { |
| assert(type_id != 0); |
| auto type_info_iter = type_id_to_number_type_info_.find(type_id); |
| if (type_info_iter == type_id_to_number_type_info_.end()) { |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Type Id " << type_id << " is not a type"; |
| } |
| |
| const NumberType& info = type_info_iter->second; |
| if (info.type == SPV_NUMBER_NONE) { |
| // This is a valid type, but for something other than a scalar number. |
| return vstate_.diag(SPV_ERROR_INVALID_BINARY) |
| << "Type Id " << type_id << " is not a scalar numeric type"; |
| } |
| |
| parsed_operand->number_kind = info.type; |
| parsed_operand->number_bit_width = info.bit_width; |
| // Round up the word count. |
| parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32); |
| return SPV_SUCCESS; |
| } |
| |
| void MarkvDecoder::RecordNumberType(const spv_parsed_instruction_t& inst) { |
| const SpvOp opcode = static_cast<SpvOp>(inst.opcode); |
| if (spvOpcodeGeneratesType(opcode)) { |
| NumberType info = {SPV_NUMBER_NONE, 0}; |
| if (SpvOpTypeInt == opcode) { |
| info.bit_width = inst.words[inst.operands[1].offset]; |
| info.type = inst.words[inst.operands[2].offset] ? |
| SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT; |
| } else if (SpvOpTypeFloat == opcode) { |
| info.bit_width = inst.words[inst.operands[1].offset]; |
| info.type = SPV_NUMBER_FLOATING; |
| } |
| // The *result* Id of a type generating instruction is the type Id. |
| type_id_to_number_type_info_[inst.result_id] = info; |
| } |
| } |
| |
| spv_result_t EncodeHeader( |
| void* user_data, spv_endianness_t endian, uint32_t magic, |
| uint32_t version, uint32_t generator, uint32_t id_bound, |
| uint32_t schema) { |
| MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data); |
| return encoder->EncodeHeader( |
| endian, magic, version, generator, id_bound, schema); |
| } |
| |
| spv_result_t EncodeInstruction( |
| void* user_data, const spv_parsed_instruction_t* inst) { |
| MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data); |
| return encoder->EncodeInstruction(*inst); |
| } |
| |
| } // namespace |
| |
| spv_result_t spvSpirvToMarkv(spv_const_context context, |
| const uint32_t* spirv_words, |
| const size_t spirv_num_words, |
| spv_const_markv_encoder_options options, |
| spv_markv_binary* markv_binary, |
| spv_text* comments, spv_diagnostic* diagnostic) { |
| spv_context_t hijack_context = *context; |
| if (diagnostic) { |
| *diagnostic = nullptr; |
| libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); |
| } |
| |
| spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words}; |
| |
| spv_endianness_t endian; |
| spv_position_t position = {}; |
| if (spvBinaryEndianness(&spirv_binary, &endian)) { |
| return libspirv::DiagnosticStream(position, hijack_context.consumer, |
| SPV_ERROR_INVALID_BINARY) |
| << "Invalid SPIR-V magic number."; |
| } |
| |
| spv_header_t header; |
| if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) { |
| return libspirv::DiagnosticStream(position, hijack_context.consumer, |
| SPV_ERROR_INVALID_BINARY) |
| << "Invalid SPIR-V header."; |
| } |
| |
| MarkvEncoder encoder(&hijack_context, options); |
| |
| if (comments) { |
| encoder.CreateCommentsLogger(); |
| |
| spv_text text = nullptr; |
| if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words, |
| SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr) |
| != SPV_SUCCESS) { |
| return libspirv::DiagnosticStream(position, hijack_context.consumer, |
| SPV_ERROR_INVALID_BINARY) |
| << "Failed to disassemble SPIR-V binary."; |
| } |
| assert(text); |
| encoder.SetDisassembly(std::string(text->str, text->length)); |
| spvTextDestroy(text); |
| } |
| |
| if (spvBinaryParse( |
| &hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader, |
| EncodeInstruction, diagnostic) != SPV_SUCCESS) { |
| return libspirv::DiagnosticStream(position, hijack_context.consumer, |
| SPV_ERROR_INVALID_BINARY) |
| << "Unable to encode to MARK-V."; |
| } |
| |
| if (comments) |
| *comments = CreateSpvText(encoder.GetComments()); |
| |
| *markv_binary = encoder.GetMarkvBinary(); |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t spvMarkvToSpirv(spv_const_context context, |
| const uint8_t* markv_data, |
| size_t markv_size_bytes, |
| spv_const_markv_decoder_options options, |
| spv_binary* spirv_binary, |
| spv_text* /* comments */, spv_diagnostic* diagnostic) { |
| spv_position_t position = {}; |
| spv_context_t hijack_context = *context; |
| if (diagnostic) { |
| *diagnostic = nullptr; |
| libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); |
| } |
| |
| MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options); |
| |
| std::vector<uint32_t> words; |
| |
| if (decoder.DecodeModule(&words) != SPV_SUCCESS) { |
| return libspirv::DiagnosticStream(position, hijack_context.consumer, |
| SPV_ERROR_INVALID_BINARY) |
| << "Unable to decode MARK-V."; |
| } |
| |
| assert(!words.empty()); |
| |
| *spirv_binary = new spv_binary_t(); |
| (*spirv_binary)->code = new uint32_t[words.size()]; |
| (*spirv_binary)->wordCount = words.size(); |
| std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size()); |
| |
| return SPV_SUCCESS; |
| } |
| |
| void spvMarkvBinaryDestroy(spv_markv_binary binary) { |
| if (!binary) return; |
| delete[] binary->data; |
| delete binary; |
| } |
| |
| spv_markv_encoder_options spvMarkvEncoderOptionsCreate() { |
| return new spv_markv_encoder_options_t; |
| } |
| |
| void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) { |
| delete options; |
| } |
| |
| spv_markv_decoder_options spvMarkvDecoderOptionsCreate() { |
| return new spv_markv_decoder_options_t; |
| } |
| |
| void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) { |
| delete options; |
| } |