aco/ngg: Refactor gs_alloc_req in preparation for NGG GS.

Previously, this function inferred the vertex and primitive counts
from the gs_tg_info shader argument, but in case of NGG GS, it will
need to be calculated in runtime.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6964>
diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp
index d6071e0..5242aff 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -10692,7 +10692,21 @@
    return true;
 }
 
-void ngg_emit_sendmsg_gs_alloc_req(isel_context *ctx)
+Temp ngg_max_vertex_count(isel_context *ctx)
+{
+   Builder bld(ctx->program, ctx->block);
+   return bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
+                   get_arg(ctx, ctx->args->gs_tg_info), Operand(12u | (9u << 16u)));
+}
+
+Temp ngg_max_primitive_count(isel_context *ctx)
+{
+   Builder bld(ctx->program, ctx->block);
+   return bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
+                   get_arg(ctx, ctx->args->gs_tg_info), Operand(22u | (9u << 16u)));
+}
+
+void ngg_emit_sendmsg_gs_alloc_req(isel_context *ctx, Temp vtx_cnt = Temp(), Temp prm_cnt = Temp())
 {
    Builder bld(ctx->program, ctx->block);
 
@@ -10712,12 +10726,20 @@
    begin_uniform_if_else(ctx, &ic);
    bld.reset(ctx->block);
 
+   /* VS/TES: we infer the vertex and primitive count from arguments
+    * GS: the caller needs to supply them
+    */
+   assert(ctx->shader->info.stage == MESA_SHADER_GEOMETRY
+          ? (vtx_cnt.id() && prm_cnt.id())
+          : (!vtx_cnt.id() && !prm_cnt.id()));
+
    /* Number of vertices output by VS/TES */
-   Temp vtx_cnt = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
-                           get_arg(ctx, ctx->args->gs_tg_info), Operand(12u | (9u << 16u)));
+   if (vtx_cnt.id() == 0)
+      vtx_cnt = ngg_max_vertex_count(ctx);
+
    /* Number of primitives output by VS/TES */
-   Temp prm_cnt = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
-                           get_arg(ctx, ctx->args->gs_tg_info), Operand(22u | (9u << 16u)));
+   if (prm_cnt.id() == 0)
+      prm_cnt = ngg_max_primitive_count(ctx);
 
    /* Put the number of vertices and primitives into m0 for the GS_ALLOC_REQ */
    Temp tmp = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), prm_cnt, Operand(12u));