nir/lower_io: Support generic pointer access

If the pointer is generic and we haven't yet figured out what kind of
pointer it is yet, we emit an if-ladder based on a mode check.

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6332>
diff --git a/src/compiler/nir/nir_lower_io.c b/src/compiler/nir/nir_lower_io.c
index 01918fd..71fb408 100644
--- a/src/compiler/nir/nir_lower_io.c
+++ b/src/compiler/nir/nir_lower_io.c
@@ -732,8 +732,8 @@
 
    ralloc_free(state.dead_ctx);
 
-   nir_metadata_preserve(impl, nir_metadata_block_index |
-                               nir_metadata_dominance);
+   nir_metadata_preserve(impl, nir_metadata_none);
+
    return progress;
 }
 
@@ -1112,14 +1112,80 @@
    *out_range = ~0;
 }
 
+static nir_variable_mode
+canonicalize_generic_modes(nir_variable_mode modes)
+{
+   assert(modes != 0);
+   if (util_bitcount(modes) == 1)
+      return modes;
+
+   assert(!(modes & ~(nir_var_function_temp | nir_var_shader_temp |
+                      nir_var_mem_shared | nir_var_mem_global)));
+
+   /* Canonicalize by converting shader_temp to function_temp */
+   if (modes & nir_var_shader_temp) {
+      modes &= ~nir_var_shader_temp;
+      modes |= nir_var_function_temp;
+   }
+
+   return modes;
+}
+
 static nir_ssa_def *
 build_explicit_io_load(nir_builder *b, nir_intrinsic_instr *intrin,
                        nir_ssa_def *addr, nir_address_format addr_format,
+                       nir_variable_mode modes,
                        uint32_t align_mul, uint32_t align_offset,
                        unsigned num_components)
 {
    nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
-   nir_variable_mode mode = deref->modes;
+   modes = canonicalize_generic_modes(modes);
+
+   if (util_bitcount(modes) > 1) {
+      if (addr_format_is_global(addr_format, modes)) {
+         return build_explicit_io_load(b, intrin, addr, addr_format,
+                                       nir_var_mem_global,
+                                       align_mul, align_offset,
+                                       num_components);
+      } else if (modes & nir_var_function_temp) {
+         nir_push_if(b, build_runtime_addr_mode_check(b, addr, addr_format,
+                                                      nir_var_function_temp));
+         nir_ssa_def *res1 =
+            build_explicit_io_load(b, intrin, addr, addr_format,
+                                   nir_var_function_temp,
+                                   align_mul, align_offset,
+                                   num_components);
+         nir_push_else(b, NULL);
+         nir_ssa_def *res2 =
+            build_explicit_io_load(b, intrin, addr, addr_format,
+                                   modes & ~nir_var_function_temp,
+                                   align_mul, align_offset,
+                                   num_components);
+         nir_pop_if(b, NULL);
+         return nir_if_phi(b, res1, res2);
+      } else {
+         nir_push_if(b, build_runtime_addr_mode_check(b, addr, addr_format,
+                                                      nir_var_mem_shared));
+         assert(modes & nir_var_mem_shared);
+         nir_ssa_def *res1 =
+            build_explicit_io_load(b, intrin, addr, addr_format,
+                                   nir_var_mem_shared,
+                                   align_mul, align_offset,
+                                   num_components);
+         nir_push_else(b, NULL);
+         assert(modes & nir_var_mem_global);
+         nir_ssa_def *res2 =
+            build_explicit_io_load(b, intrin, addr, addr_format,
+                                   nir_var_mem_global,
+                                   align_mul, align_offset,
+                                   num_components);
+         nir_pop_if(b, NULL);
+         return nir_if_phi(b, res1, res2);
+      }
+   }
+
+   assert(util_bitcount(modes) == 1);
+   const nir_variable_mode mode = modes;
 
    nir_intrinsic_op op;
    switch (mode) {
@@ -1260,10 +1326,52 @@
 static void
 build_explicit_io_store(nir_builder *b, nir_intrinsic_instr *intrin,
                         nir_ssa_def *addr, nir_address_format addr_format,
+                        nir_variable_mode modes,
                         uint32_t align_mul, uint32_t align_offset,
                         nir_ssa_def *value, nir_component_mask_t write_mask)
 {
-   nir_variable_mode mode = nir_src_as_deref(intrin->src[0])->modes;
+   modes = canonicalize_generic_modes(modes);
+
+   if (util_bitcount(modes) > 1) {
+      if (addr_format_is_global(addr_format, modes)) {
+         build_explicit_io_store(b, intrin, addr, addr_format,
+                                 nir_var_mem_global,
+                                 align_mul, align_offset,
+                                 value, write_mask);
+      } else if (modes & nir_var_function_temp) {
+         nir_push_if(b, build_runtime_addr_mode_check(b, addr, addr_format,
+                                                      nir_var_function_temp));
+         build_explicit_io_store(b, intrin, addr, addr_format,
+                                 nir_var_function_temp,
+                                 align_mul, align_offset,
+                                 value, write_mask);
+         nir_push_else(b, NULL);
+         build_explicit_io_store(b, intrin, addr, addr_format,
+                                 modes & ~nir_var_function_temp,
+                                 align_mul, align_offset,
+                                 value, write_mask);
+         nir_pop_if(b, NULL);
+      } else {
+         nir_push_if(b, build_runtime_addr_mode_check(b, addr, addr_format,
+                                                      nir_var_mem_shared));
+         assert(modes & nir_var_mem_shared);
+         build_explicit_io_store(b, intrin, addr, addr_format,
+                                 nir_var_mem_shared,
+                                 align_mul, align_offset,
+                                 value, write_mask);
+         nir_push_else(b, NULL);
+         assert(modes & nir_var_mem_global);
+         build_explicit_io_store(b, intrin, addr, addr_format,
+                                 nir_var_mem_global,
+                                 align_mul, align_offset,
+                                 value, write_mask);
+         nir_pop_if(b, NULL);
+      }
+      return;
+   }
+
+   assert(util_bitcount(modes) == 1);
+   const nir_variable_mode mode = modes;
 
    nir_intrinsic_op op;
    switch (mode) {
@@ -1349,9 +1457,47 @@
 
 static nir_ssa_def *
 build_explicit_io_atomic(nir_builder *b, nir_intrinsic_instr *intrin,
-                         nir_ssa_def *addr, nir_address_format addr_format)
+                         nir_ssa_def *addr, nir_address_format addr_format,
+                         nir_variable_mode modes)
 {
-   nir_variable_mode mode = nir_src_as_deref(intrin->src[0])->modes;
+   modes = canonicalize_generic_modes(modes);
+
+   if (util_bitcount(modes) > 1) {
+      if (addr_format_is_global(addr_format, modes)) {
+         return build_explicit_io_atomic(b, intrin, addr, addr_format,
+                                         nir_var_mem_global);
+      } else if (modes & nir_var_function_temp) {
+         nir_push_if(b, build_runtime_addr_mode_check(b, addr, addr_format,
+                                                      nir_var_function_temp));
+         nir_ssa_def *res1 =
+            build_explicit_io_atomic(b, intrin, addr, addr_format,
+                                     nir_var_function_temp);
+         nir_push_else(b, NULL);
+         nir_ssa_def *res2 =
+            build_explicit_io_atomic(b, intrin, addr, addr_format,
+                                     modes & ~nir_var_function_temp);
+         nir_pop_if(b, NULL);
+         return nir_if_phi(b, res1, res2);
+      } else {
+         nir_push_if(b, build_runtime_addr_mode_check(b, addr, addr_format,
+                                                      nir_var_mem_shared));
+         assert(modes & nir_var_mem_shared);
+         nir_ssa_def *res1 =
+            build_explicit_io_atomic(b, intrin, addr, addr_format,
+                                     nir_var_mem_shared);
+         nir_push_else(b, NULL);
+         assert(modes & nir_var_mem_global);
+         nir_ssa_def *res2 =
+            build_explicit_io_atomic(b, intrin, addr, addr_format,
+                                     nir_var_mem_global);
+         nir_pop_if(b, NULL);
+         return nir_if_phi(b, res1, res2);
+      }
+   }
+
+   assert(util_bitcount(modes) == 1);
+   const nir_variable_mode mode = modes;
+
    const unsigned num_data_srcs =
       nir_intrinsic_infos[intrin->intrinsic].num_srcs - 1;
 
@@ -1499,7 +1645,8 @@
                                                          deref->modes,
                                                          comp_offset);
             comps[i] = build_explicit_io_load(b, intrin, comp_addr,
-                                              addr_format, align_mul,
+                                              addr_format, deref->modes,
+                                              align_mul,
                                               (align_offset + comp_offset) %
                                                  align_mul,
                                               1);
@@ -1507,7 +1654,7 @@
          value = nir_vec(b, comps, intrin->num_components);
       } else {
          value = build_explicit_io_load(b, intrin, addr, addr_format,
-                                        align_mul, align_offset,
+                                        deref->modes, align_mul, align_offset,
                                         intrin->num_components);
       }
       nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(value));
@@ -1528,13 +1675,13 @@
                                                          deref->modes,
                                                          comp_offset);
             build_explicit_io_store(b, intrin, comp_addr, addr_format,
-                                    align_mul,
+                                    deref->modes, align_mul,
                                     (align_offset + comp_offset) % align_mul,
                                     nir_channel(b, value, i), 1);
          }
       } else {
          build_explicit_io_store(b, intrin, addr, addr_format,
-                                 align_mul, align_offset,
+                                 deref->modes, align_mul, align_offset,
                                  value, write_mask);
       }
       break;
@@ -1542,7 +1689,7 @@
 
    default: {
       nir_ssa_def *value =
-         build_explicit_io_atomic(b, intrin, addr, addr_format);
+         build_explicit_io_atomic(b, intrin, addr, addr_format, deref->modes);
       nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(value));
       break;
    }