spirv: Add generic pointer support

Most of this is fairly straightforward; we just set all the modes on any
derefs which are generic.  The one tricky bit is OpGenericCastToPtrExplicit.
Instead of adding NIR intrinsics to do the cast, we add NIR intrinsics
to do a storage class check and then bcsel based on that.

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6332>
diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h
index 96c89c3..163f623 100644
--- a/src/compiler/nir/nir_builder.h
+++ b/src/compiler/nir/nir_builder.h
@@ -1416,6 +1416,19 @@
 }
 
 static inline nir_ssa_def *
+nir_build_deref_mode_is(nir_builder *build, nir_deref_instr *deref,
+                        nir_variable_mode mode)
+{
+   nir_intrinsic_instr *intrin =
+      nir_intrinsic_instr_create(build->shader, nir_intrinsic_deref_mode_is);
+   intrin->src[0] = nir_src_for_ssa(&deref->dest.ssa);
+   nir_intrinsic_set_memory_modes(intrin, mode);
+   nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 1, NULL);
+   nir_builder_instr_insert(build, &intrin->instr);
+   return &intrin->dest.ssa;
+}
+
+static inline nir_ssa_def *
 nir_load_var(nir_builder *build, nir_variable *var)
 {
    return nir_load_deref(build, nir_build_deref_var(build, var));
diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py
index c64886a..ef09758 100644
--- a/src/compiler/nir/nir_intrinsics.py
+++ b/src/compiler/nir/nir_intrinsics.py
@@ -212,6 +212,11 @@
 intrinsic("get_ubo_size", src_comp=[-1], dest_comp=1,
           flags=[CAN_ELIMINATE, CAN_REORDER])
 
+# Intrinsics which provide a run-time mode-check.  Unlike the compile-time
+# mode checks, a pointer can only have exactly one mode at runtime.
+intrinsic("deref_mode_is", src_comp=[-1], dest_comp=1,
+          indices=[MEMORY_MODES], flags=[CAN_ELIMINATE, CAN_REORDER])
+
 # a barrier is an intrinsic with no inputs/outputs but which can't be moved
 # around/optimized in general
 def barrier(name):
diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c
index e6b4980..d68c2d2 100644
--- a/src/compiler/nir/nir_validate.c
+++ b/src/compiler/nir/nir_validate.c
@@ -670,6 +670,11 @@
       validate_assert(state, nir_src_bit_size(instr->src[0]) >= 8);
       break;
 
+   case nir_intrinsic_deref_mode_is:
+      validate_assert(state,
+         util_bitcount(nir_intrinsic_memory_modes(instr)) == 1);
+      break;
+
    default:
       break;
    }
diff --git a/src/compiler/shader_info.h b/src/compiler/shader_info.h
index 3bb6779..f7c3904 100644
--- a/src/compiler/shader_info.h
+++ b/src/compiler/shader_info.h
@@ -49,6 +49,7 @@
    bool float64_atomic_add;
    bool fragment_shader_sample_interlock;
    bool fragment_shader_pixel_interlock;
+   bool generic_pointers;
    bool geometry_streams;
    bool image_ms_array;
    bool image_read_without_format;
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 35e2f55..f5d7fe4 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -4185,6 +4185,10 @@
          spv_check_supported(kernel, cap);
          break;
 
+      case SpvCapabilityGenericPointer:
+         spv_check_supported(generic_pointers, cap);
+         break;
+
       case SpvCapabilityImageBasic:
          spv_check_supported(kernel_image, cap);
          break;
@@ -4197,7 +4201,6 @@
       case SpvCapabilityImageMipmap:
       case SpvCapabilityPipes:
       case SpvCapabilityDeviceEnqueue:
-      case SpvCapabilityGenericPointer:
          vtn_warn("Unsupported OpenCL-style SPIR-V capability: %s",
                   spirv_capability_to_string(cap));
          break;
@@ -5031,6 +5034,8 @@
    case SpvOpArrayLength:
    case SpvOpConvertPtrToU:
    case SpvOpConvertUToPtr:
+   case SpvOpGenericCastToPtrExplicit:
+   case SpvOpGenericPtrMemSemantics:
       vtn_handle_variables(b, opcode, w, count);
       break;
 
diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
index efc7eea..1814fa7 100644
--- a/src/compiler/spirv/vtn_alu.c
+++ b/src/compiler/spirv/vtn_alu.c
@@ -355,6 +355,10 @@
       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
       return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
    }
+
+   case SpvOpPtrCastToGeneric:   return nir_op_mov;
+   case SpvOpGenericCastToPtr:   return nir_op_mov;
+
    /* Derivatives: */
    case SpvOpDPdx:         return nir_op_fddx;
    case SpvOpDPdy:         return nir_op_fddy;
diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h
index ec6493d..7c5e9b8 100644
--- a/src/compiler/spirv/vtn_private.h
+++ b/src/compiler/spirv/vtn_private.h
@@ -487,6 +487,7 @@
    vtn_variable_mode_push_constant,
    vtn_variable_mode_workgroup,
    vtn_variable_mode_cross_workgroup,
+   vtn_variable_mode_generic,
    vtn_variable_mode_constant,
    vtn_variable_mode_input,
    vtn_variable_mode_output,
diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c
index 672d087..1cc5c2b 100644
--- a/src/compiler/spirv/vtn_variables.c
+++ b/src/compiler/spirv/vtn_variables.c
@@ -1383,6 +1383,9 @@
       nir_mode = nir_var_mem_ubo;
       break;
    case SpvStorageClassGeneric:
+      mode = vtn_variable_mode_generic;
+      nir_mode = nir_var_mem_generic;
+      break;
    default:
       vtn_fail("Unhandled variable storage class: %s (%u)",
                spirv_storageclass_to_string(class), class);
