ac/llvm: don't lower bool to int32, switch to native i1 bool
Acked-by: Pierre-Eric Pelloux-Prayer <pierre-eric.pelloux-prayer@amd.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7077>
diff --git a/src/amd/llvm/ac_llvm_build.c b/src/amd/llvm/ac_llvm_build.c
index e0ddf45..ada9c4d 100644
--- a/src/amd/llvm/ac_llvm_build.c
+++ b/src/amd/llvm/ac_llvm_build.c
@@ -203,7 +203,9 @@
static LLVMTypeRef to_integer_type_scalar(struct ac_llvm_context *ctx, LLVMTypeRef t)
{
- if (t == ctx->i8)
+ if (t == ctx->i1)
+ return ctx->i1;
+ else if (t == ctx->i8)
return ctx->i8;
else if (t == ctx->f16 || t == ctx->i16)
return ctx->i16;
@@ -435,6 +437,9 @@
{
const char *name;
+ if (LLVMTypeOf(value) == ctx->i1)
+ value = LLVMBuildZExt(ctx->builder, value, ctx->i32, "");
+
if (LLVM_VERSION_MAJOR >= 9) {
if (ctx->wave_size == 64)
name = "llvm.amdgcn.icmp.i64.i32";
@@ -3171,19 +3176,6 @@
LLVMPositionBuilderAtEnd(ctx->builder, if_block);
}
-void ac_build_if(struct ac_llvm_context *ctx, LLVMValueRef value, int label_id)
-{
- LLVMValueRef cond = LLVMBuildFCmp(ctx->builder, LLVMRealUNE, value, ctx->f32_0, "");
- ac_build_ifcc(ctx, cond, label_id);
-}
-
-void ac_build_uif(struct ac_llvm_context *ctx, LLVMValueRef value, int label_id)
-{
- LLVMValueRef cond =
- LLVMBuildICmp(ctx->builder, LLVMIntNE, ac_to_integer(ctx, value), ctx->i32_0, "");
- ac_build_ifcc(ctx, cond, label_id);
-}
-
LLVMValueRef ac_build_alloca_undef(struct ac_llvm_context *ac, LLVMTypeRef type, const char *name)
{
LLVMBuilderRef builder = ac->builder;
@@ -3631,7 +3623,18 @@
static LLVMValueRef get_reduction_identity(struct ac_llvm_context *ctx, nir_op op,
unsigned type_size)
{
- if (type_size == 1) {
+
+ if (type_size == 0) {
+ switch (op) {
+ case nir_op_ior:
+ case nir_op_ixor:
+ return LLVMConstInt(ctx->i1, 0, 0);
+ case nir_op_iand:
+ return LLVMConstInt(ctx->i1, 1, 0);
+ default:
+ unreachable("bad reduction intrinsic");
+ }
+ } else if (type_size == 1) {
switch (op) {
case nir_op_iadd:
return ctx->i8_0;
@@ -4366,8 +4369,7 @@
{
LLVMValueRef result =
ac_build_intrinsic(ctx, "llvm.amdgcn.ps.live", ctx->i1, NULL, 0, AC_FUNC_ATTR_READNONE);
- result = LLVMBuildNot(ctx->builder, result, "");
- return LLVMBuildSExt(ctx->builder, result, ctx->i32, "");
+ return LLVMBuildNot(ctx->builder, result, "");
}
LLVMValueRef ac_build_is_helper_invocation(struct ac_llvm_context *ctx)
@@ -4380,10 +4382,7 @@
ac_build_intrinsic(ctx, "llvm.amdgcn.ps.live", ctx->i1, NULL, 0, AC_FUNC_ATTR_READNONE);
LLVMValueRef postponed = LLVMBuildLoad(ctx->builder, ctx->postponed_kill, "");
- LLVMValueRef result = LLVMBuildAnd(ctx->builder, exact, postponed, "");
-
- return LLVMBuildSelect(ctx->builder, result, ctx->i32_0,
- LLVMConstInt(ctx->i32, 0xFFFFFFFF, false), "");
+ return LLVMBuildNot(ctx->builder, LLVMBuildAnd(ctx->builder, exact, postponed, ""), "");
}
LLVMValueRef ac_build_call(struct ac_llvm_context *ctx, LLVMValueRef func, LLVMValueRef *args,
diff --git a/src/amd/llvm/ac_llvm_build.h b/src/amd/llvm/ac_llvm_build.h
index d34f6c5..72b349f 100644
--- a/src/amd/llvm/ac_llvm_build.h
+++ b/src/amd/llvm/ac_llvm_build.h
@@ -466,8 +466,6 @@
void ac_build_endif(struct ac_llvm_context *ctx, int lable_id);
void ac_build_endloop(struct ac_llvm_context *ctx, int lable_id);
void ac_build_ifcc(struct ac_llvm_context *ctx, LLVMValueRef cond, int label_id);
-void ac_build_if(struct ac_llvm_context *ctx, LLVMValueRef value, int lable_id);
-void ac_build_uif(struct ac_llvm_context *ctx, LLVMValueRef value, int lable_id);
LLVMValueRef ac_build_alloca(struct ac_llvm_context *ac, LLVMTypeRef type, const char *name);
LLVMValueRef ac_build_alloca_undef(struct ac_llvm_context *ac, LLVMTypeRef type, const char *name);
diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c
index 325baf5..646b752 100644
--- a/src/amd/llvm/ac_nir_to_llvm.c
+++ b/src/amd/llvm/ac_nir_to_llvm.c
@@ -146,20 +146,15 @@
src0 = LLVMBuildIntToPtr(ctx->builder, src0, src1_type, "");
}
- LLVMValueRef result = LLVMBuildICmp(ctx->builder, pred, src0, src1, "");
- return LLVMBuildSelect(ctx->builder, result, LLVMConstInt(ctx->i32, 0xFFFFFFFF, false),
- ctx->i32_0, "");
+ return LLVMBuildICmp(ctx->builder, pred, src0, src1, "");
}
static LLVMValueRef emit_float_cmp(struct ac_llvm_context *ctx, LLVMRealPredicate pred,
LLVMValueRef src0, LLVMValueRef src1)
{
- LLVMValueRef result;
src0 = ac_to_float(ctx, src0);
src1 = ac_to_float(ctx, src1);
- result = LLVMBuildFCmp(ctx->builder, pred, src0, src1, "");
- return LLVMBuildSelect(ctx->builder, result, LLVMConstInt(ctx->i32, 0xFFFFFFFF, false),
- ctx->i32_0, "");
+ return LLVMBuildFCmp(ctx->builder, pred, src0, src1, "");
}
static LLVMValueRef emit_intrin_1f_param(struct ac_llvm_context *ctx, const char *intrin,
@@ -250,9 +245,7 @@
src1 = LLVMBuildIntToPtr(ctx->builder, src1, src2_type, "");
}
- LLVMValueRef v =
- LLVMBuildICmp(ctx->builder, LLVMIntNE, src0, LLVMConstNull(LLVMTypeOf(src0)), "");
- return LLVMBuildSelect(ctx->builder, v, ac_to_integer_or_pointer(ctx, src1),
+ return LLVMBuildSelect(ctx->builder, src0, ac_to_integer_or_pointer(ctx, src1),
ac_to_integer_or_pointer(ctx, src2), "");
}
@@ -279,20 +272,25 @@
static LLVMValueRef emit_b2f(struct ac_llvm_context *ctx, LLVMValueRef src0, unsigned bitsize)
{
- assert(ac_get_elem_bits(ctx, LLVMTypeOf(src0)) == 32);
- LLVMValueRef result =
- LLVMBuildAnd(ctx->builder, src0, ac_const_uint_vec(ctx, LLVMTypeOf(src0), 0x3f800000), "");
- result = ac_to_float(ctx, result);
+ assert(ac_get_elem_bits(ctx, LLVMTypeOf(src0)) == 1);
switch (bitsize) {
- case 16: {
- bool vec2 = LLVMGetTypeKind(LLVMTypeOf(result)) == LLVMVectorTypeKind;
- return LLVMBuildFPTrunc(ctx->builder, result, vec2 ? ctx->v2f16 : ctx->f16, "");
- }
+ case 16:
+ if (LLVMGetTypeKind(LLVMTypeOf(src0)) == LLVMVectorTypeKind) {
+ assert(LLVMGetVectorSize(LLVMTypeOf(src0)) == 2);
+ LLVMValueRef f[] = {
+ LLVMBuildSelect(ctx->builder, ac_llvm_extract_elem(ctx, src0, 0),
+ ctx->f16_1, ctx->f16_0, ""),
+ LLVMBuildSelect(ctx->builder, ac_llvm_extract_elem(ctx, src0, 1),
+ ctx->f16_1, ctx->f16_0, ""),
+ };
+ return ac_build_gather_values(ctx, f, 2);
+ }
+ return LLVMBuildSelect(ctx->builder, src0, ctx->f16_1, ctx->f16_0, "");
case 32:
- return result;
+ return LLVMBuildSelect(ctx->builder, src0, ctx->f32_1, ctx->f32_0, "");
case 64:
- return LLVMBuildFPExt(ctx->builder, result, ctx->f64, "");
+ return LLVMBuildSelect(ctx->builder, src0, ctx->f64_1, ctx->f64_0, "");
default:
unreachable("Unsupported bit size.");
}
@@ -302,23 +300,20 @@
{
src0 = ac_to_float(ctx, src0);
LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(src0));
- return LLVMBuildSExt(ctx->builder, LLVMBuildFCmp(ctx->builder, LLVMRealUNE, src0, zero, ""),
- ctx->i32, "");
+ return LLVMBuildFCmp(ctx->builder, LLVMRealUNE, src0, zero, "");
}
static LLVMValueRef emit_b2i(struct ac_llvm_context *ctx, LLVMValueRef src0, unsigned bitsize)
{
- LLVMValueRef result = LLVMBuildAnd(ctx->builder, src0, ctx->i32_1, "");
-
switch (bitsize) {
case 8:
- return LLVMBuildTrunc(ctx->builder, result, ctx->i8, "");
+ return LLVMBuildSelect(ctx->builder, src0, ctx->i8_1, ctx->i8_0, "");
case 16:
- return LLVMBuildTrunc(ctx->builder, result, ctx->i16, "");
+ return LLVMBuildSelect(ctx->builder, src0, ctx->i16_1, ctx->i16_0, "");
case 32:
- return result;
+ return LLVMBuildSelect(ctx->builder, src0, ctx->i32_1, ctx->i32_0, "");
case 64:
- return LLVMBuildZExt(ctx->builder, result, ctx->i64, "");
+ return LLVMBuildSelect(ctx->builder, src0, ctx->i64_1, ctx->i64_0, "");
default:
unreachable("Unsupported bit size.");
}
@@ -327,8 +322,7 @@
static LLVMValueRef emit_i2b(struct ac_llvm_context *ctx, LLVMValueRef src0)
{
LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(src0));
- return LLVMBuildSExt(ctx->builder, LLVMBuildICmp(ctx->builder, LLVMIntNE, src0, zero, ""),
- ctx->i32, "");
+ return LLVMBuildICmp(ctx->builder, LLVMIntNE, src0, zero, "");
}
static LLVMValueRef emit_f2f16(struct ac_llvm_context *ctx, LLVMValueRef src0)
@@ -703,34 +697,34 @@
src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), "");
result = LLVMBuildLShr(ctx->ac.builder, src[0], src[1], "");
break;
- case nir_op_ilt32:
+ case nir_op_ilt:
result = emit_int_cmp(&ctx->ac, LLVMIntSLT, src[0], src[1]);
break;
- case nir_op_ine32:
+ case nir_op_ine:
result = emit_int_cmp(&ctx->ac, LLVMIntNE, src[0], src[1]);
break;
- case nir_op_ieq32:
+ case nir_op_ieq:
result = emit_int_cmp(&ctx->ac, LLVMIntEQ, src[0], src[1]);
break;
- case nir_op_ige32:
+ case nir_op_ige:
result = emit_int_cmp(&ctx->ac, LLVMIntSGE, src[0], src[1]);
break;
- case nir_op_ult32:
+ case nir_op_ult:
result = emit_int_cmp(&ctx->ac, LLVMIntULT, src[0], src[1]);
break;
- case nir_op_uge32:
+ case nir_op_uge:
result = emit_int_cmp(&ctx->ac, LLVMIntUGE, src[0], src[1]);
break;
- case nir_op_feq32:
+ case nir_op_feq:
result = emit_float_cmp(&ctx->ac, LLVMRealOEQ, src[0], src[1]);
break;
- case nir_op_fneu32:
+ case nir_op_fneu:
result = emit_float_cmp(&ctx->ac, LLVMRealUNE, src[0], src[1]);
break;
- case nir_op_flt32:
+ case nir_op_flt:
result = emit_float_cmp(&ctx->ac, LLVMRealOLT, src[0], src[1]);
break;
- case nir_op_fge32:
+ case nir_op_fge:
result = emit_float_cmp(&ctx->ac, LLVMRealOGE, src[0], src[1]);
break;
case nir_op_fabs:
@@ -987,7 +981,7 @@
else
result = LLVMBuildTrunc(ctx->ac.builder, src[0], def_type, "");
break;
- case nir_op_b32csel:
+ case nir_op_bcsel:
result = emit_bcsel(&ctx->ac, src[0], src[1], src[2]);
break;
case nir_op_find_lsb:
@@ -1010,7 +1004,7 @@
case nir_op_b2f64:
result = emit_b2f(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
break;
- case nir_op_f2b32:
+ case nir_op_f2b1:
result = emit_f2b(&ctx->ac, src[0]);
break;
case nir_op_b2i8:
@@ -1019,9 +1013,16 @@
case nir_op_b2i64:
result = emit_b2i(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
break;
- case nir_op_i2b32:
+ case nir_op_i2b1:
+ case nir_op_b2b1: /* after loads */
result = emit_i2b(&ctx->ac, src[0]);
break;
+ case nir_op_b2b16: /* before stores */
+ result = LLVMBuildZExt(ctx->ac.builder, src[0], ctx->ac.i16, "");
+ break;
+ case nir_op_b2b32: /* before stores */
+ result = LLVMBuildZExt(ctx->ac.builder, src[0], ctx->ac.i32, "");
+ break;
case nir_op_fquantize2f16:
result = emit_f2f16(&ctx->ac, src[0]);
break;
@@ -1179,6 +1180,9 @@
for (unsigned i = 0; i < instr->def.num_components; ++i) {
switch (instr->def.bit_size) {
+ case 1:
+ values[i] = LLVMConstInt(element_type, instr->value[i].b, false);
+ break;
case 8:
values[i] = LLVMConstInt(element_type, instr->value[i].u8, false);
break;
@@ -2758,8 +2762,7 @@
LLVMValueRef cond;
if (instr->intrinsic == nir_intrinsic_discard_if) {
- cond =
- LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, get_src(ctx, instr->src[0]), ctx->ac.i32_0, "");
+ cond = LLVMBuildNot(ctx->ac.builder, get_src(ctx, instr->src[0]), "");
} else {
assert(instr->intrinsic == nir_intrinsic_discard);
cond = ctx->ac.i1false;
@@ -2773,8 +2776,7 @@
LLVMValueRef cond;
if (instr->intrinsic == nir_intrinsic_demote_if) {
- cond =
- LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, get_src(ctx, instr->src[0]), ctx->ac.i32_0, "");
+ cond = LLVMBuildNot(ctx->ac.builder, get_src(ctx, instr->src[0]), "");
} else {
assert(instr->intrinsic == nir_intrinsic_demote);
cond = ctx->ac.i1false;
@@ -3337,7 +3339,7 @@
result = ctx->abi->inputs[ac_llvm_reg_index_soa(VARYING_SLOT_LAYER, 0)];
break;
case nir_intrinsic_load_front_face:
- result = ac_get_arg(&ctx->ac, ctx->args->front_face);
+ result = emit_i2b(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->front_face));
break;
case nir_intrinsic_load_helper_invocation:
result = ac_build_load_helper_invocation(&ctx->ac);
@@ -3637,13 +3639,11 @@
result = ctx->abi->load_patch_vertices_in(ctx->abi);
break;
case nir_intrinsic_vote_all: {
- LLVMValueRef tmp = ac_build_vote_all(&ctx->ac, get_src(ctx, instr->src[0]));
- result = LLVMBuildSExt(ctx->ac.builder, tmp, ctx->ac.i32, "");
+ result = ac_build_vote_all(&ctx->ac, get_src(ctx, instr->src[0]));
break;
}
case nir_intrinsic_vote_any: {
- LLVMValueRef tmp = ac_build_vote_any(&ctx->ac, get_src(ctx, instr->src[0]));
- result = LLVMBuildSExt(ctx->ac.builder, tmp, ctx->ac.i32, "");
+ result = ac_build_vote_any(&ctx->ac, get_src(ctx, instr->src[0]));
break;
}
case nir_intrinsic_shuffle:
@@ -4251,7 +4251,9 @@
result = build_tex_intrinsic(ctx, instr, &txf_args);
result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
- result = emit_int_cmp(&ctx->ac, LLVMIntEQ, result, ctx->ac.i32_0);
+ result = LLVMBuildSExt(ctx->ac.builder,
+ emit_int_cmp(&ctx->ac, LLVMIntEQ, result, ctx->ac.i32_0),
+ ctx->ac.i32, "");
goto write_result;
}
@@ -4643,7 +4645,7 @@
nir_block *then_block = (nir_block *)exec_list_get_head(&if_stmt->then_list);
- ac_build_uif(&ctx->ac, value, then_block->index);
+ ac_build_ifcc(&ctx->ac, value, then_block->index);
visit_cf_list(ctx, &if_stmt->then_list);
diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c
index b21b4c8..9751ff0 100644
--- a/src/amd/vulkan/radv_pipeline.c
+++ b/src/amd/vulkan/radv_pipeline.c
@@ -3019,9 +3019,7 @@
/* do this again since information such as outputs_read can be out-of-date */
nir_shader_gather_info(nir[i], nir_shader_get_entrypoint(nir[i]));
- if (radv_use_llvm_for_stage(device, i)) {
- NIR_PASS_V(nir[i], nir_lower_bool_to_int32);
- } else {
+ if (!radv_use_llvm_for_stage(device, i)) {
NIR_PASS_V(nir[i], nir_lower_non_uniform_access,
nir_lower_non_uniform_ubo_access |
nir_lower_non_uniform_ssbo_access |
diff --git a/src/gallium/drivers/radeonsi/si_shader.c b/src/gallium/drivers/radeonsi/si_shader.c
index ac85ec6..f6a592f 100644
--- a/src/gallium/drivers/radeonsi/si_shader.c
+++ b/src/gallium/drivers/radeonsi/si_shader.c
@@ -1611,8 +1611,6 @@
return NULL;
}
- NIR_PASS_V(nir, nir_lower_bool_to_int32);
-
return nir;
}