Fix reachability in the validator (#3541)

Fixes #3529

* Make BasicBlock::reachable() only consider static reachability
* Fix reachability calculation to be independent of block order
* add tests
diff --git a/source/val/basic_block.cpp b/source/val/basic_block.cpp
index a53103c..b2a8793 100644
--- a/source/val/basic_block.cpp
+++ b/source/val/basic_block.cpp
@@ -58,15 +58,9 @@
   for (auto& block : next_blocks) {
     block->predecessors_.push_back(this);
     successors_.push_back(block);
-    if (block->reachable_ == false) block->set_reachable(reachable_);
   }
 }
 
-void BasicBlock::RegisterBranchInstruction(SpvOp branch_instruction) {
-  if (branch_instruction == SpvOpUnreachable) reachable_ = false;
-  return;
-}
-
 bool BasicBlock::dominates(const BasicBlock& other) const {
   return (this == &other) ||
          !(other.dom_end() ==
diff --git a/source/val/basic_block.h b/source/val/basic_block.h
index 876105c..5eea4f9 100644
--- a/source/val/basic_block.h
+++ b/source/val/basic_block.h
@@ -106,9 +106,6 @@
   /// Returns the immedate post dominator of this basic block
   const BasicBlock* immediate_post_dominator() const;
 
-  /// Ends the block without a successor
-  void RegisterBranchInstruction(SpvOp branch_instruction);
-
   /// Returns the label instruction for the block, or nullptr if not set.
   const Instruction* label() const { return label_; }
 
diff --git a/source/val/function.cpp b/source/val/function.cpp
index 0281770..249c866 100644
--- a/source/val/function.cpp
+++ b/source/val/function.cpp
@@ -130,7 +130,6 @@
     undefined_blocks_.erase(block_id);
     current_block_ = &inserted_block->second;
     ordered_blocks_.push_back(current_block_);
-    if (IsFirstBlock(block_id)) current_block_->set_reachable(true);
   } else if (success) {  // Block doesn't exsist but this is not a definition
     undefined_blocks_.insert(block_id);
   }
@@ -138,8 +137,7 @@
   return SPV_SUCCESS;
 }
 
-void Function::RegisterBlockEnd(std::vector<uint32_t> next_list,
-                                SpvOp branch_instruction) {
+void Function::RegisterBlockEnd(std::vector<uint32_t> next_list) {
   assert(
       current_block_ &&
       "RegisterBlockEnd can only be called when parsing a binary in a block");
@@ -174,7 +172,6 @@
     }
   }
 
-  current_block_->RegisterBranchInstruction(branch_instruction);
   current_block_->RegisterSuccessors(next_blocks);
   current_block_ = nullptr;
   return;
diff --git a/source/val/function.h b/source/val/function.h
index 0d6873d..400bb63 100644
--- a/source/val/function.h
+++ b/source/val/function.h
@@ -97,9 +97,7 @@
   /// Registers the end of the block
   ///
   /// @param[in] successors_list A list of ids to the block's successors
-  /// @param[in] branch_instruction the branch instruction that ended the block
-  void RegisterBlockEnd(std::vector<uint32_t> successors_list,
-                        SpvOp branch_instruction);
+  void RegisterBlockEnd(std::vector<uint32_t> successors_list);
 
   /// Registers the end of the function.  This is idempotent.
   void RegisterFunctionEnd();
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index 168968d..f964b9b 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -368,6 +368,10 @@
   // Catch undefined forward references before performing further checks.
   if (auto error = ValidateForwardDecls(*vstate)) return error;
 
+  // Calculate reachability after all the blocks are parsed, but early that it
+  // can be relied on in subsequent pases.
+  ReachabilityPass(*vstate);
+
   // ID usage needs be handled in its own iteration of the instructions,
   // between the two others. It depends on the first loop to have been
   // finished, so that all instructions have been registered. And the following
diff --git a/source/val/validate.h b/source/val/validate.h
index 31a775b..3fc183d 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -197,6 +197,9 @@
 /// Validates correctness of miscellaneous instructions.
 spv_result_t MiscPass(ValidationState_t& _, const Instruction* inst);
 
+/// Calculates the reachability of basic blocks.
+void ReachabilityPass(ValidationState_t& _);
+
 /// Validates execution limitations.
 ///
 /// Verifies execution models are allowed for all functionality they contain.
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index 1e33e51..a2fe882 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -1062,7 +1062,7 @@
       uint32_t target = inst->GetOperandAs<uint32_t>(0);
       CFG_ASSERT(FirstBlockAssert, target);
 
