radv: use shared ac_ngg_compute_subgroup_info

Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/12496

Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35473>
diff --git a/src/amd/vulkan/radv_shader_info.c b/src/amd/vulkan/radv_shader_info.c
index 7386163..f462280 100644
--- a/src/amd/vulkan/radv_shader_info.c
+++ b/src/amd/vulkan/radv_shader_info.c
@@ -1443,212 +1443,29 @@
 {
    const struct radv_physical_device *pdev = radv_device_physical(device);
    const enum amd_gfx_level gfx_level = pdev->info.gfx_level;
-   const unsigned max_verts_per_prim = radv_get_num_input_vertices(es_info, gs_info);
-   const unsigned min_verts_per_prim = gs_info ? max_verts_per_prim : 1;
-
+   const struct radv_shader_info *stage_info = gs_info ? gs_info : es_info;
    const unsigned gs_num_invocations = gs_info ? MAX2(gs_info->gs.invocations, 1) : 1;
-
    const unsigned input_prim = radv_get_pre_rast_input_topology(es_info, gs_info);
-   const bool uses_adjacency = input_prim == MESA_PRIM_LINES_ADJACENCY || input_prim == MESA_PRIM_TRIANGLES_ADJACENCY;
+   const unsigned gs_vertices_out = gs_info ? gs_info->gs.vertices_out : 0;
+   ac_ngg_subgroup_info info;
 
-   /* All these are in dwords: */
-   /* We can't allow using the whole LDS, because GS waves compete with
-    * other shader stages for LDS space.
-    *
-    * TODO: We should really take the shader's internal LDS use into
-    *       account. The linker will fail if the size is greater than
-    *       8K dwords.
-    */
-   const unsigned max_lds_size = 8 * 1024 - 768;
-   const unsigned target_lds_size = max_lds_size;
-   unsigned esvert_lds_size = 0;
-   unsigned gsprim_lds_size = 0;
+   ac_ngg_compute_subgroup_info(gfx_level, es_info->stage, !!gs_info, input_prim, gs_vertices_out, gs_num_invocations,
+                                128, stage_info->wave_size, es_info->esgs_itemsize, stage_info->ngg_lds_vertex_size,
+                                stage_info->ngg_lds_scratch_size, false, &info);
 
-   /* All these are per subgroup: */
-   const unsigned min_esverts = gfx_level >= GFX11 ? max_verts_per_prim /* gfx11 requires at least 1 primitive per TG */
-                                : gfx_level >= GFX10_3 ? 29
-                                                       : (24 - 1 + max_verts_per_prim);
-   bool max_vert_out_per_gs_instance = false;
-   unsigned max_esverts_base = 128;
-   unsigned max_gsprims_base = 128; /* default prim group size clamp */
+   out->hw_max_esverts = info.hw_max_esverts;
+   out->max_gsprims = info.max_gsprims;
+   out->max_out_verts = info.max_out_verts;
+   out->max_vert_out_per_gs_instance = info.max_vert_out_per_gs_instance;
+   out->ngg_emit_size = info.ngg_out_lds_size;
+   out->esgs_ring_size = info.esgs_lds_size * 4;
+   out->prim_amp_factor = gs_info ? gs_info->gs.vertices_out : 1;
 
-   /* Hardware has the following non-natural restrictions on the value
-    * of GE_CNTL.VERT_GRP_SIZE based on based on the primitive type of
-    * the draw:
-    *  - at most 252 for any line input primitive type
-    *  - at most 251 for any quad input primitive type
-    *  - at most 251 for triangle strips with adjacency (this happens to
-    *    be the natural limit for triangle *lists* with adjacency)
-    */
-   max_esverts_base = MIN2(max_esverts_base, 251 + max_verts_per_prim - 1);
+   const struct radv_shader_info *rinfo = gs_info ? gs_info : es_info;
+   out->lds_size = rinfo->ngg_lds_scratch_size + gfx10_get_ngg_vert_prim_lds_size(device, es_info, gs_info, out);
 
-   if (gs_info) {
-      unsigned max_out_verts_per_gsprim = gs_info->gs.vertices_out * gs_num_invocations;
-
-      if (max_out_verts_per_gsprim <= 256) {
-         if (max_out_verts_per_gsprim) {
-            max_gsprims_base = MIN2(max_gsprims_base, 256 / max_out_verts_per_gsprim);
-         }
-      } else {
-         /* Use special multi-cycling mode in which each GS
-          * instance gets its own subgroup. Does not work with
-          * tessellation. */
-         max_vert_out_per_gs_instance = true;
-         max_gsprims_base = 1;
-         max_out_verts_per_gsprim = gs_info->gs.vertices_out;
-      }
-
-      esvert_lds_size = es_info->esgs_itemsize / 4;
-      gsprim_lds_size = (gs_info->ngg_lds_vertex_size / 4) * max_out_verts_per_gsprim;
-   } else {
-      /* VS and TES. */
-      /* LDS size for passing data from GS to ES. */
-      struct radv_streamout_info *so_info = &es_info->so;
-
-      if (so_info->enabled_stream_buffers_mask) {
-         /* Compute the same pervertex LDS size as the NGG streamout lowering pass which allocates
-          * space for all outputs.
-          * TODO: only alloc space for outputs that really need streamout.
-          */
-         const uint32_t num_outputs =
-            es_info->stage == MESA_SHADER_VERTEX ? es_info->vs.num_outputs : es_info->tes.num_outputs;
-         esvert_lds_size = 4 * num_outputs + 1;
-      }
-
-      /* GS stores Primitive IDs (one DWORD) into LDS at the address
-       * corresponding to the ES thread of the provoking vertex. All
-       * ES threads load and export PrimitiveID for their thread.
-       */
-      if (es_info->stage == MESA_SHADER_VERTEX && es_info->outinfo.export_prim_id)
-         esvert_lds_size = MAX2(esvert_lds_size, 1);
-   }
-
-   unsigned max_gsprims = max_gsprims_base;
-   unsigned max_esverts = max_esverts_base;
-
-   if (esvert_lds_size)
-      max_esverts = MIN2(max_esverts, target_lds_size / esvert_lds_size);
-   if (gsprim_lds_size)
-      max_gsprims = MIN2(max_gsprims, target_lds_size / gsprim_lds_size);
-
-   max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
-   clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, uses_adjacency);
-   assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
-
-   if (esvert_lds_size || gsprim_lds_size) {
-      /* Now that we have a rough proportionality between esverts
-       * and gsprims based on the primitive type, scale both of them
-       * down simultaneously based on required LDS space.
-       *
-       * We could be smarter about this if we knew how much vertex
-       * reuse to expect.
-       */
-      unsigned lds_total = max_esverts * esvert_lds_size + max_gsprims * gsprim_lds_size;
-      if (lds_total > target_lds_size) {
-         max_esverts = max_esverts * target_lds_size / lds_total;
-         max_gsprims = max_gsprims * target_lds_size / lds_total;
-
-         max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
-         clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, uses_adjacency);
-         assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
-      }
-   }
-
-   /* Round up towards full wave sizes for better ALU utilization. */
-   if (!max_vert_out_per_gs_instance) {
-      unsigned orig_max_esverts;
-      unsigned orig_max_gsprims;
-      unsigned wavesize;
-
-      if (gs_info) {
-         wavesize = gs_info->wave_size;
-      } else {
-         wavesize = es_info->wave_size;
-      }
-
-      do {
-         orig_max_esverts = max_esverts;
-         orig_max_gsprims = max_gsprims;
-
-         max_esverts = align(max_esverts, wavesize);
-         max_esverts = MIN2(max_esverts, max_esverts_base);
-         if (esvert_lds_size)
-            max_esverts = MIN2(max_esverts, (max_lds_size - max_gsprims * gsprim_lds_size) / esvert_lds_size);
-         max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
-
-         /* Hardware restriction: minimum value of max_esverts */
-         if (gfx_level == GFX10)
-            max_esverts = MAX2(max_esverts, min_esverts - 1 + max_verts_per_prim);
-         else
-            max_esverts = MAX2(max_esverts, min_esverts);
-
-         max_gsprims = align(max_gsprims, wavesize);
-         max_gsprims = MIN2(max_gsprims, max_gsprims_base);
-         if (gsprim_lds_size) {
-            /* Don't count unusable vertices to the LDS
-             * size. Those are vertices above the maximum
-             * number of vertices that can occur in the
-             * workgroup, which is e.g. max_gsprims * 3
-             * for triangles.
-             */
-            unsigned usable_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
-            max_gsprims = MIN2(max_gsprims, (max_lds_size - usable_esverts * esvert_lds_size) / gsprim_lds_size);
-         }
-         clamp_gsprims_to_esverts(&max_gsprims, max_esverts, min_verts_per_prim, uses_adjacency);
-         assert(max_esverts >= max_verts_per_prim && max_gsprims >= 1);
-      } while (orig_max_esverts != max_esverts || orig_max_gsprims != max_gsprims);
-
-      /* Verify the restriction. */
-      if (gfx_level == GFX10)
-         assert(max_esverts >= min_esverts - 1 + max_verts_per_prim);
-      else
-         assert(max_esverts >= min_esverts);
-   } else {
-      /* Hardware restriction: minimum value of max_esverts */
-      if (gfx_level == GFX10)
-         max_esverts = MAX2(max_esverts, min_esverts - 1 + max_verts_per_prim);
-      else
-         max_esverts = MAX2(max_esverts, min_esverts);
-   }
-
-   unsigned max_out_vertices = max_vert_out_per_gs_instance ? gs_info->gs.vertices_out
-                               : gs_info ? max_gsprims * gs_num_invocations * gs_info->gs.vertices_out
-                                         : max_esverts;
-   assert(max_out_vertices <= 256);
-
-   unsigned prim_amp_factor = 1;
-   if (gs_info) {
-      /* Number of output primitives per GS input primitive after
-       * GS instancing. */
-      prim_amp_factor = gs_info->gs.vertices_out;
-   }
-
-   /* On Gfx10, the GE only checks against the maximum number of ES verts
-    * after allocating a full GS primitive. So we need to ensure that
-    * whenever this check passes, there is enough space for a full
-    * primitive without vertex reuse.
-    */
-   if (gfx_level == GFX10)
-      out->hw_max_esverts = max_esverts - max_verts_per_prim + 1;
-   else
-      out->hw_max_esverts = max_esverts;
-
-   out->max_gsprims = max_gsprims;
-   out->max_out_verts = max_out_vertices;
-   out->prim_amp_factor = prim_amp_factor;
-   out->max_vert_out_per_gs_instance = max_vert_out_per_gs_instance;
-   out->ngg_emit_size = max_gsprims * gsprim_lds_size;
-
-   /* Don't count unusable vertices. */
-   out->esgs_ring_size = MIN2(max_esverts, max_gsprims * max_verts_per_prim) * esvert_lds_size * 4;
-
-   assert(out->hw_max_esverts >= min_esverts); /* HW limitation */
-
-   const struct radv_shader_info *info = gs_info ? gs_info : es_info;
-   out->lds_size = info->ngg_lds_scratch_size + gfx10_get_ngg_vert_prim_lds_size(device, es_info, gs_info, out);
-
-   unsigned workgroup_size =
-      ac_compute_ngg_workgroup_size(max_esverts, max_gsprims * gs_num_invocations, max_out_vertices, prim_amp_factor);
+   unsigned workgroup_size = ac_compute_ngg_workgroup_size(info.hw_max_esverts, info.max_gsprims * gs_num_invocations,
+                                                           info.max_out_verts, out->prim_amp_factor);
    if (gs_info) {
       gs_info->workgroup_size = workgroup_size;
    }
