spirv-fuzz: Avoid the type manager when looking for struct types (#3963)
Fixes #3947.
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 1947853..a691a91 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -418,6 +418,8 @@
spv_validator_options validator_options,
MessageConsumer consumer) {
if (!IsValid(ir_context, validator_options, consumer)) {
+ // Expression to dump |ir_context| to /data/temp/shader.spv:
+ // DumpShader(ir_context, "/data/temp/shader.spv")
consumer(SPV_MSG_INFO, nullptr, {},
"Module is invalid (set a breakpoint to inspect).");
return false;
@@ -1052,18 +1054,24 @@
uint32_t MaybeGetStructType(opt::IRContext* ir_context,
const std::vector<uint32_t>& component_type_ids) {
- std::vector<const opt::analysis::Type*> component_types;
- component_types.reserve(component_type_ids.size());
-
- for (auto type_id : component_type_ids) {
- const auto* component_type = ir_context->get_type_mgr()->GetType(type_id);
- assert(component_type && !component_type->AsFunction() &&
- "Component type is invalid");
- component_types.push_back(component_type);
+ for (auto& type_or_value : ir_context->types_values()) {
+ if (type_or_value.opcode() != SpvOpTypeStruct ||
+ type_or_value.NumInOperands() !=
+ static_cast<uint32_t>(component_type_ids.size())) {
+ continue;
+ }
+ bool all_components_match = true;
+ for (uint32_t i = 0; i < component_type_ids.size(); i++) {
+ if (type_or_value.GetSingleWordInOperand(i) != component_type_ids[i]) {
+ all_components_match = false;
+ break;
+ }
+ }
+ if (all_components_match) {
+ return type_or_value.result_id();
+ }
}
-
- opt::analysis::Struct type(component_types);
- return ir_context->get_type_mgr()->GetId(&type);
+ return 0;
}
uint32_t MaybeGetVoidType(opt::IRContext* ir_context) {
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index 566c535..f23826a 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -388,7 +388,8 @@
uint32_t MaybeGetVectorType(opt::IRContext* ir_context,
uint32_t component_type_id, uint32_t element_count);
-// Returns a result id of an OpTypeStruct instruction if present. Returns 0
+// Returns a result id of an OpTypeStruct instruction whose field types exactly
+// match |component_type_ids| if such an instruction is present. Returns 0
// otherwise. |component_type_ids| may not contain a result id of an
// OpTypeFunction.
uint32_t MaybeGetStructType(opt::IRContext* ir_context,
diff --git a/test/fuzz/transformation_replace_params_with_struct_test.cpp b/test/fuzz/transformation_replace_params_with_struct_test.cpp
index 58fff65..afa782e 100644
--- a/test/fuzz/transformation_replace_params_with_struct_test.cpp
+++ b/test/fuzz/transformation_replace_params_with_struct_test.cpp
@@ -163,9 +163,9 @@
.IsApplicable(context.get(), transformation_context));
// |caller_id_to_fresh_composite_id| misses values.
- ASSERT_FALSE(TransformationReplaceParamsWithStruct({16, 17}, 90, 91,
- {{33, 92}, {90, 93}})
- .IsApplicable(context.get(), transformation_context));
+ ASSERT_FALSE(
+ TransformationReplaceParamsWithStruct({16, 17}, 90, 91, {{90, 93}})
+ .IsApplicable(context.get(), transformation_context));
// All fresh ids must be unique.
ASSERT_FALSE(TransformationReplaceParamsWithStruct({16, 17}, 90, 90,
@@ -483,6 +483,45 @@
ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
}
+TEST(TransformationReplaceParamsWithStructTest, IsomorphicStructs) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %16 "main"
+ OpExecutionMode %16 OriginUpperLeft
+ OpSource ESSL 310
+ %2 = OpTypeVoid
+ %6 = OpTypeInt 32 1
+ %7 = OpTypeStruct %6
+ %8 = OpTypeStruct %6
+ %9 = OpTypeStruct %8
+ %10 = OpTypeFunction %2 %7
+ %15 = OpTypeFunction %2
+ %16 = OpFunction %2 None %15
+ %17 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %11 = OpFunction %2 None %10
+ %12 = OpFunctionParameter %7
+ %13 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_3;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ spvtools::ValidatorOptions validator_options;
+ ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
+ kConsoleMessageConsumer));
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+
+ ASSERT_FALSE(TransformationReplaceParamsWithStruct({12}, 100, 101, {{}})
+ .IsApplicable(context.get(), transformation_context));
+}
+
} // namespace
} // namespace fuzz
} // namespace spvtools