ac/nir: handle all lowered IO intrinsics
Acked-by: Pierre-Eric Pelloux-Prayer <pierre-eric.pelloux-prayer@amd.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Connor Abbott <cwabbott0@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6445>
diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c
index 6dc155d..ed0cb80 100644
--- a/src/amd/llvm/ac_nir_to_llvm.c
+++ b/src/amd/llvm/ac_nir_to_llvm.c
@@ -2269,6 +2269,7 @@
switch (mode) {
case nir_var_shader_in:
+ /* TODO: remove this after RADV switches to lowered IO */
if (ctx->stage == MESA_SHADER_TESS_CTRL ||
ctx->stage == MESA_SHADER_TESS_EVAL) {
return load_tess_varyings(ctx, instr, true);
@@ -2324,6 +2325,7 @@
}
break;
case nir_var_shader_out:
+ /* TODO: remove this after RADV switches to lowered IO */
if (ctx->stage == MESA_SHADER_TESS_CTRL) {
return load_tess_varyings(ctx, instr, false);
}
@@ -2444,7 +2446,7 @@
switch (deref->mode) {
case nir_var_shader_out:
-
+ /* TODO: remove this after RADV switches to lowered IO */
if (ctx->stage == MESA_SHADER_TESS_CTRL) {
LLVMValueRef vertex_index = NULL;
LLVMValueRef indir_index = NULL;
@@ -2459,7 +2461,9 @@
ctx->abi->store_tcs_outputs(ctx->abi, var,
vertex_index, indir_index,
- const_index, src, writemask);
+ const_index, src, writemask,
+ var->data.location_frac,
+ var->data.driver_location);
break;
}
@@ -2581,6 +2585,71 @@
ac_build_endif(&ctx->ac, 7002);
}
+static void
+visit_store_output(struct ac_nir_context *ctx, nir_intrinsic_instr *instr)
+{
+ if (ctx->ac.postponed_kill) {
+ LLVMValueRef cond = LLVMBuildLoad(ctx->ac.builder,
+ ctx->ac.postponed_kill, "");
+ ac_build_ifcc(&ctx->ac, cond, 7002);
+ }
+
+ unsigned base = nir_intrinsic_base(instr);
+ unsigned writemask = nir_intrinsic_write_mask(instr);
+ unsigned component = nir_intrinsic_component(instr);
+ LLVMValueRef src = ac_to_float(&ctx->ac, get_src(ctx, instr->src[0]));
+ nir_src offset = *nir_get_io_offset_src(instr);
+ LLVMValueRef indir_index = NULL;
+
+ if (nir_src_is_const(offset))
+ assert(nir_src_as_uint(offset) == 0);
+ else
+ indir_index = get_src(ctx, offset);
+
+ switch (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src))) {
+ case 32:
+ break;
+ case 64:
+ writemask = widen_mask(writemask, 2);
+ src = LLVMBuildBitCast(ctx->ac.builder, src,
+ LLVMVectorType(ctx->ac.f32, ac_get_llvm_num_components(src) * 2),
+ "");
+ break;
+ default:
+ unreachable("unhandled store_output bit size");
+ return;
+ }
+
+ writemask <<= component;
+
+ if (ctx->stage == MESA_SHADER_TESS_CTRL) {
+ nir_src *vertex_index_src = nir_get_io_vertex_index_src(instr);
+ LLVMValueRef vertex_index =
+ vertex_index_src ? get_src(ctx, *vertex_index_src) : NULL;
+
+ ctx->abi->store_tcs_outputs(ctx->abi, NULL,
+ vertex_index, indir_index,
+ 0, src, writemask,
+ component, base * 4);
+ return;
+ }
+
+ /* No indirect indexing is allowed after this point. */
+ assert(!indir_index);
+
+ for (unsigned chan = 0; chan < 8; chan++) {
+ if (!(writemask & (1 << chan)))
+ continue;
+
+ LLVMValueRef value = ac_llvm_extract_elem(&ctx->ac, src, chan - component);
+ LLVMBuildStore(ctx->ac.builder, value,
+ ctx->abi->outputs[base * 4 + chan]);
+ }
+
+ if (ctx->ac.postponed_kill)
+ ac_build_endif(&ctx->ac, 7002);
+}
+
static int image_type_to_components_count(enum glsl_sampler_dim dim, bool array)
{
switch (dim) {
@@ -3578,18 +3647,82 @@
return ac_to_integer(&ctx->ac, ac_build_gather_values(&ctx->ac, values, num_components));
}
-static LLVMValueRef load_input(struct ac_nir_context *ctx,
- nir_intrinsic_instr *instr)
+static LLVMValueRef visit_load(struct ac_nir_context *ctx,
+ nir_intrinsic_instr *instr, bool is_output)
{
- unsigned offset_idx = instr->intrinsic == nir_intrinsic_load_input ? 0 : 1;
-
- /* We only lower inputs for fragment shaders ATM */
- ASSERTED nir_const_value *offset = nir_src_as_const_value(instr->src[offset_idx]);
- assert(offset);
- assert(offset[0].i32 == 0);
-
+ LLVMValueRef values[8];
+ LLVMTypeRef dest_type = get_def_type(ctx, &instr->dest.ssa);
+ LLVMTypeRef component_type;
+ unsigned base = nir_intrinsic_base(instr);
unsigned component = nir_intrinsic_component(instr);
- unsigned index = nir_intrinsic_base(instr);
+ unsigned count = instr->dest.ssa.num_components *
+ (instr->dest.ssa.bit_size == 64 ? 2 : 1);
+ nir_src *vertex_index_src = nir_get_io_vertex_index_src(instr);
+ LLVMValueRef vertex_index =
+ vertex_index_src ? get_src(ctx, *vertex_index_src) : NULL;
+ nir_src offset = *nir_get_io_offset_src(instr);
+ LLVMValueRef indir_index = NULL;
+
+ if (LLVMGetTypeKind(dest_type) == LLVMVectorTypeKind)
+ component_type = LLVMGetElementType(dest_type);
+ else
+ component_type = dest_type;
+
+ if (nir_src_is_const(offset))
+ assert(nir_src_as_uint(offset) == 0);
+ else
+ indir_index = get_src(ctx, offset);
+
+ if (ctx->stage == MESA_SHADER_TESS_CTRL ||
+ (ctx->stage == MESA_SHADER_TESS_EVAL && !is_output)) {
+ LLVMValueRef result =
+ ctx->abi->load_tess_varyings(ctx->abi, component_type,
+ vertex_index, indir_index,
+ 0, 0, base * 4,
+ component,
+ instr->num_components,
+ false, false, !is_output);
+ if (instr->dest.ssa.bit_size == 16) {
+ result = ac_to_integer(&ctx->ac, result);
+ result = LLVMBuildTrunc(ctx->ac.builder, result, dest_type, "");
+ }
+ return LLVMBuildBitCast(ctx->ac.builder, result, dest_type, "");
+ }
+
+ /* No indirect indexing is allowed after this point. */
+ assert(!indir_index);
+
+ if (ctx->stage == MESA_SHADER_GEOMETRY) {
+ LLVMTypeRef type = LLVMIntTypeInContext(ctx->ac.context, instr->dest.ssa.bit_size);
+ assert(nir_src_is_const(*vertex_index_src));
+
+ return ctx->abi->load_inputs(ctx->abi, 0, base * 4, component,
+ instr->num_components,
+ nir_src_as_uint(*vertex_index_src),
+ 0, type);
+ }
+
+ if (ctx->stage == MESA_SHADER_FRAGMENT && is_output &&
+ nir_intrinsic_io_semantics(instr).fb_fetch_output)
+ return ctx->abi->emit_fbfetch(ctx->abi);
+
+ /* Other non-fragment cases have inputs and outputs in temporaries. */
+ if (ctx->stage != MESA_SHADER_FRAGMENT) {
+ for (unsigned chan = component; chan < count + component; chan++) {
+ if (is_output) {
+ values[chan] = LLVMBuildLoad(ctx->ac.builder,
+ ctx->abi->outputs[base * 4 + chan], "");
+ } else {
+ values[chan] = ctx->abi->inputs[base * 4 + chan];
+ if (!values[chan])
+ values[chan] = LLVMGetUndef(ctx->ac.i32);
+ }
+ }
+ LLVMValueRef result = ac_build_varying_gather_values(&ctx->ac, values, count, component);
+ return LLVMBuildBitCast(ctx->ac.builder, result, dest_type, "");
+ }
+
+ /* Fragment shader inputs. */
unsigned vertex_id = 2; /* P0 */
if (instr->intrinsic == nir_intrinsic_load_input_vertex) {
@@ -3610,18 +3743,11 @@
}
}
- LLVMValueRef attr_number = LLVMConstInt(ctx->ac.i32, index, false);
- LLVMValueRef values[8];
+ LLVMValueRef attr_number = LLVMConstInt(ctx->ac.i32, base, false);
- /* Each component of a 64-bit value takes up two GL-level channels. */
- unsigned num_components = instr->dest.ssa.num_components;
- unsigned bit_size = instr->dest.ssa.bit_size;
- unsigned channels =
- bit_size == 64 ? num_components * 2 : num_components;
-
- for (unsigned chan = 0; chan < channels; chan++) {
+ for (unsigned chan = 0; chan < count; chan++) {
if (component + chan > 4)
- attr_number = LLVMConstInt(ctx->ac.i32, index + 1, false);
+ attr_number = LLVMConstInt(ctx->ac.i32, base + 1, false);
LLVMValueRef llvm_chan = LLVMConstInt(ctx->ac.i32, (component + chan) % 4, false);
values[chan] = ac_build_fs_interp_mov(&ctx->ac,
LLVMConstInt(ctx->ac.i32, vertex_id, false),
@@ -3630,16 +3756,12 @@
ac_get_arg(&ctx->ac, ctx->args->prim_mask));
values[chan] = LLVMBuildBitCast(ctx->ac.builder, values[chan], ctx->ac.i32, "");
values[chan] = LLVMBuildTruncOrBitCast(ctx->ac.builder, values[chan],
- bit_size == 16 ? ctx->ac.i16 : ctx->ac.i32, "");
+ instr->dest.ssa.bit_size == 16 ? ctx->ac.i16
+ : ctx->ac.i32, "");
}
- LLVMValueRef result = ac_build_gather_values(&ctx->ac, values, channels);
- if (bit_size == 64) {
- LLVMTypeRef type = num_components == 1 ? ctx->ac.i64 :
- LLVMVectorType(ctx->ac.i64, num_components);
- result = LLVMBuildBitCast(ctx->ac.builder, result, type, "");
- }
- return result;
+ LLVMValueRef result = ac_build_gather_values(&ctx->ac, values, count);
+ return LLVMBuildBitCast(ctx->ac.builder, result, dest_type, "");
}
static void visit_intrinsic(struct ac_nir_context *ctx,
@@ -3836,6 +3958,19 @@
case nir_intrinsic_store_deref:
visit_store_var(ctx, instr);
break;
+ case nir_intrinsic_load_input:
+ case nir_intrinsic_load_input_vertex:
+ case nir_intrinsic_load_per_vertex_input:
+ result = visit_load(ctx, instr, false);
+ break;
+ case nir_intrinsic_load_output:
+ case nir_intrinsic_load_per_vertex_output:
+ result = visit_load(ctx, instr, true);
+ break;
+ case nir_intrinsic_store_output:
+ case nir_intrinsic_store_per_vertex_output:
+ visit_store_output(ctx, instr);
+ break;
case nir_intrinsic_load_shared:
result = visit_load_shared(ctx, instr);
break;
@@ -4003,10 +4138,6 @@
instr->dest.ssa.bit_size);
break;
}
- case nir_intrinsic_load_input:
- case nir_intrinsic_load_input_vertex:
- result = load_input(ctx, instr);
- break;
case nir_intrinsic_emit_vertex:
ctx->abi->emit_vertex(ctx->abi, nir_intrinsic_stream_id(instr), ctx->abi->outputs);
break;
@@ -5339,9 +5470,13 @@
ctx.main_function = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx.ac.builder));
- nir_foreach_shader_out_variable(variable, nir)
- ac_handle_shader_output_decl(&ctx.ac, ctx.abi, nir, variable,
- ctx.stage);
+ /* TODO: remove this after RADV switches to lowered IO */
+ if (!nir->info.io_lowered) {
+ nir_foreach_shader_out_variable(variable, nir) {
+ ac_handle_shader_output_decl(&ctx.ac, ctx.abi, nir, variable,
+ ctx.stage);
+ }
+ }
ctx.defs = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
_mesa_key_pointer_equal);
diff --git a/src/amd/llvm/ac_shader_abi.h b/src/amd/llvm/ac_shader_abi.h
index 80b1554..359e948 100644
--- a/src/amd/llvm/ac_shader_abi.h
+++ b/src/amd/llvm/ac_shader_abi.h
@@ -113,7 +113,9 @@
LLVMValueRef param_index,
unsigned const_index,
LLVMValueRef src,
- unsigned writemask);
+ unsigned writemask,
+ unsigned component,
+ unsigned driver_location);
LLVMValueRef (*load_tess_coord)(struct ac_shader_abi *abi);
diff --git a/src/amd/vulkan/radv_nir_to_llvm.c b/src/amd/vulkan/radv_nir_to_llvm.c
index b962f25..db21ad8 100644
--- a/src/amd/vulkan/radv_nir_to_llvm.c
+++ b/src/amd/vulkan/radv_nir_to_llvm.c
@@ -589,11 +589,12 @@
LLVMValueRef param_index,
unsigned const_index,
LLVMValueRef src,
- unsigned writemask)
+ unsigned writemask,
+ unsigned component,
+ unsigned driver_location)
{
struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
const unsigned location = var->data.location;
- unsigned component = var->data.location_frac;
const bool is_patch = var->data.patch;
const bool is_compact = var->data.compact;
LLVMValueRef dw_addr;
diff --git a/src/gallium/drivers/radeonsi/si_shader_llvm_tess.c b/src/gallium/drivers/radeonsi/si_shader_llvm_tess.c
index 13bed5f..e54b4a8 100644
--- a/src/gallium/drivers/radeonsi/si_shader_llvm_tess.c
+++ b/src/gallium/drivers/radeonsi/si_shader_llvm_tess.c
@@ -509,12 +509,11 @@
static void si_nir_store_output_tcs(struct ac_shader_abi *abi, const struct nir_variable *var,
LLVMValueRef vertex_index, LLVMValueRef param_index,
- unsigned const_index, LLVMValueRef src, unsigned writemask)
+ unsigned const_index, LLVMValueRef src, unsigned writemask,
+ unsigned component, unsigned driver_location)
{
struct si_shader_context *ctx = si_shader_context_from_abi(abi);
struct si_shader_info *info = &ctx->shader->selector->info;
- unsigned component = var->data.location_frac;
- unsigned driver_location = var->data.driver_location;
LLVMValueRef dw_addr, stride;
LLVMValueRef buffer, base, addr;
LLVMValueRef values[8];