diff --git a/src/gallium/drivers/zink/ci/zink-radv-navi10-fails.txt b/src/gallium/drivers/zink/ci/zink-radv-navi10-fails.txt
index a8c4637..4c0f460 100644
--- a/src/gallium/drivers/zink/ci/zink-radv-navi10-fails.txt
+++ b/src/gallium/drivers/zink/ci/zink-radv-navi10-fails.txt
@@ -168,10 +168,6 @@
 spec@arb_framebuffer_object@execution@msaa-alpha-to-coverage_alpha-to-one,Fail
 spec@arb_framebuffer_object@execution@msaa-alpha-to-coverage_alpha-to-one_write-z,Fail
 
-# https://gitlab.freedesktop.org/mesa/mesa/-/issues/12496
-# ../src/amd/vulkan/radv_shader_info.c:1559: gfx10_get_ngg_info: Assertion `max_esverts >= max_verts_per_prim && max_gsprims >= 1' failed.
-spec@glsl-1.50@gs-max-output,Crash
-
 spec@arb_sample_locations@test,Fail
 spec@arb_sample_locations@test@MSAA: 1- X: 0- Y: 0- Grid: false,Fail
 spec@arb_sample_locations@test@MSAA: 1- X: 0- Y: 0- Grid: true,Fail
