ac/llvm: don't lower bool to int32, switch to native i1 bool

Acked-by: Pierre-Eric Pelloux-Prayer <pierre-eric.pelloux-prayer@amd.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7077>
diff --git a/src/amd/llvm/ac_llvm_build.c b/src/amd/llvm/ac_llvm_build.c
index e0ddf45..ada9c4d 100644
--- a/src/amd/llvm/ac_llvm_build.c
+++ b/src/amd/llvm/ac_llvm_build.c
@@ -203,7 +203,9 @@
 
 static LLVMTypeRef to_integer_type_scalar(struct ac_llvm_context *ctx, LLVMTypeRef t)
 {
-   if (t == ctx->i8)
+   if (t == ctx->i1)
+      return ctx->i1;
+   else if (t == ctx->i8)
       return ctx->i8;
    else if (t == ctx->f16 || t == ctx->i16)
       return ctx->i16;
@@ -435,6 +437,9 @@
 {
    const char *name;
 
+   if (LLVMTypeOf(value) == ctx->i1)
+      value = LLVMBuildZExt(ctx->builder, value, ctx->i32, "");
+
    if (LLVM_VERSION_MAJOR >= 9) {
       if (ctx->wave_size == 64)
          name = "llvm.amdgcn.icmp.i64.i32";
@@ -3171,19 +3176,6 @@
    LLVMPositionBuilderAtEnd(ctx->builder, if_block);
 }
 
-void ac_build_if(struct ac_llvm_context *ctx, LLVMValueRef value, int label_id)
-{
-   LLVMValueRef cond = LLVMBuildFCmp(ctx->builder, LLVMRealUNE, value, ctx->f32_0, "");
-   ac_build_ifcc(ctx, cond, label_id);
-}
-
-void ac_build_uif(struct ac_llvm_context *ctx, LLVMValueRef value, int label_id)
-{
-   LLVMValueRef cond =
-      LLVMBuildICmp(ctx->builder, LLVMIntNE, ac_to_integer(ctx, value), ctx->i32_0, "");
-   ac_build_ifcc(ctx, cond, label_id);
-}
-
 LLVMValueRef ac_build_alloca_undef(struct ac_llvm_context *ac, LLVMTypeRef type, const char *name)
 {
    LLVMBuilderRef builder = ac->builder;
@@ -3631,7 +3623,18 @@
 static LLVMValueRef get_reduction_identity(struct ac_llvm_context *ctx, nir_op op,
                                            unsigned type_size)
 {
-   if (type_size == 1) {
+
+   if (type_size == 0) {
+      switch (op) {
+      case nir_op_ior:
+      case nir_op_ixor:
+         return LLVMConstInt(ctx->i1, 0, 0);
+      case nir_op_iand:
+         return LLVMConstInt(ctx->i1, 1, 0);
+      default:
+         unreachable("bad reduction intrinsic");
+      }
+   } else if (type_size == 1) {
       switch (op) {
       case nir_op_iadd:
          return ctx->i8_0;
@@ -4366,8 +4369,7 @@
 {
    LLVMValueRef result =
       ac_build_intrinsic(ctx, "llvm.amdgcn.ps.live", ctx->i1, NULL, 0, AC_FUNC_ATTR_READNONE);
-   result = LLVMBuildNot(ctx->builder, result, "");
-   return LLVMBuildSExt(ctx->builder, result, ctx->i32, "");
+   return LLVMBuildNot(ctx->builder, result, "");
 }
 
 LLVMValueRef ac_build_is_helper_invocation(struct ac_llvm_context *ctx)
@@ -4380,10 +4382,7 @@
       ac_build_intrinsic(ctx, "llvm.amdgcn.ps.live", ctx->i1, NULL, 0, AC_FUNC_ATTR_READNONE);
 
    LLVMValueRef postponed = LLVMBuildLoad(ctx->builder, ctx->postponed_kill, "");
-   LLVMValueRef result = LLVMBuildAnd(ctx->builder, exact, postponed, "");
-
-   return LLVMBuildSelect(ctx->builder, result, ctx->i32_0,
-                          LLVMConstInt(ctx->i32, 0xFFFFFFFF, false), "");
+   return LLVMBuildNot(ctx->builder, LLVMBuildAnd(ctx->builder, exact, postponed, ""), "");
 }
 
 LLVMValueRef ac_build_call(struct ac_llvm_context *ctx, LLVMValueRef func, LLVMValueRef *args,
diff --git a/src/amd/llvm/ac_llvm_build.h b/src/amd/llvm/ac_llvm_build.h
index d34f6c5..72b349f 100644
--- a/src/amd/llvm/ac_llvm_build.h
+++ b/src/amd/llvm/ac_llvm_build.h
@@ -466,8 +466,6 @@
 void ac_build_endif(struct ac_llvm_context *ctx, int lable_id);
 void ac_build_endloop(struct ac_llvm_context *ctx, int lable_id);
 void ac_build_ifcc(struct ac_llvm_context *ctx, LLVMValueRef cond, int label_id);
-void ac_build_if(struct ac_llvm_context *ctx, LLVMValueRef value, int lable_id);
-void ac_build_uif(struct ac_llvm_context *ctx, LLVMValueRef value, int lable_id);
 
 LLVMValueRef ac_build_alloca(struct ac_llvm_context *ac, LLVMTypeRef type, const char *name);
 LLVMValueRef ac_build_alloca_undef(struct ac_llvm_context *ac, LLVMTypeRef type, const char *name);
diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c
index 325baf5..646b752 100644
--- a/src/amd/llvm/ac_nir_to_llvm.c
+++ b/src/amd/llvm/ac_nir_to_llvm.c
@@ -146,20 +146,15 @@
       src0 = LLVMBuildIntToPtr(ctx->builder, src0, src1_type, "");
    }
 
-   LLVMValueRef result = LLVMBuildICmp(ctx->builder, pred, src0, src1, "");
-   return LLVMBuildSelect(ctx->builder, result, LLVMConstInt(ctx->i32, 0xFFFFFFFF, false),
-                          ctx->i32_0, "");
+   return LLVMBuildICmp(ctx->builder, pred, src0, src1, "");
 }
 
 static LLVMValueRef emit_float_cmp(struct ac_llvm_context *ctx, LLVMRealPredicate pred,
                                    LLVMValueRef src0, LLVMValueRef src1)
 {
-   LLVMValueRef result;
    src0 = ac_to_float(ctx, src0);
    src1 = ac_to_float(ctx, src1);
-   result = LLVMBuildFCmp(ctx->builder, pred, src0, src1, "");
-   return LLVMBuildSelect(ctx->builder, result, LLVMConstInt(ctx->i32, 0xFFFFFFFF, false),
-                          ctx->i32_0, "");
+   return LLVMBuildFCmp(ctx->builder, pred, src0, src1, "");
 }
 
 static LLVMValueRef emit_intrin_1f_param(struct ac_llvm_context *ctx, const char *intrin,
@@ -250,9 +245,7 @@
       src1 = LLVMBuildIntToPtr(ctx->builder, src1, src2_type, "");
    }
 
-   LLVMValueRef v =
-      LLVMBuildICmp(ctx->builder, LLVMIntNE, src0, LLVMConstNull(LLVMTypeOf(src0)), "");
-   return LLVMBuildSelect(ctx->builder, v, ac_to_integer_or_pointer(ctx, src1),
+   return LLVMBuildSelect(ctx->builder, src0, ac_to_integer_or_pointer(ctx, src1),
                           ac_to_integer_or_pointer(ctx, src2), "");
 }
 
@@ -279,20 +272,25 @@
 
 static LLVMValueRef emit_b2f(struct ac_llvm_context *ctx, LLVMValueRef src0, unsigned bitsize)
 {
-   assert(ac_get_elem_bits(ctx, LLVMTypeOf(src0)) == 32);
-   LLVMValueRef result =
-      LLVMBuildAnd(ctx->builder, src0, ac_const_uint_vec(ctx, LLVMTypeOf(src0), 0x3f800000), "");
-   result = ac_to_float(ctx, result);
+   assert(ac_get_elem_bits(ctx, LLVMTypeOf(src0)) == 1);
 
    switch (bitsize) {
-   case 16: {
-      bool vec2 = LLVMGetTypeKind(LLVMTypeOf(result)) == LLVMVectorTypeKind;
-      return LLVMBuildFPTrunc(ctx->builder, result, vec2 ? ctx->v2f16 : ctx->f16, "");
-   }
+   case 16:
+      if (LLVMGetTypeKind(LLVMTypeOf(src0)) == LLVMVectorTypeKind) {
+         assert(LLVMGetVectorSize(LLVMTypeOf(src0)) == 2);
+         LLVMValueRef f[] = {
+            LLVMBuildSelect(ctx->builder, ac_llvm_extract_elem(ctx, src0, 0),
+                            ctx->f16_1, ctx->f16_0, ""),
+            LLVMBuildSelect(ctx->builder, ac_llvm_extract_elem(ctx, src0, 1),
+                            ctx->f16_1, ctx->f16_0, ""),
+         };
+         return ac_build_gather_values(ctx, f, 2);
+      }
+      return LLVMBuildSelect(ctx->builder, src0, ctx->f16_1, ctx->f16_0, "");
    case 32:
-      return result;
+      return LLVMBuildSelect(ctx->builder, src0, ctx->f32_1, ctx->f32_0, "");
    case 64:
-      return LLVMBuildFPExt(ctx->builder, result, ctx->f64, "");
+      return LLVMBuildSelect(ctx->builder, src0, ctx->f64_1, ctx->f64_0, "");
    default:
       unreachable("Unsupported bit size.");
    }
@@ -302,23 +300,20 @@
 {
    src0 = ac_to_float(ctx, src0);
    LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(src0));
-   return LLVMBuildSExt(ctx->builder, LLVMBuildFCmp(ctx->builder, LLVMRealUNE, src0, zero, ""),
-                        ctx->i32, "");
+   return LLVMBuildFCmp(ctx->builder, LLVMRealUNE, src0, zero, "");
 }
 
 static LLVMValueRef emit_b2i(struct ac_llvm_context *ctx, LLVMValueRef src0, unsigned bitsize)
 {
-   LLVMValueRef result = LLVMBuildAnd(ctx->builder, src0, ctx->i32_1, "");
-
    switch (bitsize) {
    case 8:
-      return LLVMBuildTrunc(ctx->builder, result, ctx->i8, "");
+      return LLVMBuildSelect(ctx->builder, src0, ctx->i8_1, ctx->i8_0, "");
    case 16:
-      return LLVMBuildTrunc(ctx->builder, result, ctx->i16, "");
+      return LLVMBuildSelect(ctx->builder, src0, ctx->i16_1, ctx->i16_0, "");
    case 32:
-      return result;
+      return LLVMBuildSelect(ctx->builder, src0, ctx->i32_1, ctx->i32_0, "");
    case 64:
-      return LLVMBuildZExt(ctx->builder, result, ctx->i64, "");
+      return LLVMBuildSelect(ctx->builder, src0, ctx->i64_1, ctx->i64_0, "");
    default:
       unreachable("Unsupported bit size.");
    }
@@ -327,8 +322,7 @@
 static LLVMValueRef emit_i2b(struct ac_llvm_context *ctx, LLVMValueRef src0)
 {
    LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(src0));
-   return LLVMBuildSExt(ctx->builder, LLVMBuildICmp(ctx->builder, LLVMIntNE, src0, zero, ""),
-                        ctx->i32, "");
+   return LLVMBuildICmp(ctx->builder, LLVMIntNE, src0, zero, "");
 }
 
 static LLVMValueRef emit_f2f16(struct ac_llvm_context *ctx, LLVMValueRef src0)
@@ -703,34 +697,34 @@
          src[1] = LLVMBuildTrunc(ctx->ac.builder, src[1], LLVMTypeOf(src[0]), "");
       result = LLVMBuildLShr(ctx->ac.builder, src[0], src[1], "");
       break;
-   case nir_op_ilt32:
+   case nir_op_ilt:
       result = emit_int_cmp(&ctx->ac, LLVMIntSLT, src[0], src[1]);
       break;
-   case nir_op_ine32:
+   case nir_op_ine:
       result = emit_int_cmp(&ctx->ac, LLVMIntNE, src[0], src[1]);
       break;
-   case nir_op_ieq32:
+   case nir_op_ieq:
       result = emit_int_cmp(&ctx->ac, LLVMIntEQ, src[0], src[1]);
       break;
-   case nir_op_ige32:
+   case nir_op_ige:
       result = emit_int_cmp(&ctx->ac, LLVMIntSGE, src[0], src[1]);
       break;
-   case nir_op_ult32:
+   case nir_op_ult:
       result = emit_int_cmp(&ctx->ac, LLVMIntULT, src[0], src[1]);
       break;
-   case nir_op_uge32:
+   case nir_op_uge:
       result = emit_int_cmp(&ctx->ac, LLVMIntUGE, src[0], src[1]);
       break;
-   case nir_op_feq32:
+   case nir_op_feq:
       result = emit_float_cmp(&ctx->ac, LLVMRealOEQ, src[0], src[1]);
       break;
-   case nir_op_fneu32:
+   case nir_op_fneu:
       result = emit_float_cmp(&ctx->ac, LLVMRealUNE, src[0], src[1]);
       break;
-   case nir_op_flt32:
+   case nir_op_flt:
       result = emit_float_cmp(&ctx->ac, LLVMRealOLT, src[0], src[1]);
       break;
-   case nir_op_fge32:
+   case nir_op_fge:
       result = emit_float_cmp(&ctx->ac, LLVMRealOGE, src[0], src[1]);
       break;
    case nir_op_fabs:
@@ -987,7 +981,7 @@
       else
          result = LLVMBuildTrunc(ctx->ac.builder, src[0], def_type, "");
       break;
-   case nir_op_b32csel:
+   case nir_op_bcsel:
       result = emit_bcsel(&ctx->ac, src[0], src[1], src[2]);
       break;
    case nir_op_find_lsb:
@@ -1010,7 +1004,7 @@
    case nir_op_b2f64:
       result = emit_b2f(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
       break;
-   case nir_op_f2b32:
+   case nir_op_f2b1:
       result = emit_f2b(&ctx->ac, src[0]);
       break;
    case nir_op_b2i8:
@@ -1019,9 +1013,16 @@
    case nir_op_b2i64:
       result = emit_b2i(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
       break;
-   case nir_op_i2b32:
+   case nir_op_i2b1:
+   case nir_op_b2b1: /* after loads */
       result = emit_i2b(&ctx->ac, src[0]);
       break;
+   case nir_op_b2b16: /* before stores */
+      result = LLVMBuildZExt(ctx->ac.builder, src[0], ctx->ac.i16, "");
+      break;
+   case nir_op_b2b32: /* before stores */
+      result = LLVMBuildZExt(ctx->ac.builder, src[0], ctx->ac.i32, "");
+      break;
    case nir_op_fquantize2f16:
       result = emit_f2f16(&ctx->ac, src[0]);
       break;
@@ -1179,6 +1180,9 @@
 
    for (unsigned i = 0; i < instr->def.num_components; ++i) {
       switch (instr->def.bit_size) {
+      case 1:
+         values[i] = LLVMConstInt(element_type, instr->value[i].b, false);
+         break;
       case 8:
          values[i] = LLVMConstInt(element_type, instr->value[i].u8, false);
          break;
@@ -2758,8 +2762,7 @@
    LLVMValueRef cond;
 
    if (instr->intrinsic == nir_intrinsic_discard_if) {
-      cond =
-         LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, get_src(ctx, instr->src[0]), ctx->ac.i32_0, "");
+      cond = LLVMBuildNot(ctx->ac.builder, get_src(ctx, instr->src[0]), "");
    } else {
       assert(instr->intrinsic == nir_intrinsic_discard);
       cond = ctx->ac.i1false;
@@ -2773,8 +2776,7 @@
    LLVMValueRef cond;
 
    if (instr->intrinsic == nir_intrinsic_demote_if) {
-      cond =
-         LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, get_src(ctx, instr->src[0]), ctx->ac.i32_0, "");
+      cond = LLVMBuildNot(ctx->ac.builder, get_src(ctx, instr->src[0]), "");
    } else {
       assert(instr->intrinsic == nir_intrinsic_demote);
       cond = ctx->ac.i1false;
@@ -3337,7 +3339,7 @@
       result = ctx->abi->inputs[ac_llvm_reg_index_soa(VARYING_SLOT_LAYER, 0)];
       break;
    case nir_intrinsic_load_front_face:
-      result = ac_get_arg(&ctx->ac, ctx->args->front_face);
+      result = emit_i2b(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->front_face));
       break;
    case nir_intrinsic_load_helper_invocation:
       result = ac_build_load_helper_invocation(&ctx->ac);
@@ -3637,13 +3639,11 @@
       result = ctx->abi->load_patch_vertices_in(ctx->abi);
       break;
    case nir_intrinsic_vote_all: {
-      LLVMValueRef tmp = ac_build_vote_all(&ctx->ac, get_src(ctx, instr->src[0]));
-      result = LLVMBuildSExt(ctx->ac.builder, tmp, ctx->ac.i32, "");
+      result = ac_build_vote_all(&ctx->ac, get_src(ctx, instr->src[0]));
       break;
    }
    case nir_intrinsic_vote_any: {
-      LLVMValueRef tmp = ac_build_vote_any(&ctx->ac, get_src(ctx, instr->src[0]));
-      result = LLVMBuildSExt(ctx->ac.builder, tmp, ctx->ac.i32, "");
+      result = ac_build_vote_any(&ctx->ac, get_src(ctx, instr->src[0]));
       break;
    }
    case nir_intrinsic_shuffle:
@@ -4251,7 +4251,9 @@
       result = build_tex_intrinsic(ctx, instr, &txf_args);
 
       result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
-      result = emit_int_cmp(&ctx->ac, LLVMIntEQ, result, ctx->ac.i32_0);
+      result = LLVMBuildSExt(ctx->ac.builder,
+                             emit_int_cmp(&ctx->ac, LLVMIntEQ, result, ctx->ac.i32_0),
+                             ctx->ac.i32, "");
       goto write_result;
    }
 
@@ -4643,7 +4645,7 @@
 
    nir_block *then_block = (nir_block *)exec_list_get_head(&if_stmt->then_list);
 
-   ac_build_uif(&ctx->ac, value, then_block->index);
+   ac_build_ifcc(&ctx->ac, value, then_block->index);
 
    visit_cf_list(ctx, &if_stmt->then_list);
 
diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c
index b21b4c8..9751ff0 100644
--- a/src/amd/vulkan/radv_pipeline.c
+++ b/src/amd/vulkan/radv_pipeline.c
@@ -3019,9 +3019,7 @@
 			/* do this again since information such as outputs_read can be out-of-date */
 			nir_shader_gather_info(nir[i], nir_shader_get_entrypoint(nir[i]));
 
-			if (radv_use_llvm_for_stage(device, i)) {
-				NIR_PASS_V(nir[i], nir_lower_bool_to_int32);
-			} else {
+			if (!radv_use_llvm_for_stage(device, i)) {
 				NIR_PASS_V(nir[i], nir_lower_non_uniform_access,
 				           nir_lower_non_uniform_ubo_access |
 				           nir_lower_non_uniform_ssbo_access |
diff --git a/src/gallium/drivers/radeonsi/si_shader.c b/src/gallium/drivers/radeonsi/si_shader.c
index ac85ec6..f6a592f 100644
--- a/src/gallium/drivers/radeonsi/si_shader.c
+++ b/src/gallium/drivers/radeonsi/si_shader.c
@@ -1611,8 +1611,6 @@
       return NULL;
    }
 
-   NIR_PASS_V(nir, nir_lower_bool_to_int32);
-
    return nir;
 }