spirv: Add a helpers for getting types of values

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5278>
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 93969bf..8a83f7f 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -1173,7 +1173,7 @@
    }
 
    case SpvOpTypeVector: {
-      struct vtn_type *base = vtn_value(b, w[2], vtn_value_type_type)->type;
+      struct vtn_type *base = vtn_get_type(b, w[2]);
       unsigned elems = w[3];
 
       vtn_fail_if(base->base_type != vtn_base_type_scalar,
@@ -1191,7 +1191,7 @@
    }
 
    case SpvOpTypeMatrix: {
-      struct vtn_type *base = vtn_value(b, w[2], vtn_value_type_type)->type;
+      struct vtn_type *base = vtn_get_type(b, w[2]);
       unsigned columns = w[3];
 
       vtn_fail_if(base->base_type != vtn_base_type_vector,
@@ -1215,8 +1215,7 @@
 
    case SpvOpTypeRuntimeArray:
    case SpvOpTypeArray: {
-      struct vtn_type *array_element =
-         vtn_value(b, w[2], vtn_value_type_type)->type;
+      struct vtn_type *array_element = vtn_get_type(b, w[2]);
 
       if (opcode == SpvOpTypeRuntimeArray) {
          /* A length of 0 is used to denote unsized arrays */
@@ -1246,8 +1245,7 @@
 
       NIR_VLA(struct glsl_struct_field, fields, count);
       for (unsigned i = 0; i < num_fields; i++) {
-         val->type->members[i] =
-            vtn_value(b, w[i + 2], vtn_value_type_type)->type;
+         val->type->members[i] = vtn_get_type(b, w[i + 2]);
          fields[i] = (struct glsl_struct_field) {
             .type = val->type->members[i]->type,
             .name = ralloc_asprintf(b, "field%d", i),
@@ -1296,14 +1294,13 @@
       val->type->base_type = vtn_base_type_function;
       val->type->type = NULL;
 
-      val->type->return_type = vtn_value(b, w[2], vtn_value_type_type)->type;
+      val->type->return_type = vtn_get_type(b, w[2]);
 
       const unsigned num_params = count - 3;
       val->type->length = num_params;
       val->type->params = ralloc_array(b, struct vtn_type *, num_params);
       for (unsigned i = 0; i < count - 3; i++) {
-         val->type->params[i] =
-            vtn_value(b, w[i + 3], vtn_value_type_type)->type;
+         val->type->params[i] = vtn_get_type(b, w[i + 3]);
       }
       break;
    }
@@ -1344,7 +1341,7 @@
                      "forward declaration of a pointer, OpTypePointer can "
                      "only be used once for a given id.");
 
-         val->type->deref = vtn_value(b, w[3], vtn_value_type_type)->type;
+         val->type->deref = vtn_get_type(b, w[3]);
 
          /* Only certain storage classes use ArrayStride.  The others (in
           * particular Workgroup) are expected to be laid out by the driver.
@@ -1381,9 +1378,7 @@
    case SpvOpTypeImage: {
       val->type->base_type = vtn_base_type_image;
 
-      const struct vtn_type *sampled_type =
-         vtn_value(b, w[2], vtn_value_type_type)->type;
-
+      const struct vtn_type *sampled_type = vtn_get_type(b, w[2]);
       vtn_fail_if(sampled_type->base_type != vtn_base_type_scalar ||
                   glsl_get_bit_size(sampled_type->type) != 32,
                   "Sampled type of OpTypeImage must be a 32-bit scalar");
@@ -1443,7 +1438,7 @@
 
    case SpvOpTypeSampledImage:
       val->type->base_type = vtn_base_type_sampled_image;
-      val->type->image = vtn_value(b, w[2], vtn_value_type_type)->type;
+      val->type->image = vtn_get_type(b, w[2]);
       val->type->type = val->type->image->type;
       break;
 
@@ -1813,11 +1808,9 @@
          case SpvOpUConvert:
             /* We have a source in a conversion */
             src_alu_type =
-               nir_get_nir_type_for_glsl_type(
-                  vtn_value(b, w[4], vtn_value_type_constant)->type->type);
+               nir_get_nir_type_for_glsl_type(vtn_get_value_type(b, w[4])->type);
             /* We use the bitsize of the conversion source to evaluate the opcode later */
-            bit_size = glsl_get_bit_size(
-               vtn_value(b, w[4], vtn_value_type_constant)->type->type);
+            bit_size = glsl_get_bit_size(vtn_get_value_type(b, w[4])->type);
             break;
          default:
             bit_size = glsl_get_bit_size(val->type->type);
@@ -2306,7 +2299,7 @@
       return;
    }
 
-   struct vtn_type *ret_type = vtn_value(b, w[1], vtn_value_type_type)->type;
+   struct vtn_type *ret_type = vtn_get_type(b, w[1]);
 
    struct vtn_pointer *image = NULL, *sampler = NULL;
    struct vtn_value *sampled_val = vtn_untyped_value(b, w[3]);
@@ -3018,7 +3011,7 @@
       vtn_emit_memory_barrier(b, scope, before_semantics);
 
    if (opcode != SpvOpImageWrite && opcode != SpvOpAtomicStore) {
-      struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      struct vtn_type *type = vtn_get_type(b, w[1]);
 
       unsigned dest_components = glsl_get_vector_elements(type->type);
       if (nir_intrinsic_infos[op].dest_components == 0)
@@ -3319,7 +3312,7 @@
       vtn_emit_memory_barrier(b, scope, before_semantics);
 
    if (opcode != SpvOpAtomicStore) {
-      struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      struct vtn_type *type = vtn_get_type(b, w[1]);
 
       nir_ssa_dest_init(&atomic->instr, &atomic->dest,
                         glsl_get_vector_elements(type->type),
@@ -3536,7 +3529,7 @@
 vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
                      const uint32_t *w, unsigned count)
 {
-   struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
+   struct vtn_type *type = vtn_get_type(b, w[1]);
    struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, type->type);
 
    switch (opcode) {
@@ -4724,7 +4717,7 @@
       vtn_fail("Result type of OpSelect must be a scalar, composite, or pointer");
    }
 
-   struct vtn_type *res_type = vtn_value(b, w[1], vtn_value_type_type)->type;
+   struct vtn_type *res_type = vtn_get_type(b, w[1]);
    struct vtn_ssa_value *ssa = vtn_nir_select(b,
       vtn_ssa_value(b, w[3]), vtn_ssa_value(b, w[4]), vtn_ssa_value(b, w[5]));
 
@@ -4735,8 +4728,8 @@
 vtn_handle_ptr(struct vtn_builder *b, SpvOp opcode,
                const uint32_t *w, unsigned count)
 {
-   struct vtn_type *type1 = vtn_untyped_value(b, w[3])->type;
-   struct vtn_type *type2 = vtn_untyped_value(b, w[4])->type;
+   struct vtn_type *type1 = vtn_get_value_type(b, w[3]);
+   struct vtn_type *type2 = vtn_get_value_type(b, w[4]);
    vtn_fail_if(type1->base_type != vtn_base_type_pointer ||
                type2->base_type != vtn_base_type_pointer,
                "%s operands must have pointer types",
@@ -4745,8 +4738,7 @@
                "%s operands must have the same storage class",
                spirv_op_to_string(opcode));
 
-   struct vtn_type *vtn_type =
-      vtn_value(b, w[1], vtn_value_type_type)->type;
+   struct vtn_type *vtn_type = vtn_get_type(b, w[1]);
    const struct glsl_type *type = vtn_type->type;
 
    nir_address_format addr_format = vtn_mode_to_address_format(
@@ -4805,7 +4797,7 @@
 
    case SpvOpUndef: {
       struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_undef);
-      val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      val->type = vtn_get_type(b, w[1]);
       break;
    }
 
@@ -5142,8 +5134,7 @@
       nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 1, NULL);
       nir_builder_instr_insert(&b->nb, &intrin->instr);
 
-      struct vtn_type *res_type =
-         vtn_value(b, w[1], vtn_value_type_type)->type;
+      struct vtn_type *res_type = vtn_get_type(b, w[1]);
       struct vtn_ssa_value *val = vtn_create_ssa_value(b, res_type->type);
       val->def = &intrin->dest.ssa;
 
@@ -5175,7 +5166,7 @@
       nir_intrinsic_set_memory_scope(intrin, nir_scope);
       nir_builder_instr_insert(&b->nb, &intrin->instr);
 
-      struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      struct vtn_type *type = vtn_get_type(b, w[1]);
       const struct glsl_type *dest_type = type->type;
       nir_ssa_def *result;
 
diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
index 9ebadba..60e8814 100644
--- a/src/compiler/spirv/vtn_alu.c
+++ b/src/compiler/spirv/vtn_alu.c
@@ -415,8 +415,7 @@
                const uint32_t *w, unsigned count)
 {
    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-   const struct glsl_type *type =
-      vtn_value(b, w[1], vtn_value_type_type)->type->type;
+   const struct glsl_type *type = vtn_get_type(b, w[1])->type;
 
    vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
 
@@ -697,7 +696,7 @@
     *    L) maps its lower-ordered bits to the lower-numbered components of L."
     */
 
-   struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
+   struct vtn_type *type = vtn_get_type(b, w[1]);
    struct vtn_ssa_value *vtn_src = vtn_ssa_value(b, w[3]);
    struct nir_ssa_def *src = vtn_src->def;
    struct vtn_ssa_value *val = vtn_create_ssa_value(b, type->type);
diff --git a/src/compiler/spirv/vtn_amd.c b/src/compiler/spirv/vtn_amd.c
index f180e3b..ae7c7af 100644
--- a/src/compiler/spirv/vtn_amd.c
+++ b/src/compiler/spirv/vtn_amd.c
@@ -30,8 +30,7 @@
 vtn_handle_amd_gcn_shader_instruction(struct vtn_builder *b, SpvOp ext_opcode,
                                       const uint32_t *w, unsigned count)
 {
-   const struct glsl_type *dest_type =
-                           vtn_value(b, w[1], vtn_value_type_type)->type->type;
+   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
    val->ssa = vtn_create_ssa_value(b, dest_type);
 
@@ -61,8 +60,7 @@
 vtn_handle_amd_shader_ballot_instruction(struct vtn_builder *b, SpvOp ext_opcode,
                                          const uint32_t *w, unsigned count)
 {
-   const struct glsl_type *dest_type =
-                           vtn_value(b, w[1], vtn_value_type_type)->type->type;
+   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
    val->ssa = vtn_create_ssa_value(b, dest_type);
 
@@ -124,8 +122,7 @@
                                                  const uint32_t *w, unsigned count)
 {
    struct nir_builder *nb = &b->nb;
-   const struct glsl_type *dest_type =
-      vtn_value(b, w[1], vtn_value_type_type)->type->type;
+   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
    val->ssa = vtn_create_ssa_value(b, dest_type);
 
@@ -175,9 +172,7 @@
 vtn_handle_amd_shader_explicit_vertex_parameter_instruction(struct vtn_builder *b, SpvOp ext_opcode,
                                                             const uint32_t *w, unsigned count)
 {
-   const struct glsl_type *dest_type =
-      vtn_value(b, w[1], vtn_value_type_type)->type->type;
-
+   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
    val->ssa = vtn_create_ssa_value(b, dest_type);
 
diff --git a/src/compiler/spirv/vtn_cfg.c b/src/compiler/spirv/vtn_cfg.c
index a31047b..cbdd6af 100644
--- a/src/compiler/spirv/vtn_cfg.c
+++ b/src/compiler/spirv/vtn_cfg.c
@@ -185,7 +185,7 @@
 vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
                          const uint32_t *w, unsigned count)
 {
-   struct vtn_type *res_type = vtn_value(b, w[1], vtn_value_type_type)->type;
+   struct vtn_type *res_type = vtn_get_type(b, w[1]);
    struct vtn_function *vtn_callee =
       vtn_value(b, w[3], vtn_value_type_function)->func;
    struct nir_function *callee = vtn_callee->impl->function;
@@ -256,12 +256,11 @@
       list_inithead(&b->func->body);
       b->func->control = w[3];
 
-      UNUSED const struct glsl_type *result_type =
-         vtn_value(b, w[1], vtn_value_type_type)->type->type;
+      UNUSED const struct glsl_type *result_type = vtn_get_type(b, w[1])->type;
       struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
       val->func = b->func;
 
-      b->func->type = vtn_value(b, w[4], vtn_value_type_type)->type;
+      b->func->type = vtn_get_type(b, w[4]);
       const struct vtn_type *func_type = b->func->type;
 
       vtn_assert(func_type->return_type->type == result_type);
@@ -314,7 +313,7 @@
       break;
 
    case SpvOpFunctionParameter: {
-      struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      struct vtn_type *type = vtn_get_type(b, w[1]);
 
       vtn_assert(b->func_param_idx < b->func->impl->function->num_params);
 
@@ -984,7 +983,7 @@
     * algorithm all over again.  It's easier if we just let
     * lower_vars_to_ssa do that for us instead of repeating it here.
     */
-   struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
+   struct vtn_type *type = vtn_get_type(b, w[1]);
    nir_variable *phi_var =
       nir_local_variable_create(b->nb.impl, type->type, "phi");
    _mesa_hash_table_insert(b->phi_table, w, phi_var);
diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c
index 947d33c..061ffd0 100644
--- a/src/compiler/spirv/vtn_glsl450.c
+++ b/src/compiler/spirv/vtn_glsl450.c
@@ -310,9 +310,7 @@
                    const uint32_t *w, unsigned count)
 {
    struct nir_builder *nb = &b->nb;
-   const struct glsl_type *dest_type =
-      vtn_value(b, w[1], vtn_value_type_type)->type->type;
-
+   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
    val->ssa = vtn_create_ssa_value(b, dest_type);
 
@@ -559,9 +557,7 @@
 handle_glsl450_interpolation(struct vtn_builder *b, enum GLSLstd450 opcode,
                              const uint32_t *w, unsigned count)
 {
-   const struct glsl_type *dest_type =
-      vtn_value(b, w[1], vtn_value_type_type)->type->type;
-
+   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
    val->ssa = vtn_create_ssa_value(b, dest_type);
 
@@ -636,7 +632,7 @@
    case GLSLstd450Determinant: {
       struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
       val->ssa = rzalloc(b, struct vtn_ssa_value);
-      val->ssa->type = vtn_value(b, w[1], vtn_value_type_type)->type->type;
+      val->ssa->type = vtn_get_type(b, w[1])->type;
       val->ssa->def = build_mat_det(b, vtn_ssa_value(b, w[5]));
       break;
    }
diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c
index e332810..57d39ee 100644
--- a/src/compiler/spirv/vtn_opencl.c
+++ b/src/compiler/spirv/vtn_opencl.c
@@ -39,8 +39,7 @@
 handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
              const uint32_t *w, unsigned count, nir_handler handler)
 {
-   const struct glsl_type *dest_type =
-      vtn_value(b, w[1], vtn_value_type_type)->type->type;
+   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
 
    unsigned num_srcs = count - 5;
    nir_ssa_def *srcs[3] = { NULL };
@@ -225,9 +224,9 @@
 {
    struct vtn_type *type;
    if (load)
-      type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      type = vtn_get_type(b, w[1]);
    else
-      type = vtn_untyped_value(b, w[5])->type;
+      type = vtn_get_value_type(b, w[5]);
    unsigned a = load ? 0 : 1;
 
    const struct glsl_type *dest_type = type->type;
diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h
index 5cd5a46..b8163ea 100644
--- a/src/compiler/spirv/vtn_private.h
+++ b/src/compiler/spirv/vtn_private.h
@@ -769,6 +769,20 @@
    }
 }
 
+static inline struct vtn_type *
+vtn_get_value_type(struct vtn_builder *b, uint32_t value_id)
+{
+   struct vtn_value *val = vtn_untyped_value(b, value_id);
+   vtn_fail_if(val->type == NULL, "Value %u does not have a type", value_id);
+   return val->type;
+}
+
+static inline struct vtn_type *
+vtn_get_type(struct vtn_builder *b, uint32_t value_id)
+{
+   return vtn_value(b, value_id, vtn_value_type_type)->type;
+}
+
 struct vtn_ssa_value *vtn_ssa_value(struct vtn_builder *b, uint32_t value_id);
 
 struct vtn_value *vtn_push_value_pointer(struct vtn_builder *b,
diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c
index 58ec051..f453834 100644
--- a/src/compiler/spirv/vtn_variables.c
+++ b/src/compiler/spirv/vtn_variables.c
@@ -2491,12 +2491,12 @@
    switch (opcode) {
    case SpvOpUndef: {
       struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_undef);
-      val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      val->type = vtn_get_type(b, w[1]);
       break;
    }
 
    case SpvOpVariable: {
-      struct vtn_type *ptr_type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      struct vtn_type *ptr_type = vtn_get_type(b, w[1]);
 
       struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_pointer);
 
@@ -2544,7 +2544,7 @@
          idx++;
       }
 
-      struct vtn_type *ptr_type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      struct vtn_type *ptr_type = vtn_get_type(b, w[1]);
       struct vtn_value *base_val = vtn_untyped_value(b, w[3]);
       if (base_val->value_type == vtn_value_type_sampled_image) {
          /* This is rather insane.  SPIR-V allows you to use OpSampledImage
@@ -2586,8 +2586,7 @@
    }
 
    case SpvOpLoad: {
-      struct vtn_type *res_type =
-         vtn_value(b, w[1], vtn_value_type_type)->type;
+      struct vtn_type *res_type = vtn_get_type(b, w[1]);
       struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer);
       struct vtn_pointer *src = src_val->pointer;