Unify memory instruction validation style (#1934)
* Rename ValidateMemoryInstructions to MemoryPass
* Changed functions to take pointer to an instruction instead of
reference
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index 84ec193..012e9d7 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -313,8 +313,7 @@
if (auto error = ModeSettingPass(*vstate, &instruction)) return error;
if (auto error = TypePass(*vstate, &instruction)) return error;
if (auto error = ConstantPass(*vstate, &instruction)) return error;
- if (auto error = ValidateMemoryInstructions(*vstate, &instruction))
- return error;
+ if (auto error = MemoryPass(*vstate, &instruction)) return error;
if (auto error = FunctionPass(*vstate, &instruction)) return error;
if (auto error = ImagePass(*vstate, &instruction)) return error;
if (auto error = ConversionPass(*vstate, &instruction)) return error;
diff --git a/source/val/validate.h b/source/val/validate.h
index 4599c4a..518547f 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -91,8 +91,7 @@
///
/// @param[in] _ the validation state of the module
/// @return SPV_SUCCESS if no errors are found.
-spv_result_t ValidateMemoryInstructions(ValidationState_t& _,
- const Instruction* inst);
+spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst);
/// @brief Updates the immediate dominator for each of the block edges
///
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index f1e10b3..3a9e1ae 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -180,18 +180,18 @@
return false;
}
-spv_result_t ValidateVariable(ValidationState_t& _, const Instruction& inst) {
- auto result_type = _.FindDef(inst.type_id());
+spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
+ auto result_type = _.FindDef(inst->type_id());
if (!result_type || result_type->opcode() != SpvOpTypePointer) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
- << "OpVariable Result Type <id> '" << _.getIdName(inst.type_id())
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpVariable Result Type <id> '" << _.getIdName(inst->type_id())
<< "' is not a pointer type.";
}
const auto initializer_index = 3;
const auto storage_class_index = 2;
- if (initializer_index < inst.operands().size()) {
- const auto initializer_id = inst.GetOperandAs<uint32_t>(initializer_index);
+ if (initializer_index < inst->operands().size()) {
+ const auto initializer_id = inst->GetOperandAs<uint32_t>(initializer_index);
const auto initializer = _.FindDef(initializer_id);
const auto is_module_scope_var =
initializer && (initializer->opcode() == SpvOpVariable) &&
@@ -200,13 +200,13 @@
const auto is_constant =
initializer && spvOpcodeIsConstant(initializer->opcode());
if (!initializer || !(is_constant || is_module_scope_var)) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpVariable Initializer <id> '" << _.getIdName(initializer_id)
<< "' is not a constant or module-scope variable.";
}
}
- const auto storage_class = inst.GetOperandAs<uint32_t>(storage_class_index);
+ const auto storage_class = inst->GetOperandAs<uint32_t>(storage_class_index);
if (storage_class != SpvStorageClassWorkgroup &&
storage_class != SpvStorageClassCrossWorkgroup &&
storage_class != SpvStorageClassPrivate &&
@@ -218,7 +218,7 @@
storage_class == SpvStorageClassOutput;
bool builtin = false;
if (storage_input_or_output) {
- for (const Decoration& decoration : _.id_decorations(inst.id())) {
+ for (const Decoration& decoration : _.id_decorations(inst->id())) {
if (decoration.dec_type() == SpvDecorationBuiltIn) {
builtin = true;
break;
@@ -227,7 +227,7 @@
}
if (!(storage_input_or_output && builtin) &&
ContainsInvalidBool(_, storage, storage_input_or_output)) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "If OpTypeBool is stored in conjunction with OpVariable, it "
<< "can only be used with non-externally visible shader Storage "
<< "Classes: Workgroup, CrossWorkgroup, Private, and Function";
@@ -236,11 +236,11 @@
return SPV_SUCCESS;
}
-spv_result_t ValidateLoad(ValidationState_t& _, const Instruction& inst) {
- const auto result_type = _.FindDef(inst.type_id());
+spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
+ const auto result_type = _.FindDef(inst->type_id());
if (!result_type) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
- << "OpLoad Result Type <id> '" << _.getIdName(inst.type_id())
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
<< "' is not defined.";
}
@@ -248,7 +248,7 @@
_.features().variable_pointers ||
_.features().variable_pointers_storage_buffer;
const auto pointer_index = 2;
- const auto pointer_id = inst.GetOperandAs<uint32_t>(pointer_index);
+ const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
const auto pointer = _.FindDef(pointer_id);
if (!pointer ||
((_.addressing_model() == SpvAddressingModelLogical) &&
@@ -256,22 +256,22 @@
!spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
(uses_variable_pointers &&
!spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpLoad Pointer <id> '" << _.getIdName(pointer_id)
<< "' is not a logical pointer.";
}
const auto pointer_type = _.FindDef(pointer->type_id());
if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpLoad type for pointer <id> '" << _.getIdName(pointer_id)
<< "' is not a pointer type.";
}
const auto pointee_type = _.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
if (!pointee_type || result_type->id() != pointee_type->id()) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
- << "OpLoad Result Type <id> '" << _.getIdName(inst.type_id())
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
<< "' does not match Pointer <id> '" << _.getIdName(pointer->id())
<< "'s type.";
}
@@ -279,12 +279,12 @@
return SPV_SUCCESS;
}
-spv_result_t ValidateStore(ValidationState_t& _, const Instruction& inst) {
+spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
const bool uses_variable_pointer =
_.features().variable_pointers ||
_.features().variable_pointers_storage_buffer;
const auto pointer_index = 0;
- const auto pointer_id = inst.GetOperandAs<uint32_t>(pointer_index);
+ const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
const auto pointer = _.FindDef(pointer_id);
if (!pointer ||
(_.addressing_model() == SpvAddressingModelLogical &&
@@ -292,20 +292,20 @@
!spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
(uses_variable_pointer &&
!spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpStore Pointer <id> '" << _.getIdName(pointer_id)
<< "' is not a logical pointer.";
}
const auto pointer_type = _.FindDef(pointer->type_id());
if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpStore type for pointer <id> '" << _.getIdName(pointer_id)
<< "' is not a pointer type.";
}
const auto type_id = pointer_type->GetOperandAs<uint32_t>(2);
const auto type = _.FindDef(type_id);
if (!type || SpvOpTypeVoid == type->opcode()) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpStore Pointer <id> '" << _.getIdName(pointer_id)
<< "'s type is void.";
}
@@ -315,7 +315,7 @@
uint32_t data_type;
uint32_t storage_class;
if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpStore Pointer <id> '" << _.getIdName(pointer_id)
<< "' is not pointer type";
}
@@ -323,23 +323,23 @@
if (storage_class == SpvStorageClassUniformConstant ||
storage_class == SpvStorageClassInput ||
storage_class == SpvStorageClassPushConstant) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpStore Pointer <id> '" << _.getIdName(pointer_id)
<< "' storage class is read-only";
}
}
const auto object_index = 1;
- const auto object_id = inst.GetOperandAs<uint32_t>(object_index);
+ const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
const auto object = _.FindDef(object_id);
if (!object || !object->type_id()) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpStore Object <id> '" << _.getIdName(object_id)
<< "' is not an object.";
}
const auto object_type = _.FindDef(object->type_id());
if (!object_type || SpvOpTypeVoid == object_type->opcode()) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpStore Object <id> '" << _.getIdName(object_id)
<< "'s type is void.";
}
@@ -347,7 +347,7 @@
if (type->id() != object_type->id()) {
if (!_.options()->relax_struct_store || type->opcode() != SpvOpTypeStruct ||
object_type->opcode() != SpvOpTypeStruct) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpStore Pointer <id> '" << _.getIdName(pointer_id)
<< "'s type does not match Object <id> '"
<< _.getIdName(object->id()) << "'s type.";
@@ -355,7 +355,7 @@
// TODO: Check for layout compatible matricies and arrays as well.
if (!AreLayoutCompatibleStructs(_, type, object_type)) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpStore Pointer <id> '" << _.getIdName(pointer_id)
<< "'s layout does not match Object <id> '"
<< _.getIdName(object->id()) << "'s layout.";
@@ -364,21 +364,21 @@
return SPV_SUCCESS;
}
-spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction& inst) {
+spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {
const auto target_index = 0;
- const auto target_id = inst.GetOperandAs<uint32_t>(target_index);
+ const auto target_id = inst->GetOperandAs<uint32_t>(target_index);
const auto target = _.FindDef(target_id);
if (!target) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Target operand <id> '" << _.getIdName(target_id)
<< "' is not defined.";
}
const auto source_index = 1;
- const auto source_id = inst.GetOperandAs<uint32_t>(source_index);
+ const auto source_id = inst->GetOperandAs<uint32_t>(source_index);
const auto source = _.FindDef(source_id);
if (!source) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Source operand <id> '" << _.getIdName(source_id)
<< "' is not defined.";
}
@@ -386,7 +386,7 @@
const auto target_pointer_type = _.FindDef(target->type_id());
if (!target_pointer_type ||
target_pointer_type->opcode() != SpvOpTypePointer) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Target operand <id> '" << _.getIdName(target_id)
<< "' is not a pointer.";
}
@@ -394,16 +394,16 @@
const auto source_pointer_type = _.FindDef(source->type_id());
if (!source_pointer_type ||
source_pointer_type->opcode() != SpvOpTypePointer) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Source operand <id> '" << _.getIdName(source_id)
<< "' is not a pointer.";
}
- if (inst.opcode() == SpvOpCopyMemory) {
+ if (inst->opcode() == SpvOpCopyMemory) {
const auto target_type =
_.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
if (!target_type || target_type->opcode() == SpvOpTypeVoid) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Target operand <id> '" << _.getIdName(target_id)
<< "' cannot be a void pointer.";
}
@@ -411,29 +411,29 @@
const auto source_type =
_.FindDef(source_pointer_type->GetOperandAs<uint32_t>(2));
if (!source_type || source_type->opcode() == SpvOpTypeVoid) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Source operand <id> '" << _.getIdName(source_id)
<< "' cannot be a void pointer.";
}
if (target_type->id() != source_type->id()) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Target <id> '" << _.getIdName(source_id)
<< "'s type does not match Source <id> '"
<< _.getIdName(source_type->id()) << "'s type.";
}
} else {
- const auto size_id = inst.GetOperandAs<uint32_t>(2);
+ const auto size_id = inst->GetOperandAs<uint32_t>(2);
const auto size = _.FindDef(size_id);
if (!size) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Size operand <id> '" << _.getIdName(size_id)
<< "' is not defined.";
}
const auto size_type = _.FindDef(size->type_id());
if (!_.IsIntScalarType(size_type->id())) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Size operand <id> '" << _.getIdName(size_id)
<< "' must be a scalar integer type.";
}
@@ -441,13 +441,13 @@
bool is_zero = true;
switch (size->opcode()) {
case SpvOpConstantNull:
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Size operand <id> '" << _.getIdName(size_id)
<< "' cannot be a constant zero.";
case SpvOpConstant:
if (size_type->word(3) == 1 &&
size->word(size->words().size() - 1) & 0x80000000) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Size operand <id> '" << _.getIdName(size_id)
<< "' cannot have the sign bit set to 1.";
}
@@ -455,7 +455,7 @@
is_zero &= (size->word(i) == 0);
}
if (is_zero) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Size operand <id> '" << _.getIdName(size_id)
<< "' cannot be a constant zero.";
}
@@ -469,16 +469,16 @@
}
spv_result_t ValidateAccessChain(ValidationState_t& _,
- const Instruction& inst) {
+ const Instruction* inst) {
std::string instr_name =
- "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst.opcode())));
+ "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
// The result type must be OpTypePointer.
- auto result_type = _.FindDef(inst.type_id());
+ auto result_type = _.FindDef(inst->type_id());
if (SpvOpTypePointer != result_type->opcode()) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The Result Type of " << instr_name << " <id> '"
- << _.getIdName(inst.id()) << "' must be OpTypePointer. Found Op"
+ << _.getIdName(inst->id()) << "' must be OpTypePointer. Found Op"
<< spvOpcodeString(static_cast<SpvOp>(result_type->opcode())) << ".";
}
@@ -489,11 +489,11 @@
// Base must be a pointer, pointing to the base of a composite object.
const auto base_index = 2;
- const auto base_id = inst.GetOperandAs<uint32_t>(base_index);
+ const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
const auto base = _.FindDef(base_id);
const auto base_type = _.FindDef(base->type_id());
if (!base_type || SpvOpTypePointer != base_type->opcode()) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The Base <id> '" << _.getIdName(base_id) << "' in " << instr_name
<< " instruction must be a pointer.";
}
@@ -503,7 +503,7 @@
auto result_type_storage_class = result_type->word(2);
auto base_type_storage_class = base_type->word(2);
if (result_type_storage_class != base_type_storage_class) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The result pointer storage class and base "
"pointer storage class in "
<< instr_name << " do not match.";
@@ -515,9 +515,9 @@
// Check Universal Limit (SPIR-V Spec. Section 2.17).
// The number of indexes passed to OpAccessChain may not exceed 255
// The instruction includes 4 words + N words (for N indexes)
- size_t num_indexes = inst.words().size() - 4;
- if (inst.opcode() == SpvOpPtrAccessChain ||
- inst.opcode() == SpvOpInBoundsPtrAccessChain) {
+ size_t num_indexes = inst->words().size() - 4;
+ if (inst->opcode() == SpvOpPtrAccessChain ||
+ inst->opcode() == SpvOpInBoundsPtrAccessChain) {
// In pointer access chains, the element operand is required, but not
// counted as an index.
--num_indexes;
@@ -525,7 +525,7 @@
const size_t num_indexes_limit =
_.options()->universal_limits_.max_access_chain_indexes;
if (num_indexes > num_indexes_limit) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The number of indexes in " << instr_name << " may not exceed "
<< num_indexes_limit << ". Found " << num_indexes << " indexes.";
}
@@ -537,18 +537,18 @@
// on. Once any non-composite type is reached, there must be no remaining
// (unused) indexes.
auto starting_index = 4;
- if (inst.opcode() == SpvOpPtrAccessChain ||
- inst.opcode() == SpvOpInBoundsPtrAccessChain) {
+ if (inst->opcode() == SpvOpPtrAccessChain ||
+ inst->opcode() == SpvOpInBoundsPtrAccessChain) {
++starting_index;
}
- for (size_t i = starting_index; i < inst.words().size(); ++i) {
- const uint32_t cur_word = inst.words()[i];
+ for (size_t i = starting_index; i < inst->words().size(); ++i) {
+ const uint32_t cur_word = inst->words()[i];
// Earlier ID checks ensure that cur_word definition exists.
auto cur_word_instr = _.FindDef(cur_word);
// The index must be a scalar integer type (See OpAccessChain in the Spec.)
auto index_type = _.FindDef(cur_word_instr->type_id());
if (!index_type || SpvOpTypeInt != index_type->opcode()) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Indexes passed to " << instr_name
<< " must be of type integer.";
}
@@ -607,7 +607,7 @@
// At this point, we have fully walked down from the base using the indeces.
// The type being pointed to should be the same as the result type.
if (type_pointee->id() != result_type_pointee->id()) {
- return _.diag(SPV_ERROR_INVALID_ID, &inst)
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
<< instr_name << " result type (Op"
<< spvOpcodeString(static_cast<SpvOp>(result_type_pointee->opcode()))
<< ") does not match the type that results from indexing into the "
@@ -622,27 +622,26 @@
} // namespace
-spv_result_t ValidateMemoryInstructions(ValidationState_t& _,
- const Instruction* inst) {
+spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
switch (inst->opcode()) {
case SpvOpVariable:
- if (auto error = ValidateVariable(_, *inst)) return error;
+ if (auto error = ValidateVariable(_, inst)) return error;
break;
case SpvOpLoad:
- if (auto error = ValidateLoad(_, *inst)) return error;
+ if (auto error = ValidateLoad(_, inst)) return error;
break;
case SpvOpStore:
- if (auto error = ValidateStore(_, *inst)) return error;
+ if (auto error = ValidateStore(_, inst)) return error;
break;
case SpvOpCopyMemory:
case SpvOpCopyMemorySized:
- if (auto error = ValidateCopyMemory(_, *inst)) return error;
+ if (auto error = ValidateCopyMemory(_, inst)) return error;
break;
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
case SpvOpPtrAccessChain:
case SpvOpInBoundsPtrAccessChain:
- if (auto error = ValidateAccessChain(_, *inst)) return error;
+ if (auto error = ValidateAccessChain(_, inst)) return error;
break;
case SpvOpImageTexelPointer:
case SpvOpArrayLength: