Remove duplicates from list of interface IDs in OpEntryPoint instruction (#2449)
* Remove duplicates from list of interface IDs in OpEntryPoint instruction
Fixes #2002.
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index 3274319..18d5149 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -2167,6 +2167,37 @@
};
}
+// Removes duplicate ids from the interface list of an OpEntryPoint
+// instruction.
+FoldingRule RemoveRedundantOperands() {
+ return [](IRContext*, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ assert(inst->opcode() == SpvOpEntryPoint &&
+ "Wrong opcode. Should be OpEntryPoint.");
+ bool has_redundant_operand = false;
+ std::unordered_set<uint32_t> seen_operands;
+ std::vector<Operand> new_operands;
+
+ new_operands.emplace_back(inst->GetOperand(0));
+ new_operands.emplace_back(inst->GetOperand(1));
+ new_operands.emplace_back(inst->GetOperand(2));
+ for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
+ if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
+ new_operands.emplace_back(inst->GetOperand(i));
+ } else {
+ has_redundant_operand = true;
+ }
+ }
+
+ if (!has_redundant_operand) {
+ return false;
+ }
+
+ inst->SetInOperands(std::move(new_operands));
+ return true;
+ };
+}
+
} // namespace
FoldingRules::FoldingRules() {
@@ -2183,6 +2214,8 @@
rules_[SpvOpDot].push_back(DotProductDoingExtract());
+ rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
+
rules_[SpvOpExtInst].push_back(RedundantFMix());
rules_[SpvOpFAdd].push_back(RedundantFAdd());
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index b3c3441..7458e1c 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -6211,6 +6211,106 @@
9, true)
));
+using EntryPointFoldingTest =
+::testing::TestWithParam<InstructionFoldingCase<bool>>;
+
+TEST_P(EntryPointFoldingTest, Case) {
+ const auto& tc = GetParam();
+
+ // Build module.
+ std::unique_ptr<IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.test_body,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ASSERT_NE(nullptr, context);
+
+ // Fold the instruction to test.
+ Instruction* inst = nullptr;
+ inst = &*context->module()->entry_points().begin();
+ assert(inst && "Invalid test. Could not find entry point instruction to fold.");
+ std::unique_ptr<Instruction> original_inst(inst->Clone(context.get()));
+ bool succeeded = context->get_instruction_folder().FoldInstruction(inst);
+ EXPECT_EQ(succeeded, tc.expected_result);
+ if (succeeded) {
+ Match(tc.test_body, context.get());
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(OpEntryPointFoldingTest, EntryPointFoldingTest,
+::testing::Values(
+ // Test case 0: Basic test 1
+ InstructionFoldingCase<bool>(std::string() +
+ "; CHECK: OpEntryPoint Fragment %2 \"main\" %3\n" +
+ "OpCapability Shader\n" +
+ "%1 = OpExtInstImport \"GLSL.std.450\"\n" +
+ "OpMemoryModel Logical GLSL450\n" +
+ "OpEntryPoint Fragment %2 \"main\" %3 %3 %3\n" +
+ "OpExecutionMode %2 OriginUpperLeft\n" +
+ "OpSource GLSL 430\n" +
+ "OpDecorate %3 Location 0\n" +
+ "%void = OpTypeVoid\n" +
+ "%5 = OpTypeFunction %void\n" +
+ "%float = OpTypeFloat 32\n" +
+ "%v4float = OpTypeVector %float 4\n" +
+ "%_ptr_Output_v4float = OpTypePointer Output %v4float\n" +
+ "%3 = OpVariable %_ptr_Output_v4float Output\n" +
+ "%int = OpTypeInt 32 1\n" +
+ "%int_0 = OpConstant %int 0\n" +
+"%_ptr_PushConstant_v4float = OpTypePointer PushConstant %v4float\n" +
+ "%2 = OpFunction %void None %5\n" +
+ "%12 = OpLabel\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 9, true),
+ InstructionFoldingCase<bool>(std::string() +
+ "; CHECK: OpEntryPoint Fragment %2 \"main\" %3 %4\n" +
+ "OpCapability Shader\n" +
+ "%1 = OpExtInstImport \"GLSL.std.450\"\n" +
+ "OpMemoryModel Logical GLSL450\n" +
+ "OpEntryPoint Fragment %2 \"main\" %3 %4 %3\n" +
+ "OpExecutionMode %2 OriginUpperLeft\n" +
+ "OpSource GLSL 430\n" +
+ "OpDecorate %3 Location 0\n" +
+ "%void = OpTypeVoid\n" +
+ "%5 = OpTypeFunction %void\n" +
+ "%float = OpTypeFloat 32\n" +
+ "%v4float = OpTypeVector %float 4\n" +
+ "%_ptr_Output_v4float = OpTypePointer Output %v4float\n" +
+ "%3 = OpVariable %_ptr_Output_v4float Output\n" +
+ "%4 = OpVariable %_ptr_Output_v4float Output\n" +
+ "%int = OpTypeInt 32 1\n" +
+ "%int_0 = OpConstant %int 0\n" +
+"%_ptr_PushConstant_v4float = OpTypePointer PushConstant %v4float\n" +
+ "%2 = OpFunction %void None %5\n" +
+ "%12 = OpLabel\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 9, true),
+ InstructionFoldingCase<bool>(std::string() +
+ "; CHECK: OpEntryPoint Fragment %2 \"main\" %4 %3\n" +
+ "OpCapability Shader\n" +
+ "%1 = OpExtInstImport \"GLSL.std.450\"\n" +
+ "OpMemoryModel Logical GLSL450\n" +
+ "OpEntryPoint Fragment %2 \"main\" %4 %4 %3\n" +
+ "OpExecutionMode %2 OriginUpperLeft\n" +
+ "OpSource GLSL 430\n" +
+ "OpDecorate %3 Location 0\n" +
+ "%void = OpTypeVoid\n" +
+ "%5 = OpTypeFunction %void\n" +
+ "%float = OpTypeFloat 32\n" +
+ "%v4float = OpTypeVector %float 4\n" +
+ "%_ptr_Output_v4float = OpTypePointer Output %v4float\n" +
+ "%3 = OpVariable %_ptr_Output_v4float Output\n" +
+ "%4 = OpVariable %_ptr_Output_v4float Output\n" +
+ "%int = OpTypeInt 32 1\n" +
+ "%int_0 = OpConstant %int 0\n" +
+"%_ptr_PushConstant_v4float = OpTypePointer PushConstant %v4float\n" +
+ "%2 = OpFunction %void None %5\n" +
+ "%12 = OpLabel\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd\n",
+ 9, true)
+));
+
} // namespace
} // namespace opt
} // namespace spvtools