nir: allow divergence information to be updated when inserting instruction

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6558>
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 20b31e9..0bf59e4 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -4801,6 +4801,7 @@
 void nir_convert_loop_to_lcssa(nir_loop *loop);
 bool nir_convert_to_lcssa(nir_shader *shader, bool skip_invariants, bool skip_bool_invariants);
 void nir_divergence_analysis(nir_shader *shader);
+bool nir_update_instr_divergence(nir_shader *shader, nir_instr *instr);
 
 /* If phi_webs_only is true, only convert SSA values involved in phi nodes to
  * registers.  If false, convert all values (even those not involved in a phi
diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h
index 2b340ba..1b43892 100644
--- a/src/compiler/nir/nir_builder.h
+++ b/src/compiler/nir/nir_builder.h
@@ -36,6 +36,10 @@
    /* Whether new ALU instructions will be marked "exact" */
    bool exact;
 
+   /* Whether to run divergence analysis on inserted instructions (loop merge
+    * and header phis are not updated). */
+   bool update_divergence;
+
    nir_shader *shader;
    nir_function_impl *impl;
 } nir_builder;
@@ -54,6 +58,7 @@
                                gl_shader_stage stage,
                                const nir_shader_compiler_options *options)
 {
+   memset(build, 0, sizeof(*build));
    build->shader = nir_shader_create(mem_ctx, stage, options, NULL);
    nir_function *func = nir_function_create(build->shader, "main");
    func->is_entrypoint = true;
@@ -110,6 +115,9 @@
 {
    nir_instr_insert(build->cursor, instr);
 
+   if (build->update_divergence)
+      nir_update_instr_divergence(build->shader, instr);
+
    /* Move the cursor forward. */
    build->cursor = nir_after_instr(instr);
 }
@@ -237,6 +245,8 @@
       return NULL;
 
    nir_instr_insert(nir_before_cf_list(&build->impl->body), &undef->instr);
+   if (build->update_divergence)
+      nir_update_instr_divergence(build->shader, &undef->instr);
 
    return &undef->def;
 }
diff --git a/src/compiler/nir/nir_divergence_analysis.c b/src/compiler/nir/nir_divergence_analysis.c
index ccfc7de..cd33996 100644
--- a/src/compiler/nir/nir_divergence_analysis.c
+++ b/src/compiler/nir/nir_divergence_analysis.c
@@ -615,6 +615,31 @@
 }
 
 static bool
+update_instr_divergence(nir_shader *shader, nir_instr *instr)
+{
+   switch (instr->type) {
+   case nir_instr_type_alu:
+      return visit_alu(nir_instr_as_alu(instr));
+   case nir_instr_type_intrinsic:
+      return visit_intrinsic(shader, nir_instr_as_intrinsic(instr));
+   case nir_instr_type_tex:
+      return visit_tex(nir_instr_as_tex(instr));
+   case nir_instr_type_load_const:
+      return visit_load_const(nir_instr_as_load_const(instr));
+   case nir_instr_type_ssa_undef:
+      return visit_ssa_undef(nir_instr_as_ssa_undef(instr));
+   case nir_instr_type_deref:
+      return visit_deref(shader, nir_instr_as_deref(instr));
+   case nir_instr_type_jump:
+   case nir_instr_type_phi:
+   case nir_instr_type_call:
+   case nir_instr_type_parallel_copy:
+   default:
+      unreachable("NIR divergence analysis: Unsupported instruction type.");
+   }
+}
+
+static bool
 visit_block(nir_block *block, struct divergence_state *state)
 {
    bool has_changed = false;
@@ -627,33 +652,10 @@
       if (state->first_visit)
          nir_foreach_ssa_def(instr, set_ssa_def_not_divergent, NULL);
 
-      switch (instr->type) {
-      case nir_instr_type_alu:
-         has_changed |= visit_alu(nir_instr_as_alu(instr));
-         break;
-      case nir_instr_type_intrinsic:
-         has_changed |= visit_intrinsic(state->shader, nir_instr_as_intrinsic(instr));
-         break;
-      case nir_instr_type_tex:
-         has_changed |= visit_tex(nir_instr_as_tex(instr));
-         break;
-      case nir_instr_type_load_const:
-         has_changed |= visit_load_const(nir_instr_as_load_const(instr));
-         break;
-      case nir_instr_type_ssa_undef:
-         has_changed |= visit_ssa_undef(nir_instr_as_ssa_undef(instr));
-         break;
-      case nir_instr_type_deref:
-         has_changed |= visit_deref(state->shader, nir_instr_as_deref(instr));
-         break;
-      case nir_instr_type_jump:
+      if (instr->type == nir_instr_type_jump)
          has_changed |= visit_jump(nir_instr_as_jump(instr), state);
-         break;
-      case nir_instr_type_phi:
-      case nir_instr_type_call:
-      case nir_instr_type_parallel_copy:
-         unreachable("NIR divergence analysis: Unsupported instruction type.");
-      }
+      else
+         has_changed |= update_instr_divergence(state->shader, instr);
    }
 
    return has_changed;
@@ -903,3 +905,23 @@
    visit_cf_list(&nir_shader_get_entrypoint(shader)->body, &state);
 }
 
+bool nir_update_instr_divergence(nir_shader *shader, nir_instr *instr)
+{
+   nir_foreach_ssa_def(instr, set_ssa_def_not_divergent, NULL);
+
+   if (instr->type == nir_instr_type_phi) {
+      nir_cf_node *prev = nir_cf_node_prev(&instr->block->cf_node);
+      /* can only update gamma/if phis */
+      if (!prev || prev->type != nir_cf_node_if)
+         return false;
+
+      nir_if *nif = nir_cf_node_as_if(prev);
+
+      visit_if_merge_phi(nir_instr_as_phi(instr), nir_src_is_divergent(nif->condition));
+      return true;
+   }
+
+   update_instr_divergence(shader, instr);
+   return true;
+}
+