Fix reachability in the validator (#3541)
Fixes #3529
* Make BasicBlock::reachable() only consider static reachability
* Fix reachability calculation to be independent of block order
* add tests
diff --git a/source/val/basic_block.cpp b/source/val/basic_block.cpp
index a53103c..b2a8793 100644
--- a/source/val/basic_block.cpp
+++ b/source/val/basic_block.cpp
@@ -58,15 +58,9 @@
for (auto& block : next_blocks) {
block->predecessors_.push_back(this);
successors_.push_back(block);
- if (block->reachable_ == false) block->set_reachable(reachable_);
}
}
-void BasicBlock::RegisterBranchInstruction(SpvOp branch_instruction) {
- if (branch_instruction == SpvOpUnreachable) reachable_ = false;
- return;
-}
-
bool BasicBlock::dominates(const BasicBlock& other) const {
return (this == &other) ||
!(other.dom_end() ==
diff --git a/source/val/basic_block.h b/source/val/basic_block.h
index 876105c..5eea4f9 100644
--- a/source/val/basic_block.h
+++ b/source/val/basic_block.h
@@ -106,9 +106,6 @@
/// Returns the immedate post dominator of this basic block
const BasicBlock* immediate_post_dominator() const;
- /// Ends the block without a successor
- void RegisterBranchInstruction(SpvOp branch_instruction);
-
/// Returns the label instruction for the block, or nullptr if not set.
const Instruction* label() const { return label_; }
diff --git a/source/val/function.cpp b/source/val/function.cpp
index 0281770..249c866 100644
--- a/source/val/function.cpp
+++ b/source/val/function.cpp
@@ -130,7 +130,6 @@
undefined_blocks_.erase(block_id);
current_block_ = &inserted_block->second;
ordered_blocks_.push_back(current_block_);
- if (IsFirstBlock(block_id)) current_block_->set_reachable(true);
} else if (success) { // Block doesn't exsist but this is not a definition
undefined_blocks_.insert(block_id);
}
@@ -138,8 +137,7 @@
return SPV_SUCCESS;
}
-void Function::RegisterBlockEnd(std::vector<uint32_t> next_list,
- SpvOp branch_instruction) {
+void Function::RegisterBlockEnd(std::vector<uint32_t> next_list) {
assert(
current_block_ &&
"RegisterBlockEnd can only be called when parsing a binary in a block");
@@ -174,7 +172,6 @@
}
}
- current_block_->RegisterBranchInstruction(branch_instruction);
current_block_->RegisterSuccessors(next_blocks);
current_block_ = nullptr;
return;
diff --git a/source/val/function.h b/source/val/function.h
index 0d6873d..400bb63 100644
--- a/source/val/function.h
+++ b/source/val/function.h
@@ -97,9 +97,7 @@
/// Registers the end of the block
///
/// @param[in] successors_list A list of ids to the block's successors
- /// @param[in] branch_instruction the branch instruction that ended the block
- void RegisterBlockEnd(std::vector<uint32_t> successors_list,
- SpvOp branch_instruction);
+ void RegisterBlockEnd(std::vector<uint32_t> successors_list);
/// Registers the end of the function. This is idempotent.
void RegisterFunctionEnd();
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index 168968d..f964b9b 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -368,6 +368,10 @@
// Catch undefined forward references before performing further checks.
if (auto error = ValidateForwardDecls(*vstate)) return error;
+ // Calculate reachability after all the blocks are parsed, but early that it
+ // can be relied on in subsequent pases.
+ ReachabilityPass(*vstate);
+
// ID usage needs be handled in its own iteration of the instructions,
// between the two others. It depends on the first loop to have been
// finished, so that all instructions have been registered. And the following
diff --git a/source/val/validate.h b/source/val/validate.h
index 31a775b..3fc183d 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -197,6 +197,9 @@
/// Validates correctness of miscellaneous instructions.
spv_result_t MiscPass(ValidationState_t& _, const Instruction* inst);
+/// Calculates the reachability of basic blocks.
+void ReachabilityPass(ValidationState_t& _);
+
/// Validates execution limitations.
///
/// Verifies execution models are allowed for all functionality they contain.
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index 1e33e51..a2fe882 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -1062,7 +1062,7 @@
uint32_t target = inst->GetOperandAs<uint32_t>(0);
CFG_ASSERT(FirstBlockAssert, target);
- _.current_function().RegisterBlockEnd({target}, opcode);
+ _.current_function().RegisterBlockEnd({target});
} break;
case SpvOpBranchConditional: {
uint32_t tlabel = inst->GetOperandAs<uint32_t>(1);
@@ -1070,7 +1070,7 @@
CFG_ASSERT(FirstBlockAssert, tlabel);
CFG_ASSERT(FirstBlockAssert, flabel);
- _.current_function().RegisterBlockEnd({tlabel, flabel}, opcode);
+ _.current_function().RegisterBlockEnd({tlabel, flabel});
} break;
case SpvOpSwitch: {
@@ -1080,7 +1080,7 @@
CFG_ASSERT(FirstBlockAssert, target);
cases.push_back(target);
}
- _.current_function().RegisterBlockEnd({cases}, opcode);
+ _.current_function().RegisterBlockEnd({cases});
} break;
case SpvOpReturn: {
const uint32_t return_type = _.current_function().GetResultTypeId();
@@ -1090,13 +1090,13 @@
return _.diag(SPV_ERROR_INVALID_CFG, inst)
<< "OpReturn can only be called from a function with void "
<< "return type.";
- _.current_function().RegisterBlockEnd(std::vector<uint32_t>(), opcode);
+ _.current_function().RegisterBlockEnd(std::vector<uint32_t>());
break;
}
case SpvOpKill:
case SpvOpReturnValue:
case SpvOpUnreachable:
- _.current_function().RegisterBlockEnd(std::vector<uint32_t>(), opcode);
+ _.current_function().RegisterBlockEnd(std::vector<uint32_t>());
if (opcode == SpvOpKill) {
_.current_function().RegisterExecutionModelLimitation(
SpvExecutionModelFragment,
@@ -1109,6 +1109,27 @@
return SPV_SUCCESS;
}
+void ReachabilityPass(ValidationState_t& _) {
+ for (auto& f : _.functions()) {
+ std::vector<BasicBlock*> stack;
+ auto entry = f.first_block();
+ // Skip function declarations.
+ if (entry) stack.push_back(entry);
+
+ while (!stack.empty()) {
+ auto block = stack.back();
+ stack.pop_back();
+
+ if (block->reachable()) continue;
+
+ block->set_reachable(true);
+ for (auto succ : *block->successors()) {
+ stack.push_back(succ);
+ }
+ }
+ }
+}
+
spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) {
switch (inst->opcode()) {
case SpvOpPhi:
diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp
index 0d09642..630a19d 100644
--- a/test/val/val_cfg_test.cpp
+++ b/test/val/val_cfg_test.cpp
@@ -1204,14 +1204,6 @@
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
-TEST_F(ValidateCFG, WebGPUUnreachableMergeWithBranchUse) {
- CompileSuccessfully(
- GetUnreachableMergeWithBranchUse(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
- ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
- EXPECT_THAT(getDiagnosticString(),
- HasSubstr("cannot be the target of a branch."));
-}
-
std::string GetUnreachableMergeWithMultipleUses(SpvCapability cap,
spv_target_env env) {
std::string header =
@@ -4503,6 +4495,76 @@
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateCFG, UnreachableIsStaticallyReachable) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeVoid
+%2 = OpTypeFunction %1
+%3 = OpFunction %1 None %2
+%4 = OpLabel
+OpBranch %5
+%5 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(text);
+ EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+
+ auto f = vstate_->function(3);
+ auto entry = f->GetBlock(4).first;
+ ASSERT_TRUE(entry->reachable());
+ auto end = f->GetBlock(5).first;
+ ASSERT_TRUE(end->reachable());
+}
+
+TEST_F(ValidateCFG, BlockOrderDoesNotAffectReachability) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeVoid
+%2 = OpTypeFunction %1
+%3 = OpTypeBool
+%4 = OpUndef %3
+%5 = OpFunction %1 None %2
+%6 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpSelectionMerge %8 None
+OpBranchConditional %4 %9 %10
+%8 = OpLabel
+OpReturn
+%9 = OpLabel
+OpBranch %8
+%10 = OpLabel
+OpBranch %8
+%11 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(text);
+ EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+
+ auto f = vstate_->function(5);
+ auto b6 = f->GetBlock(6).first;
+ auto b7 = f->GetBlock(7).first;
+ auto b8 = f->GetBlock(8).first;
+ auto b9 = f->GetBlock(9).first;
+ auto b10 = f->GetBlock(10).first;
+ auto b11 = f->GetBlock(11).first;
+
+ ASSERT_TRUE(b6->reachable());
+ ASSERT_TRUE(b7->reachable());
+ ASSERT_TRUE(b8->reachable());
+ ASSERT_TRUE(b9->reachable());
+ ASSERT_TRUE(b10->reachable());
+ ASSERT_FALSE(b11->reachable());
+}
+
} // namespace
} // namespace val
} // namespace spvtools