diff --git a/src/gallium/drivers/zink/ci/zink-radv-navi31-fails.txt b/src/gallium/drivers/zink/ci/zink-radv-navi31-fails.txt
index dd8a4c0..ce5e797 100644
--- a/src/gallium/drivers/zink/ci/zink-radv-navi31-fails.txt
+++ b/src/gallium/drivers/zink/ci/zink-radv-navi31-fails.txt
@@ -196,10 +196,6 @@
 spec@ext_image_dma_buf_import@ext_image_dma_buf_import-sample_yvyu,Fail
 spec@ext_image_dma_buf_import@ext_image_dma_buf_import-transcode-nv12-as-r8-gr88,Fail
 
-# https://gitlab.freedesktop.org/mesa/mesa/-/issues/12496
-# ../src/amd/vulkan/radv_shader_info.c:1559: gfx10_get_ngg_info: Assertion `max_esverts >= max_verts_per_prim && max_gsprims >= 1' failed.
-spec@glsl-1.50@gs-max-output,Crash
-
 spec@arb_sample_locations@test,Fail
 spec@arb_sample_locations@test@MSAA: 1- X: 0- Y: 0- Grid: false,Fail
 spec@arb_sample_locations@test@MSAA: 1- X: 0- Y: 0- Grid: true,Fail
diff --git a/src/gallium/drivers/zink/ci/zink-radv-vangogh-fails.txt b/src/gallium/drivers/zink/ci/zink-radv-vangogh-fails.txt
index bf39d36..2c86a565 100644
--- a/src/gallium/drivers/zink/ci/zink-radv-vangogh-fails.txt
+++ b/src/gallium/drivers/zink/ci/zink-radv-vangogh-fails.txt
@@ -165,10 +165,6 @@
 # Regression noticed in https://gitlab.freedesktop.org/mesa/mesa/-/pipelines/891104
 spec@arb_viewport_array@display-list,Fail
 
-# https://gitlab.freedesktop.org/mesa/mesa/-/issues/12496
-# ../src/amd/vulkan/radv_shader_info.c:1559: gfx10_get_ngg_info: Assertion `max_esverts >= max_verts_per_prim && max_gsprims >= 1' failed.
-spec@glsl-1.50@gs-max-output,Crash
-
 spec@arb_sample_locations@test,Fail
 spec@arb_sample_locations@test@MSAA: 1- X: 0- Y: 0- Grid: false,Fail
 spec@arb_sample_locations@test@MSAA: 1- X: 0- Y: 0- Grid: true,Fail