spirv,nir: Add ray-tracing intrinsics

Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6479>
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index e43062f..ef77bba 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -1944,7 +1944,7 @@
    unsigned _pad:7;
 } nir_io_semantics;
 
-#define NIR_INTRINSIC_MAX_INPUTS 5
+#define NIR_INTRINSIC_MAX_INPUTS 11
 
 typedef struct {
    const char *name;
diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py
index d93b551..df9f22f 100644
--- a/src/compiler/nir/nir_intrinsics.py
+++ b/src/compiler/nir/nir_intrinsics.py
@@ -373,6 +373,28 @@
 # Contains the final total vertex and primitive counts in the current GS thread.
 intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1], indices=[STREAM_ID])
 
+# Trace a ray through an acceleration structure
+#
+# This instruction has a lot of parameters:
+#   0. Acceleration Structure
+#   1. Ray Flags
+#   2. Cull Mask
+#   3. SBT Offset
+#   4. SBT Stride
+#   5. Miss shader index
+#   6. Ray Origin
+#   7. Ray Tmin
+#   8. Ray Direction
+#   9. Ray Tmax
+#   10. Payload
+intrinsic("trace_ray", src_comp=[-1, 1, 1, 1, 1, 1, 3, 1, 3, 1, -1])
+# src[] = { hit_t, hit_kind }
+intrinsic("report_ray_intersection", src_comp=[1, 1], dest_comp=1)
+intrinsic("ignore_ray_intersection")
+intrinsic("terminate_ray")
+# src[] = { sbt_index, payload }
+intrinsic("execute_callable", src_comp=[1, -1])
+
 # Atomic counters
 #
 # The *_var variants take an atomic_uint nir_variable, while the other,
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index e0eda4c..9c92bd8 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -5067,6 +5067,65 @@
    vtn_push_nir_ssa(b, w[2], def);
 }
 
+static void
+vtn_handle_ray_intrinsic(struct vtn_builder *b, SpvOp opcode,
+                         const uint32_t *w, unsigned count)
+{
+   nir_intrinsic_instr *intrin;
+
+   switch (opcode) {
+   case SpvOpTraceRayKHR: {
+      intrin = nir_intrinsic_instr_create(b->nb.shader,
+                                          nir_intrinsic_trace_ray);
+
+      /* The sources are in the same order in the NIR intrinsic */
+      for (unsigned i = 0; i < 10; i++)
+         intrin->src[i] = nir_src_for_ssa(vtn_ssa_value(b, w[i + 1])->def);
+
+      nir_deref_instr *payload = vtn_get_call_payload_for_location(b, w[11]);
+      intrin->src[10] = nir_src_for_ssa(&payload->dest.ssa);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+      break;
+   }
+
+   case SpvOpReportIntersectionKHR: {
+      intrin = nir_intrinsic_instr_create(b->nb.shader,
+                                          nir_intrinsic_report_ray_intersection);
+      intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[3])->def);
+      intrin->src[1] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+      nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 1, NULL);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+      vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
+      break;
+   }
+
+   case SpvOpIgnoreIntersectionKHR:
+      intrin = nir_intrinsic_instr_create(b->nb.shader,
+                                          nir_intrinsic_ignore_ray_intersection);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+      break;
+
+   case SpvOpTerminateRayKHR:
+      intrin = nir_intrinsic_instr_create(b->nb.shader,
+                                          nir_intrinsic_terminate_ray);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+      break;
+
+   case SpvOpExecuteCallableKHR: {
+      intrin = nir_intrinsic_instr_create(b->nb.shader,
+                                          nir_intrinsic_execute_callable);
+      intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[1])->def);
+      nir_deref_instr *payload = vtn_get_call_payload_for_location(b, w[2]);
+      intrin->src[1] = nir_src_for_ssa(&payload->dest.ssa);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+      break;
+   }
+
+   default:
+      vtn_fail_with_opcode("Unhandled opcode", opcode);
+   }
+}
+
 static bool
 vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
                             const uint32_t *w, unsigned count)
@@ -5476,6 +5535,14 @@
       break;
    }
 
+   case SpvOpTraceRayKHR:
+   case SpvOpReportIntersectionKHR:
+   case SpvOpIgnoreIntersectionKHR:
+   case SpvOpTerminateRayKHR:
+   case SpvOpExecuteCallableKHR:
+      vtn_handle_ray_intrinsic(b, opcode, w, count);
+      break;
+
    case SpvOpLifetimeStart:
    case SpvOpLifetimeStop:
       break;
diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h
index 5dec9af..45187a0 100644
--- a/src/compiler/spirv/vtn_private.h
+++ b/src/compiler/spirv/vtn_private.h
@@ -844,6 +844,9 @@
 vtn_pointer_to_offset(struct vtn_builder *b, struct vtn_pointer *ptr,
                       nir_ssa_def **index_out);
 
+nir_deref_instr *
+vtn_get_call_payload_for_location(struct vtn_builder *b, uint32_t location_id);
+
 struct vtn_ssa_value *
 vtn_local_load(struct vtn_builder *b, nir_deref_instr *src,
                enum gl_access_qualifier access);
diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c
index 191192c..168d0e5 100644
--- a/src/compiler/spirv/vtn_variables.c
+++ b/src/compiler/spirv/vtn_variables.c
@@ -1716,6 +1716,18 @@
    }
 }
 
+nir_deref_instr *
+vtn_get_call_payload_for_location(struct vtn_builder *b, uint32_t location_id)
+{
+   uint32_t location = vtn_constant_uint(b, location_id);
+   nir_foreach_variable_with_modes(var, b->nb.shader, nir_var_shader_temp) {
+      if (var->data.explicit_location &&
+          var->data.location == location)
+         return nir_build_deref_var(&b->nb, var);
+   }
+   vtn_fail("Couldn't find variable with a storage class of CallableDataKHR "
+            "or RayPayloadKHR and location %d", location);
+}
 
 static void
 vtn_create_variable(struct vtn_builder *b, struct vtn_value *val,
@@ -1813,6 +1825,14 @@
       var->var = rzalloc(b->shader, nir_variable);
       var->var->name = ralloc_strdup(var->var, val->name);
       var->var->type = vtn_type_get_nir_type(b, var->type, var->mode);
+
+      /* This is a total hack but we need some way to flag variables which are
+       * going to be call payloads.  See get_call_payload_deref.
+       */
+      if (storage_class == SpvStorageClassCallableDataKHR ||
+          storage_class == SpvStorageClassRayPayloadKHR)
+         var->var->data.explicit_location = true;
+
       var->var->data.mode = nir_mode;
       var->var->data.location = -1;
       var->var->interface_type = NULL;