@@ -1413,6 +1416,7 @@
    case vtn_variable_mode_workgroup:
       return b->options->shared_addr_format;
 
+   case vtn_variable_mode_generic:
    case vtn_variable_mode_cross_workgroup:
       return b->options->global_addr_format;
 
@@ -1636,6 +1640,10 @@
          glsl_get_explicit_size(without_array->type, false);
       break;
 
+   case vtn_variable_mode_generic:
+      vtn_fail("Cannot create a variable with the Generic storage class");
+      break;
+
    case vtn_variable_mode_image:
       vtn_fail("Cannot create a variable with the Image storage class");
       break;
@@ -1692,11 +1700,12 @@
       break;
 
    case vtn_variable_mode_workgroup:
+   case vtn_variable_mode_cross_workgroup:
       /* Create the variable normally */
       var->var = rzalloc(b->shader, nir_variable);
       var->var->name = ralloc_strdup(var->var, val->name);
       var->var->type = vtn_type_get_nir_type(b, var->type, var->mode);
-      var->var->data.mode = nir_var_mem_shared;
+      var->var->data.mode = nir_mode;
       break;
 
    case vtn_variable_mode_input:
@@ -1791,12 +1800,9 @@
       break;
    }
 
-   case vtn_variable_mode_cross_workgroup:
-      /* These don't need actual variables. */
-      break;
-
    case vtn_variable_mode_image:
    case vtn_variable_mode_phys_ssbo:
+   case vtn_variable_mode_generic:
       unreachable("Should have been caught before");
    }
 
@@ -2340,6 +2346,84 @@
       break;
    }
 
+   case SpvOpGenericCastToPtrExplicit: {
+      struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+      struct vtn_type *src_type = vtn_get_value_type(b, w[3]);
+      SpvStorageClass storage_class = w[4];
+
+      vtn_fail_if(dst_type->base_type != vtn_base_type_pointer ||
+                  dst_type->storage_class != storage_class,
+                  "Result type of an SpvOpGenericCastToPtrExplicit must be "
+                  "an OpTypePointer. Its Storage Class must match the "
+                  "storage class specified in the instruction");
+
+      vtn_fail_if(src_type->base_type != vtn_base_type_pointer ||
+                  src_type->deref->id != dst_type->deref->id,
+                  "Source pointer of an SpvOpGenericCastToPtrExplicit must "
+                  "have a type of OpTypePointer whose Type is the same as "
+                  "the Type of Result Type");
+
+      vtn_fail_if(src_type->storage_class != SpvStorageClassGeneric,
+                  "Source pointer of an SpvOpGenericCastToPtrExplicit must "
+                  "point to the Generic Storage Class.");
+
+      vtn_fail_if(storage_class != SpvStorageClassWorkgroup &&
+                  storage_class != SpvStorageClassCrossWorkgroup &&
+                  storage_class != SpvStorageClassFunction,
+                  "Storage must be one of the following literal values from "
+                  "Storage Class: Workgroup, CrossWorkgroup, or Function.");
+
+      nir_deref_instr *src_deref = vtn_nir_deref(b, w[3]);
+
+      nir_variable_mode nir_mode;
+      enum vtn_variable_mode mode =
+         vtn_storage_class_to_mode(b, storage_class, dst_type->deref, &nir_mode);
+      nir_address_format addr_format = vtn_mode_to_address_format(b, mode);
+
+      nir_ssa_def *null_value =
+         nir_build_imm(&b->nb, nir_address_format_num_components(addr_format),
+                               nir_address_format_bit_size(addr_format),
+                               nir_address_format_null_value(addr_format));
+
+      nir_ssa_def *valid = nir_build_deref_mode_is(&b->nb, src_deref, nir_mode);
+      vtn_push_nir_ssa(b, w[2], nir_bcsel(&b->nb, valid,
+                                                  &src_deref->dest.ssa,
+                                                  null_value));
+      break;
+   }
+
+   case SpvOpGenericPtrMemSemantics: {
+      struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+      struct vtn_type *src_type = vtn_get_value_type(b, w[3]);
+
+      vtn_fail_if(dst_type->base_type != vtn_base_type_scalar ||
+                  dst_type->type != glsl_uint_type(),
+                  "Result type of an SpvOpGenericPtrMemSemantics must be "
+                  "an OpTypeInt with 32-bit Width and 0 Signedness.");
+
+      vtn_fail_if(src_type->base_type != vtn_base_type_pointer ||
+                  src_type->storage_class != SpvStorageClassGeneric,
+                  "Source pointer of an SpvOpGenericPtrMemSemantics must "
+                  "point to the Generic Storage Class");
+
+      nir_deref_instr *src_deref = vtn_nir_deref(b, w[3]);
+
+      nir_ssa_def *global_bit =
+         nir_bcsel(&b->nb, nir_build_deref_mode_is(&b->nb, src_deref,
+                                                   nir_var_mem_global),
+                   nir_imm_int(&b->nb, SpvMemorySemanticsCrossWorkgroupMemoryMask),
+                   nir_imm_int(&b->nb, 0));
+
+      nir_ssa_def *shared_bit =
+         nir_bcsel(&b->nb, nir_build_deref_mode_is(&b->nb, src_deref,
+                                                   nir_var_mem_shared),
+                   nir_imm_int(&b->nb, SpvMemorySemanticsWorkgroupMemoryMask),
+                   nir_imm_int(&b->nb, 0));
+
+      vtn_push_nir_ssa(b, w[2], nir_iand(&b->nb, global_bit, shared_bit));
+      break;
+   }
+
    default:
       vtn_fail_with_opcode("Unhandled opcode", opcode);
    }