-      _.current_function().RegisterBlockEnd({target}, opcode);
+      _.current_function().RegisterBlockEnd({target});
     } break;
     case SpvOpBranchConditional: {
       uint32_t tlabel = inst->GetOperandAs<uint32_t>(1);
@@ -1070,7 +1070,7 @@
       CFG_ASSERT(FirstBlockAssert, tlabel);
       CFG_ASSERT(FirstBlockAssert, flabel);
 
-      _.current_function().RegisterBlockEnd({tlabel, flabel}, opcode);
+      _.current_function().RegisterBlockEnd({tlabel, flabel});
     } break;
 
     case SpvOpSwitch: {
@@ -1080,7 +1080,7 @@
         CFG_ASSERT(FirstBlockAssert, target);
         cases.push_back(target);
       }
-      _.current_function().RegisterBlockEnd({cases}, opcode);
+      _.current_function().RegisterBlockEnd({cases});
     } break;
     case SpvOpReturn: {
       const uint32_t return_type = _.current_function().GetResultTypeId();
@@ -1090,13 +1090,13 @@
         return _.diag(SPV_ERROR_INVALID_CFG, inst)
                << "OpReturn can only be called from a function with void "
                << "return type.";
-      _.current_function().RegisterBlockEnd(std::vector<uint32_t>(), opcode);
+      _.current_function().RegisterBlockEnd(std::vector<uint32_t>());
       break;
     }
     case SpvOpKill:
     case SpvOpReturnValue:
     case SpvOpUnreachable:
-      _.current_function().RegisterBlockEnd(std::vector<uint32_t>(), opcode);
+      _.current_function().RegisterBlockEnd(std::vector<uint32_t>());
       if (opcode == SpvOpKill) {
         _.current_function().RegisterExecutionModelLimitation(
             SpvExecutionModelFragment,
@@ -1109,6 +1109,27 @@
   return SPV_SUCCESS;
 }
 
+void ReachabilityPass(ValidationState_t& _) {
+  for (auto& f : _.functions()) {
+    std::vector<BasicBlock*> stack;
+    auto entry = f.first_block();
+    // Skip function declarations.
+    if (entry) stack.push_back(entry);
+
+    while (!stack.empty()) {
+      auto block = stack.back();
+      stack.pop_back();
+
+      if (block->reachable()) continue;
+
+      block->set_reachable(true);
+      for (auto succ : *block->successors()) {
+        stack.push_back(succ);
+      }
+    }
+  }
+}
+
 spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) {
   switch (inst->opcode()) {
     case SpvOpPhi:
diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp
index 0d09642..630a19d 100644
--- a/test/val/val_cfg_test.cpp
+++ b/test/val/val_cfg_test.cpp
@@ -1204,14 +1204,6 @@
   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateCFG, WebGPUUnreachableMergeWithBranchUse) {
-  CompileSuccessfully(
-      GetUnreachableMergeWithBranchUse(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
-  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
-  EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("cannot be the target of a branch."));
-}
-
 std::string GetUnreachableMergeWithMultipleUses(SpvCapability cap,
                                                 spv_target_env env) {
   std::string header =
@@ -4503,6 +4495,76 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateCFG, UnreachableIsStaticallyReachable) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeVoid
+%2 = OpTypeFunction %1
+%3 = OpFunction %1 None %2
+%4 = OpLabel
+OpBranch %5
+%5 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+
+  auto f = vstate_->function(3);
+  auto entry = f->GetBlock(4).first;
+  ASSERT_TRUE(entry->reachable());
+  auto end = f->GetBlock(5).first;
+  ASSERT_TRUE(end->reachable());
+}
+
+TEST_F(ValidateCFG, BlockOrderDoesNotAffectReachability) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeVoid
+%2 = OpTypeFunction %1
+%3 = OpTypeBool
+%4 = OpUndef %3
+%5 = OpFunction %1 None %2
+%6 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpSelectionMerge %8 None
+OpBranchConditional %4 %9 %10
+%8 = OpLabel
+OpReturn
+%9 = OpLabel
+OpBranch %8
+%10 = OpLabel
+OpBranch %8
+%11 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+
+  auto f = vstate_->function(5);
+  auto b6 = f->GetBlock(6).first;
+  auto b7 = f->GetBlock(7).first;
+  auto b8 = f->GetBlock(8).first;
+  auto b9 = f->GetBlock(9).first;
+  auto b10 = f->GetBlock(10).first;
+  auto b11 = f->GetBlock(11).first;
+
+  ASSERT_TRUE(b6->reachable());
+  ASSERT_TRUE(b7->reachable());
+  ASSERT_TRUE(b8->reachable());
+  ASSERT_TRUE(b9->reachable());
+  ASSERT_TRUE(b10->reachable());
+  ASSERT_FALSE(b11->reachable());
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools