blob: 4a8f0dd6f69e5900d3b79aba7405296a049a5fe4 [file] [log] [blame]
// RUN: hlo_to_llvm_ir %s | FileCheck %s
// CHECK-LABEL: entry:
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [1024 x half]*
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1024 x half]*
// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [1023 x half]*
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [1023 x half]*
// CHECK: %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0
// CHECK: %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [1024 x half]*
// CHECK: %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0
// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1023 x half]*
// CHECK: %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0
// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [0 x half]*
// CHECK: %[[VAL_21:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_22:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_23:.*]] = mul nuw nsw i32 %[[VAL_21]], 1024
// CHECK: %[[VAL_24:.*]] = add nuw nsw i32 %[[VAL_23]], %[[VAL_22]]
// CHECK: %[[VAL_25:.*]] = icmp ult i32 %[[VAL_24]], 2048
// CHECK: call void @llvm.assume(i1 %[[VAL_25]])
// CHECK: %[[VAL_26:.*]] = udiv i32 %[[VAL_24]], 1
// CHECK: %[[VAL_27:.*]] = icmp ult i32 %[[VAL_24]], 2047
// CHECK: br i1 %[[VAL_27]], label %[[VAL_28:.*]], label %[[VAL_29:.*]]
// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_30:.*]], %[[VAL_31:.*]]
// CHECK: ret void
// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_31]]
// CHECK: %[[VAL_32:.*]] = icmp ult i32 %[[VAL_26]], 1024
// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]]
// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_28]]
// CHECK: %[[VAL_35:.*]] = phi i32 [ 0, %[[VAL_28]] ]
// CHECK: %[[VAL_36:.*]] = sub nsw i32 %[[VAL_26]], %[[VAL_35]]
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_2]], i32 0, i32 %[[VAL_36]]
// CHECK: %[[VAL_38:.*]] = load half, half* %[[VAL_37]], align 2, !invariant.load !4
// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_5]], i32 0, i32 %[[VAL_36]]
// CHECK: %[[VAL_40:.*]] = load half, half* %[[VAL_39]], align 2, !invariant.load !4
// CHECK: %[[VAL_41:.*]] = fmul half %[[VAL_38]], %[[VAL_40]]
// CHECK: br label %[[VAL_42:.*]]
// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_34]]
// CHECK: %[[VAL_43:.*]] = phi i32 [ 1024, %[[VAL_34]] ]
// CHECK: %[[VAL_44:.*]] = sub nsw i32 %[[VAL_26]], %[[VAL_43]]
// CHECK: %[[VAL_45:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_8]], i32 0, i32 %[[VAL_44]]
// CHECK: %[[VAL_46:.*]] = load half, half* %[[VAL_45]], align 2, !invariant.load !4
// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_11]], i32 0, i32 %[[VAL_44]]
// CHECK: %[[VAL_48:.*]] = load half, half* %[[VAL_47]], align 2, !invariant.load !4
// CHECK: %[[VAL_49:.*]] = fadd half %[[VAL_46]], %[[VAL_48]]
// CHECK: br label %[[VAL_42]]
// CHECK: concat_index_not_from_operand0: ; preds = %[[VAL_28]]
// CHECK: %[[VAL_50:.*]] = icmp ult i32 %[[VAL_26]], 2047
// CHECK: br i1 %[[VAL_50]], label %[[VAL_51:.*]], label %[[VAL_52:.*]]
// CHECK: concat_index_not_from_operand1: ; preds = %[[VAL_34]]
// CHECK: unreachable
// CHECK: concatenate.7.merge: ; preds = %[[VAL_51]], %[[VAL_33]]
// CHECK: %[[VAL_53:.*]] = phi half [ %[[VAL_41]], %[[VAL_33]] ], [ %[[VAL_49]], %[[VAL_51]] ]
// CHECK: %[[VAL_54:.*]] = icmp sge i32 %[[VAL_26]], 0
// CHECK: %[[VAL_55:.*]] = icmp slt i32 %[[VAL_26]], 1024
// CHECK: %[[VAL_56:.*]] = and i1 %[[VAL_54]], %[[VAL_55]]
// CHECK: br i1 %[[VAL_56]], label %[[VAL_57:.*]], label %[[VAL_58:.*]]
// CHECK: slice0-after: ; preds = %[[VAL_57]], %[[VAL_42]]
// CHECK: %[[VAL_59:.*]] = icmp sge i32 %[[VAL_26]], 1024
// CHECK: %[[VAL_60:.*]] = icmp slt i32 %[[VAL_26]], 2047
// CHECK: %[[VAL_61:.*]] = and i1 %[[VAL_59]], %[[VAL_60]]
// CHECK: br i1 %[[VAL_61]], label %[[VAL_62:.*]], label %[[VAL_63:.*]]
// CHECK: slice1-after: ; preds = %[[VAL_62]], %[[VAL_58]]
// CHECK: %[[VAL_64:.*]] = icmp sge i32 %[[VAL_26]], 2047
// CHECK: %[[VAL_65:.*]] = icmp slt i32 %[[VAL_26]], 2047
// CHECK: %[[VAL_66:.*]] = and i1 %[[VAL_64]], %[[VAL_65]]
// CHECK: br i1 %[[VAL_66]], label %[[VAL_67:.*]], label %[[VAL_30]]
// CHECK: slice2-after: ; preds = %[[VAL_67]], %[[VAL_63]]
// CHECK: br label %[[VAL_29]]
// CHECK: slice0-true: ; preds = %[[VAL_42]]
// CHECK: %[[VAL_68:.*]] = sub i32 %[[VAL_26]], 0
// CHECK: %[[VAL_69:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_14]], i32 0, i32 %[[VAL_68]]
// CHECK: store half %[[VAL_53]], half* %[[VAL_69]], align 2
// CHECK: br label %[[VAL_58]]
// CHECK: slice1-true: ; preds = %[[VAL_58]]
// CHECK: %[[VAL_70:.*]] = sub i32 %[[VAL_26]], 1024
// CHECK: %[[VAL_71:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_17]], i32 0, i32 %[[VAL_70]]
// CHECK: store half %[[VAL_53]], half* %[[VAL_71]], align 2
// CHECK: br label %[[VAL_63]]
// CHECK: slice2-true: ; preds = %[[VAL_63]]
// CHECK: %[[VAL_72:.*]] = sub i32 %[[VAL_26]], 2047
// CHECK: %[[VAL_73:.*]] = getelementptr inbounds [0 x half], [0 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_72]]
// CHECK: store half %[[VAL_53]], half* %[[VAL_73]], align 2
// CHECK: br label %[[VAL_30]]
HloModule input_fusion_with_a_tuple_of_slices
fused_computation {
arg.1 = f16[1024]{0} parameter(0)
arg.2 = f16[1024]{0} parameter(1)
arg.3 = f16[1023]{0} parameter(2)
arg.4 = f16[1023]{0} parameter(3)
mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
add.1 = f16[1023]{0} add(arg.3, arg.4)
concat.1 = f16[2047]{0} concatenate(mul.1, add.1), dimensions={0}
slice.1 = f16[1024]{0} slice(concat.1), slice={[0:1024]}
slice.2 = f16[1023]{0} slice(concat.1), slice={[1024:2047]}
slice.3 = f16[0]{0} slice(concat.1), slice={[2047:2047]}
ROOT tuple.1 = (f16[1024]{0}, f16[1023]{0}, f16[0]{0})
tuple(slice.1, slice.2, slice.3)
}
ENTRY kernel_entry {
arg.1 = f16[1024]{0} parameter(0)
arg.2 = f16[1024]{0} parameter(1)
arg.3 = f16[1023]{0} parameter(2)
arg.4 = f16[1023]{0} parameter(3)
ROOT fusion = (f16[1024]{0}, f16[1023]{0}, f16[0]{0})
fusion(arg.1, arg.2, arg.3, arg.4), kind=kInput, calls=fused_computation
}