spirv/subgroups: Refactor to use vtn_push_ssa

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5278>
diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c
index aa8ddff..8e4c3f2 100644
--- a/src/compiler/spirv/vtn_subgroup.c
+++ b/src/compiler/spirv/vtn_subgroup.c
@@ -23,10 +23,9 @@
 
 #include "vtn_private.h"
 
-static void
+static struct vtn_ssa_value *
 vtn_build_subgroup_instr(struct vtn_builder *b,
                          nir_intrinsic_op nir_op,
-                         struct vtn_ssa_value *dst,
                          struct vtn_ssa_value *src0,
                          nir_ssa_def *index,
                          unsigned const_idx0,
@@ -39,14 +38,16 @@
    if (index && index->bit_size != 32)
       index = nir_u2u32(&b->nb, index);
 
+   struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
+
    vtn_assert(dst->type == src0->type);
    if (!glsl_type_is_vector_or_scalar(dst->type)) {
       for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
-         vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
-                                  src0->elems[i], index,
-                                  const_idx0, const_idx1);
+         dst->elems[0] =
+            vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
+                                     const_idx0, const_idx1);
       }
-      return;
+      return dst;
    }
 
    nir_intrinsic_instr *intrin =
@@ -65,33 +66,33 @@
    nir_builder_instr_insert(&b->nb, &intrin->instr);
 
    dst->def = &intrin->dest.ssa;
+
+   return dst;
 }
 
 void
 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
                     const uint32_t *w, unsigned count)
 {
-   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-
-   val->ssa = vtn_create_ssa_value(b, val->type->type);
+   struct vtn_type *dest_type = vtn_get_type(b, w[1]);
 
    switch (opcode) {
    case SpvOpGroupNonUniformElect: {
-      vtn_fail_if(val->type->type != glsl_bool_type(),
+      vtn_fail_if(dest_type->type != glsl_bool_type(),
                   "OpGroupNonUniformElect must return a Bool");
       nir_intrinsic_instr *elect =
          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
       nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
-                                 val->type->type, NULL);
+                                 dest_type->type, NULL);
       nir_builder_instr_insert(&b->nb, &elect->instr);
-      val->ssa->def = &elect->dest.ssa;
+      vtn_push_nir_ssa(b, w[2], &elect->dest.ssa);
       break;
    }
 
    case SpvOpGroupNonUniformBallot:
    case SpvOpSubgroupBallotKHR: {
       bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
-      vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
+      vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
                   "OpGroupNonUniformBallot must return a uvec4");
       nir_intrinsic_instr *ballot =
          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
@@ -99,7 +100,7 @@
       nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
       ballot->num_components = 4;
       nir_builder_instr_insert(&b->nb, &ballot->instr);
-      val->ssa->def = &ballot->dest.ssa;
+      vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa);
       break;
    }
 
@@ -116,10 +117,10 @@
       intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
 
       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
-                                 val->type->type, NULL);
+                                 dest_type->type, NULL);
       nir_builder_instr_insert(&b->nb, &intrin->instr);
 
-      val->ssa->def = &intrin->dest.ssa;
+      vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
       break;
    }
 
@@ -171,19 +172,20 @@
          intrin->src[1] = nir_src_for_ssa(src1);
 
       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
-                                 val->type->type, NULL);
+                                 dest_type->type, NULL);
       nir_builder_instr_insert(&b->nb, &intrin->instr);
 
-      val->ssa->def = &intrin->dest.ssa;
+      vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
       break;
    }
 
    case SpvOpGroupNonUniformBroadcastFirst:
    case SpvOpSubgroupFirstInvocationKHR: {
       bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
-      vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
-                               val->ssa, vtn_ssa_value(b, w[3 + has_scope]),
-                               NULL, 0, 0);
+      vtn_push_ssa_value(b, w[2],
+         vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
+                                  vtn_ssa_value(b, w[3 + has_scope]),
+                                  NULL, 0, 0));
       break;
    }
 
@@ -191,9 +193,10 @@
    case SpvOpGroupBroadcast:
    case SpvOpSubgroupReadInvocationKHR: {
       bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
-      vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
-                               val->ssa, vtn_ssa_value(b, w[3 + has_scope]),
-                               vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0);
+      vtn_push_ssa_value(b, w[2],
+         vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
+                                  vtn_ssa_value(b, w[3 + has_scope]),
+                                  vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
       break;
    }
 
@@ -205,7 +208,7 @@
    case SpvOpSubgroupAllKHR:
    case SpvOpSubgroupAnyKHR:
    case SpvOpSubgroupAllEqualKHR: {
-      vtn_fail_if(val->type->type != glsl_bool_type(),
+      vtn_fail_if(dest_type->type != glsl_bool_type(),
                   "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
       nir_intrinsic_op op;
       switch (opcode) {
@@ -262,10 +265,10 @@
          intrin->num_components = src0->num_components;
       intrin->src[0] = nir_src_for_ssa(src0);
       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
-                                 val->type->type, NULL);
+                                 dest_type->type, NULL);
       nir_builder_instr_insert(&b->nb, &intrin->instr);
 
-      val->ssa->def = &intrin->dest.ssa;
+      vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
       break;
    }
 
@@ -290,15 +293,17 @@
       default:
          unreachable("Invalid opcode");
       }
-      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
-                               vtn_get_nir_ssa(b, w[5]), 0, 0);
+      vtn_push_ssa_value(b, w[2],
+         vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
+                                  vtn_get_nir_ssa(b, w[5]), 0, 0));
       break;
    }
 
    case SpvOpGroupNonUniformQuadBroadcast:
-      vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
-                               val->ssa, vtn_ssa_value(b, w[4]),
-                               vtn_get_nir_ssa(b, w[5]), 0, 0);
+      vtn_push_ssa_value(b, w[2],
+         vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
+                                  vtn_ssa_value(b, w[4]),
+                                  vtn_get_nir_ssa(b, w[5]), 0, 0));
       break;
 
    case SpvOpGroupNonUniformQuadSwap: {
@@ -317,8 +322,8 @@
       default:
          vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
       }
-      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
-                               NULL, 0, 0);
+      vtn_push_ssa_value(b, w[2],
+         vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
       break;
    }
 
@@ -439,8 +444,9 @@
          unreachable("Invalid group operation");
       }
 
-      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
-                               NULL, reduction_op, cluster_size);
+      vtn_push_ssa_value(b, w[2],
+         vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
+                                  reduction_op, cluster_size));
       break;
    }