Remove validate_datarules.cpp (#2911)
* Checks moved into individual opcode validation
* removes duplicated checks
* Add check that forward pointer points to struct
diff --git a/Android.mk b/Android.mk
index 2eb4368..cb7062c 100644
--- a/Android.mk
+++ b/Android.mk
@@ -49,7 +49,6 @@
source/val/validate_composites.cpp \
source/val/validate_constants.cpp \
source/val/validate_conversion.cpp \
- source/val/validate_datarules.cpp \
source/val/validate_debug.cpp \
source/val/validate_decorations.cpp \
source/val/validate_derivatives.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index b1ed713..a8bd134 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -410,7 +410,6 @@
"source/val/validate_composites.cpp",
"source/val/validate_constants.cpp",
"source/val/validate_conversion.cpp",
- "source/val/validate_datarules.cpp",
"source/val/validate_debug.cpp",
"source/val/validate_decorations.cpp",
"source/val/validate_derivatives.cpp",
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt
index 37348e9..a083995 100644
--- a/source/CMakeLists.txt
+++ b/source/CMakeLists.txt
@@ -283,7 +283,6 @@
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_composites.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_constants.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_conversion.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_datarules.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_debug.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_decorations.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_derivatives.cpp
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index e4123a4..7f4b0dc 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -343,7 +343,6 @@
}
if (auto error = CapabilityPass(*vstate, &instruction)) return error;
- if (auto error = DataRulesPass(*vstate, &instruction)) return error;
if (auto error = ModuleLayoutPass(*vstate, &instruction)) return error;
if (auto error = CfgPass(*vstate, &instruction)) return error;
if (auto error = InstructionPass(*vstate, &instruction)) return error;
@@ -352,6 +351,9 @@
{
Instruction* inst = const_cast<Instruction*>(&instruction);
vstate->RegisterInstruction(inst);
+ if (inst->opcode() == SpvOpTypeForwardPointer) {
+ vstate->RegisterForwardPointer(inst->GetOperandAs<uint32_t>(0));
+ }
}
}
diff --git a/source/val/validate.h b/source/val/validate.h
index da3d0b8..31a775b 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -123,11 +123,6 @@
/// Performs Id and SSA validation of a module
spv_result_t IdPass(ValidationState_t& _, Instruction* inst);
-/// Performs validation of the Data Rules subsection of 2.16.1 Universal
-/// Validation Rules.
-/// TODO(ehsann): add more comments here as more validation code is added.
-spv_result_t DataRulesPass(ValidationState_t& _, const Instruction* inst);
-
/// Performs instruction validation.
spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst);
diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp
index 42497cf..7623d49 100644
--- a/source/val/validate_builtins.cpp
+++ b/source/val/validate_builtins.cpp
@@ -2504,20 +2504,20 @@
switch (execution_model) {
case SpvExecutionModelGeometry:
case SpvExecutionModelFragment:
- case SpvExecutionModelMeshNV: {
+ case SpvExecutionModelMeshNV:
// Ok.
break;
- case SpvExecutionModelVertex:
- case SpvExecutionModelTessellationEvaluation:
- if (!_.HasCapability(SpvCapabilityShaderViewportIndexLayerEXT)) {
- return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
- << "Using BuiltIn "
- << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
- decoration.params()[0])
- << " in Vertex or Tessellation execution model requires "
- "the ShaderViewportIndexLayerEXT capability.";
- }
- break;
+ case SpvExecutionModelVertex:
+ case SpvExecutionModelTessellationEvaluation: {
+ if (!_.HasCapability(SpvCapabilityShaderViewportIndexLayerEXT)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
+ << "Using BuiltIn "
+ << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+ decoration.params()[0])
+ << " in Vertex or Tessellation execution model requires "
+ "the ShaderViewportIndexLayerEXT capability.";
+ }
+ break;
}
default: {
return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
diff --git a/source/val/validate_constants.cpp b/source/val/validate_constants.cpp
index 565518b..95ec168 100644
--- a/source/val/validate_constants.cpp
+++ b/source/val/validate_constants.cpp
@@ -342,6 +342,21 @@
return SPV_SUCCESS;
}
+// Validates that OpSpecConstant specializes to either int or float type.
+spv_result_t ValidateSpecConstant(ValidationState_t& _,
+ const Instruction* inst) {
+ // Operand 0 is the <id> of the type that we're specializing to.
+ auto type_id = inst->GetOperandAs<const uint32_t>(0);
+ auto type_instruction = _.FindDef(type_id);
+ auto type_opcode = type_instruction->opcode();
+ if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
+ "must be an integer or "
+ "floating-point number.";
+ }
+ return SPV_SUCCESS;
+}
+
spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
const Instruction* inst) {
const auto op = inst->GetOperandAs<SpvOp>(2);
@@ -422,6 +437,9 @@
case SpvOpConstantNull:
if (auto error = ValidateConstantNull(_, inst)) return error;
break;
+ case SpvOpSpecConstant:
+ if (auto error = ValidateSpecConstant(_, inst)) return error;
+ break;
case SpvOpSpecConstantOp:
if (auto error = ValidateSpecConstantOp(_, inst)) return error;
break;
diff --git a/source/val/validate_datarules.cpp b/source/val/validate_datarules.cpp
deleted file mode 100644
index 826eb8d..0000000
--- a/source/val/validate_datarules.cpp
+++ /dev/null
@@ -1,286 +0,0 @@
-// Copyright (c) 2016 Google Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Ensures Data Rules are followed according to the specifications.
-
-#include "source/val/validate.h"
-
-#include <cassert>
-#include <sstream>
-#include <string>
-
-#include "source/diagnostic.h"
-#include "source/opcode.h"
-#include "source/operand.h"
-#include "source/val/instruction.h"
-#include "source/val/validation_state.h"
-
-namespace spvtools {
-namespace val {
-namespace {
-
-// Validates that the number of components in the vector is valid.
-// Vector types can only be parameterized as having 2, 3, or 4 components.
-// If the Vector16 capability is added, 8 and 16 components are also allowed.
-spv_result_t ValidateVecNumComponents(ValidationState_t& _,
- const Instruction* inst) {
- // Operand 2 specifies the number of components in the vector.
- auto num_components = inst->GetOperandAs<const uint32_t>(2);
- if (num_components == 2 || num_components == 3 || num_components == 4) {
- return SPV_SUCCESS;
- }
- if (num_components == 8 || num_components == 16) {
- if (_.HasCapability(SpvCapabilityVector16)) {
- return SPV_SUCCESS;
- }
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Having " << num_components << " components for "
- << spvOpcodeString(inst->opcode())
- << " requires the Vector16 capability";
- }
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Illegal number of components (" << num_components << ") for "
- << spvOpcodeString(inst->opcode());
-}
-
-// Validates that the number of bits specifed for a float type is valid.
-// Scalar floating-point types can be parameterized only with 32-bits.
-// Float16 capability allows using a 16-bit OpTypeFloat.
-// Float16Buffer capability allows creation of a 16-bit OpTypeFloat.
-// Float64 capability allows using a 64-bit OpTypeFloat.
-spv_result_t ValidateFloatSize(ValidationState_t& _, const Instruction* inst) {
- // Operand 1 is the number of bits for this float
- auto num_bits = inst->GetOperandAs<const uint32_t>(1);
- if (num_bits == 32) {
- return SPV_SUCCESS;
- }
- if (num_bits == 16) {
- if (_.features().declare_float16_type) {
- return SPV_SUCCESS;
- }
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Using a 16-bit floating point "
- << "type requires the Float16 or Float16Buffer capability,"
- " or an extension that explicitly enables 16-bit floating point.";
- }
- if (num_bits == 64) {
- if (_.HasCapability(SpvCapabilityFloat64)) {
- return SPV_SUCCESS;
- }
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Using a 64-bit floating point "
- << "type requires the Float64 capability.";
- }
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Invalid number of bits (" << num_bits << ") used for OpTypeFloat.";
-}
-
-// Validates that the number of bits specified for an Int type is valid.
-// Scalar integer types can be parameterized only with 32-bits.
-// Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
-// integers, respectively.
-spv_result_t ValidateIntSize(ValidationState_t& _, const Instruction* inst) {
- // Operand 1 is the number of bits for this integer.
- auto num_bits = inst->GetOperandAs<const uint32_t>(1);
- if (num_bits == 32) {
- return SPV_SUCCESS;
- }
- if (num_bits == 8) {
- if (_.features().declare_int8_type) {
- return SPV_SUCCESS;
- }
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Using an 8-bit integer type requires the Int8 capability,"
- " or an extension that explicitly enables 8-bit integers.";
- }
- if (num_bits == 16) {
- if (_.features().declare_int16_type) {
- return SPV_SUCCESS;
- }
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Using a 16-bit integer type requires the Int16 capability,"
- " or an extension that explicitly enables 16-bit integers.";
- }
- if (num_bits == 64) {
- if (_.HasCapability(SpvCapabilityInt64)) {
- return SPV_SUCCESS;
- }
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Using a 64-bit integer type requires the Int64 capability.";
- }
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Invalid number of bits (" << num_bits << ") used for OpTypeInt.";
-}
-
-// Validates that the matrix is parameterized with floating-point types.
-spv_result_t ValidateMatrixColumnType(ValidationState_t& _,
- const Instruction* inst) {
- // Find the component type of matrix columns (must be vector).
- // Operand 1 is the <id> of the type specified for matrix columns.
- auto type_id = inst->GetOperandAs<const uint32_t>(1);
- auto col_type_instr = _.FindDef(type_id);
- if (col_type_instr->opcode() != SpvOpTypeVector) {
- return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "Columns in a matrix must be of type vector.";
- }
-
- // Trace back once more to find out the type of components in the vector.
- // Operand 1 is the <id> of the type of data in the vector.
- auto comp_type_id =
- col_type_instr->words()[col_type_instr->operands()[1].offset];
- auto comp_type_instruction = _.FindDef(comp_type_id);
- if (comp_type_instruction->opcode() != SpvOpTypeFloat) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
- "parameterized with "
- "floating-point types.";
- }
- return SPV_SUCCESS;
-}
-
-// Validates that the matrix has 2,3, or 4 columns.
-spv_result_t ValidateMatrixNumCols(ValidationState_t& _,
- const Instruction* inst) {
- // Operand 2 is the number of columns in the matrix.
- auto num_cols = inst->GetOperandAs<const uint32_t>(2);
- if (num_cols != 2 && num_cols != 3 && num_cols != 4) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
- "parameterized as having "
- "only 2, 3, or 4 columns.";
- }
- return SPV_SUCCESS;
-}
-
-// Validates that OpSpecConstant specializes to either int or float type.
-spv_result_t ValidateSpecConstNumerical(ValidationState_t& _,
- const Instruction* inst) {
- // Operand 0 is the <id> of the type that we're specializing to.
- auto type_id = inst->GetOperandAs<const uint32_t>(0);
- auto type_instruction = _.FindDef(type_id);
- auto type_opcode = type_instruction->opcode();
- if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
- "must be an integer or "
- "floating-point number.";
- }
- return SPV_SUCCESS;
-}
-
-// Validates that OpSpecConstantTrue and OpSpecConstantFalse specialize to bool.
-spv_result_t ValidateSpecConstBoolean(ValidationState_t& _,
- const Instruction* inst) {
- // Find out the type that we're specializing to.
- auto type_instruction = _.FindDef(inst->type_id());
- if (type_instruction->opcode() != SpvOpTypeBool) {
- return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "Specialization constant must be a boolean type.";
- }
- return SPV_SUCCESS;
-}
-
-// Records the <id> of the forward pointer to be used for validation.
-spv_result_t ValidateForwardPointer(ValidationState_t& _,
- const Instruction* inst) {
- // Record the <id> (which is operand 0) to ensure it's used properly.
- // OpTypeStruct can only include undefined pointers that are
- // previously declared as a ForwardPointer
- return (_.RegisterForwardPointer(inst->GetOperandAs<uint32_t>(0)));
-}
-
-// Validates that any undefined component of the struct is a forward pointer.
-// It is valid to declare a forward pointer, and use its <id> as one of the
-// components of a struct.
-spv_result_t ValidateStruct(ValidationState_t& _, const Instruction* inst) {
- // Struct components are operands 1, 2, etc.
- for (unsigned i = 1; i < inst->operands().size(); i++) {
- auto type_id = inst->GetOperandAs<const uint32_t>(i);
- auto type_instruction = _.FindDef(type_id);
- if (type_instruction == nullptr && !_.IsForwardPointer(type_id)) {
- return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "Forward reference operands in an OpTypeStruct must first be "
- "declared using OpTypeForwardPointer.";
- }
- }
- return SPV_SUCCESS;
-}
-
-// Validates that any undefined type of the array is a forward pointer.
-// It is valid to declare a forward pointer, and use its <id> as the element
-// type of the array.
-spv_result_t ValidateArray(ValidationState_t& _, const Instruction* inst) {
- auto element_type_id = inst->GetOperandAs<const uint32_t>(1);
- auto element_type_instruction = _.FindDef(element_type_id);
- if (element_type_instruction == nullptr &&
- !_.IsForwardPointer(element_type_id)) {
- return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "Forward reference operands in an OpTypeArray must first be "
- "declared using OpTypeForwardPointer.";
- }
- return SPV_SUCCESS;
-}
-
-} // namespace
-
-// Validates that Data Rules are followed according to the specifications.
-// (Data Rules subsection of 2.16.1 Universal Validation Rules)
-spv_result_t DataRulesPass(ValidationState_t& _, const Instruction* inst) {
- switch (inst->opcode()) {
- case SpvOpTypeVector: {
- if (auto error = ValidateVecNumComponents(_, inst)) return error;
- break;
- }
- case SpvOpTypeFloat: {
- if (auto error = ValidateFloatSize(_, inst)) return error;
- break;
- }
- case SpvOpTypeInt: {
- if (auto error = ValidateIntSize(_, inst)) return error;
- break;
- }
- case SpvOpTypeMatrix: {
- if (auto error = ValidateMatrixColumnType(_, inst)) return error;
- if (auto error = ValidateMatrixNumCols(_, inst)) return error;
- break;
- }
- // TODO(ehsan): Add OpSpecConstantComposite validation code.
- // TODO(ehsan): Add OpSpecConstantOp validation code (if any).
- case SpvOpSpecConstant: {
- if (auto error = ValidateSpecConstNumerical(_, inst)) return error;
- break;
- }
- case SpvOpSpecConstantFalse:
- case SpvOpSpecConstantTrue: {
- if (auto error = ValidateSpecConstBoolean(_, inst)) return error;
- break;
- }
- case SpvOpTypeForwardPointer: {
- if (auto error = ValidateForwardPointer(_, inst)) return error;
- break;
- }
- case SpvOpTypeStruct: {
- if (auto error = ValidateStruct(_, inst)) return error;
- break;
- }
- case SpvOpTypeArray: {
- if (auto error = ValidateArray(_, inst)) return error;
- break;
- }
- // TODO(ehsan): add more data rules validation here.
- default: { break; }
- }
-
- return SPV_SUCCESS;
-}
-
-} // namespace val
-} // namespace spvtools
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index afc0656..d3872da 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -67,6 +67,39 @@
}
spv_result_t ValidateTypeInt(ValidationState_t& _, const Instruction* inst) {
+ // Validates that the number of bits specified for an Int type is valid.
+ // Scalar integer types can be parameterized only with 32-bits.
+ // Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
+ // integers, respectively.
+ auto num_bits = inst->GetOperandAs<const uint32_t>(1);
+ if (num_bits != 32) {
+ if (num_bits == 8) {
+ if (_.features().declare_int8_type) {
+ return SPV_SUCCESS;
+ }
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Using an 8-bit integer type requires the Int8 capability,"
+ " or an extension that explicitly enables 8-bit integers.";
+ } else if (num_bits == 16) {
+ if (_.features().declare_int16_type) {
+ return SPV_SUCCESS;
+ }
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Using a 16-bit integer type requires the Int16 capability,"
+ " or an extension that explicitly enables 16-bit integers.";
+ } else if (num_bits == 64) {
+ if (_.HasCapability(SpvCapabilityInt64)) {
+ return SPV_SUCCESS;
+ }
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Using a 64-bit integer type requires the Int64 capability.";
+ } else {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Invalid number of bits (" << num_bits
+ << ") used for OpTypeInt.";
+ }
+ }
+
const auto signedness_index = 2;
const auto signedness = inst->GetOperandAs<uint32_t>(signedness_index);
if (signedness != 0 && signedness != 1) {
@@ -76,6 +109,36 @@
return SPV_SUCCESS;
}
+spv_result_t ValidateTypeFloat(ValidationState_t& _, const Instruction* inst) {
+ // Validates that the number of bits specified for an Int type is valid.
+ // Scalar integer types can be parameterized only with 32-bits.
+ // Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
+ // integers, respectively.
+ auto num_bits = inst->GetOperandAs<const uint32_t>(1);
+ if (num_bits == 32) {
+ return SPV_SUCCESS;
+ }
+ if (num_bits == 16) {
+ if (_.features().declare_float16_type) {
+ return SPV_SUCCESS;
+ }
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Using a 16-bit floating point "
+ << "type requires the Float16 or Float16Buffer capability,"
+ " or an extension that explicitly enables 16-bit floating point.";
+ }
+ if (num_bits == 64) {
+ if (_.HasCapability(SpvCapabilityFloat64)) {
+ return SPV_SUCCESS;
+ }
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Using a 64-bit floating point "
+ << "type requires the Float64 capability.";
+ }
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Invalid number of bits (" << num_bits << ") used for OpTypeFloat.";
+}
+
spv_result_t ValidateTypeVector(ValidationState_t& _, const Instruction* inst) {
const auto component_index = 1;
const auto component_id = inst->GetOperandAs<uint32_t>(component_index);
@@ -85,6 +148,27 @@
<< "OpTypeVector Component Type <id> '" << _.getIdName(component_id)
<< "' is not a scalar type.";
}
+
+ // Validates that the number of components in the vector is valid.
+ // Vector types can only be parameterized as having 2, 3, or 4 components.
+ // If the Vector16 capability is added, 8 and 16 components are also allowed.
+ auto num_components = inst->GetOperandAs<const uint32_t>(2);
+ if (num_components == 2 || num_components == 3 || num_components == 4) {
+ return SPV_SUCCESS;
+ } else if (num_components == 8 || num_components == 16) {
+ if (_.HasCapability(SpvCapabilityVector16)) {
+ return SPV_SUCCESS;
+ }
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Having " << num_components << " components for "
+ << spvOpcodeString(inst->opcode())
+ << " requires the Vector16 capability";
+ } else {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Illegal number of components (" << num_components << ") for "
+ << spvOpcodeString(inst->opcode());
+ }
+
return SPV_SUCCESS;
}
@@ -94,9 +178,27 @@
const auto column_type = _.FindDef(column_type_id);
if (!column_type || SpvOpTypeVector != column_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "OpTypeMatrix Column Type <id> '" << _.getIdName(column_type_id)
- << "' is not a vector.";
+ << "Columns in a matrix must be of type vector.";
}
+
+ // Trace back once more to find out the type of components in the vector.
+ // Operand 1 is the <id> of the type of data in the vector.
+ const auto comp_type_id = column_type->GetOperandAs<uint32_t>(1);
+ auto comp_type_instruction = _.FindDef(comp_type_id);
+ if (comp_type_instruction->opcode() != SpvOpTypeFloat) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
+ "parameterized with "
+ "floating-point types.";
+ }
+
+ // Validates that the matrix has 2,3, or 4 columns.
+ auto num_cols = inst->GetOperandAs<const uint32_t>(2);
+ if (num_cols != 2 && num_cols != 3 && num_cols != 4) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
+ "parameterized as having "
+ "only 2, 3, or 4 columns.";
+ }
+
return SPV_SUCCESS;
}
@@ -224,6 +326,11 @@
for (size_t member_type_index = 1;
member_type_index < inst->operands().size(); ++member_type_index) {
auto member_type_id = inst->GetOperandAs<uint32_t>(member_type_index);
+ if (member_type_id == inst->id()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "Structure members may not be self references";
+ }
+
auto member_type = _.FindDef(member_type_id);
if (!member_type || !spvOpcodeGeneratesType(member_type->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
@@ -245,22 +352,6 @@
<< " contains structure <id> " << _.getIdName(member_type_id)
<< ".";
}
- if (_.IsForwardPointer(member_type_id)) {
- // If we're dealing with a forward pointer:
- // Find out the type that the pointer is pointing to (must be struct)
- // word 3 is the <id> of the type being pointed to.
- auto type_pointing_to = _.FindDef(member_type->words()[3]);
- if (type_pointing_to && type_pointing_to->opcode() != SpvOpTypeStruct) {
- // Forward declared operands of a struct may only point to a struct.
- return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "A forward reference operand in an OpTypeStruct must be an "
- "OpTypePointer that points to an OpTypeStruct. "
- "Found OpTypePointer that points to Op"
- << spvOpcodeString(
- static_cast<SpvOp>(type_pointing_to->opcode()))
- << ".";
- }
- }
if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
member_type->opcode() == SpvOpTypeRuntimeArray) {
@@ -356,7 +447,6 @@
}
return SPV_SUCCESS;
}
-} // namespace
spv_result_t ValidateTypeFunction(ValidationState_t& _,
const Instruction* inst) {
@@ -425,6 +515,13 @@
<< "pointer definition.";
}
+ const auto pointee_type_id = pointer_type_inst->GetOperandAs<uint32_t>(2);
+ const auto pointee_type = _.FindDef(pointee_type_id);
+ if (!pointee_type || pointee_type->opcode() != SpvOpTypeStruct) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "Forward pointers must point to a structure";
+ }
+
return SPV_SUCCESS;
}
@@ -474,6 +571,7 @@
return SPV_SUCCESS;
}
+} // namespace
spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
if (!spvOpcodeGeneratesType(inst->opcode()) &&
@@ -487,6 +585,9 @@
case SpvOpTypeInt:
if (auto error = ValidateTypeInt(_, inst)) return error;
break;
+ case SpvOpTypeFloat:
+ if (auto error = ValidateTypeFloat(_, inst)) return error;
+ break;
case SpvOpTypeVector:
if (auto error = ValidateTypeVector(_, inst)) return error;
break;
diff --git a/test/val/val_data_test.cpp b/test/val/val_data_test.cpp
index 4690f97..0178fa7 100644
--- a/test/val/val_data_test.cpp
+++ b/test/val/val_data_test.cpp
@@ -629,15 +629,26 @@
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
-TEST_F(ValidateData, specialize_boolean_to_int) {
+TEST_F(ValidateData, specialize_boolean_true_to_int) {
std::string str = header + R"(
%2 = OpTypeInt 32 1
-%3 = OpSpecConstantTrue %2
+%3 = OpSpecConstantTrue %2)";
+ CompileSuccessfully(str.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("OpSpecConstantTrue Result Type <id> '1[%int]' is not "
+ "a boolean type"));
+}
+
+TEST_F(ValidateData, specialize_boolean_false_to_int) {
+ std::string str = header + R"(
+%2 = OpTypeInt 32 1
%4 = OpSpecConstantFalse %2)";
CompileSuccessfully(str.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
- HasSubstr("Specialization constant must be a boolean"));
+ HasSubstr("OpSpecConstantFalse Result Type <id> '1[%int]' is not "
+ "a boolean type"));
}
TEST_F(ValidateData, missing_forward_pointer_decl) {
@@ -647,8 +658,9 @@
)";
CompileSuccessfully(str.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
- EXPECT_THAT(getDiagnosticString(),
- HasSubstr("must first be declared using OpTypeForwardPointer"));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("The following forward referenced IDs have not been defined:"));
}
TEST_F(ValidateData, missing_forward_pointer_decl_self_reference) {
@@ -659,7 +671,7 @@
CompileSuccessfully(str.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
- HasSubstr("must first be declared using OpTypeForwardPointer"));
+ HasSubstr("Structure members may not be self references"));
}
TEST_F(ValidateData, forward_pointer_missing_definition) {
@@ -698,9 +710,7 @@
CompileSuccessfully(str.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
- HasSubstr("A forward reference operand in an OpTypeStruct must "
- "be an OpTypePointer that points to an OpTypeStruct. "
- "Found OpTypePointer that points to OpTypeInt."));
+ HasSubstr("Forward pointers must point to a structure"));
}
TEST_F(ValidateData, struct_forward_pointer_good) {
@@ -934,23 +944,6 @@
"OpTypeStruct %_runtimearr_uint %uint\n"));
}
-TEST_F(ValidateData, invalid_forward_reference_in_array) {
- std::string str = R"(
- OpCapability Shader
- OpCapability Linkage
- OpMemoryModel Logical GLSL450
- %uint = OpTypeInt 32 0
-%uint_1 = OpConstant %uint 1
-%_arr_3_uint_1 = OpTypeArray %_arr_3_uint_1 %uint_1
-)";
-
- CompileSuccessfully(str.c_str(), SPV_ENV_UNIVERSAL_1_3);
- ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3));
- EXPECT_THAT(getDiagnosticString(),
- HasSubstr("Forward reference operands in an OpTypeArray must "
- "first be declared using OpTypeForwardPointer."));
-}
-
} // namespace
} // namespace val
} // namespace spvtools
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 4c2a73b..43e3197 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -1476,7 +1476,8 @@
CompileSuccessfully(spirv.c_str());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
- HasSubstr("Specialization constant must be a boolean type."));
+ HasSubstr("OpSpecConstantTrue Result Type <id> '1[%void]' is not "
+ "a boolean type"));
}
TEST_F(ValidateIdWithMessage, OpSpecConstantFalseGood) {
@@ -1492,8 +1493,10 @@
%2 = OpSpecConstantFalse %1)";
CompileSuccessfully(spirv.c_str());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
- EXPECT_THAT(getDiagnosticString(),
- HasSubstr("Specialization constant must be a boolean type."));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("OpSpecConstantFalse Result Type <id> '1[%void]' is not "
+ "a boolean type"));
}
TEST_F(ValidateIdWithMessage, OpSpecConstantGood) {
@@ -5175,9 +5178,9 @@
)";
CompileSuccessfully(spirv.c_str());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
- EXPECT_THAT(getDiagnosticString(),
- HasSubstr("Forward reference operands in an OpTypeStruct must "
- "first be declared using OpTypeForwardPointer."));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("The following forward referenced IDs have not been defined"));
}
TEST_F(ValidateIdWithMessage, UndefinedIdScope) {
diff --git a/test/val/val_ssa_test.cpp b/test/val/val_ssa_test.cpp
index e10abd7..035c710 100644
--- a/test/val/val_ssa_test.cpp
+++ b/test/val/val_ssa_test.cpp
@@ -1415,7 +1415,8 @@
OpName %intptrt "intptrt"
OpTypeForwardPointer %intptrt UniformConstant
%uint = OpTypeInt 32 0
- %intptrt = OpTypePointer UniformConstant %uint
+ %struct = OpTypeStruct %uint
+ %intptrt = OpTypePointer UniformConstant %struct
)";
CompileSuccessfully(str);
diff --git a/test/val/val_type_unique_test.cpp b/test/val/val_type_unique_test.cpp
index 67ceadd..45a4d50 100644
--- a/test/val/val_type_unique_test.cpp
+++ b/test/val/val_type_unique_test.cpp
@@ -210,9 +210,11 @@
OpTypeForwardPointer %ptr Generic
OpTypeForwardPointer %ptr2 Generic
%intt = OpTypeInt 32 0
+%int_struct = OpTypeStruct %intt
%floatt = OpTypeFloat 32
-%ptr = OpTypePointer Generic %intt
-%ptr2 = OpTypePointer Generic %floatt
+%ptr = OpTypePointer Generic %int_struct
+%float_struct = OpTypeStruct %floatt
+%ptr2 = OpTypePointer Generic %float_struct
)";
CompileSuccessfully(str.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());