Refactor kernel thunk's launch dimension setting in EmitInputFusibleNonStridedSlices.
Move SetThunkLaunchDimensions() to right after KernelThunk construction. Launch dimension will be passed to KernelThunk's constructor as a parameter.
PiperOrigin-RevId: 387204323
Change-Id: I6e7ec07edc33a7ccc4d2c70cc0a707b89b9ae83f
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index af98790..d11ef12 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -5316,10 +5316,6 @@
constexpr int unroll_factor = 1;
- std::vector<llvm_ir::IrArray> ir_arrays;
- TF_ASSIGN_OR_RETURN(auto kernel_thunk,
- BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays));
-
TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
GetOrCreateSubComputationFromRegion(&fusion.region(),
/*is_fusion=*/true));
@@ -5330,6 +5326,10 @@
CalculateLaunchDimensions(
element_shape, ir_emitter_context_->gpu_device_info(),
{unroll_factor}));
+
+ std::vector<llvm_ir::IrArray> ir_arrays;
+ TF_ASSIGN_OR_RETURN(auto kernel_thunk,
+ BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays));
SetThunkLaunchDimensions(launch_dimensions, kernel_thunk.get(),
ir_emitter_context_->llvm_module());
diff --git a/tensorflow/compiler/xla/service/gpu/tests/fused_slice_different_operands.hlo b/tensorflow/compiler/xla/service/gpu/tests/fused_slice_different_operands.hlo
index 987ab0f..551183b 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/fused_slice_different_operands.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/fused_slice_different_operands.hlo
@@ -3,20 +3,6 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: entry:
-// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
-// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [1024 x half]*
-// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
-// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1024 x half]*
-// CHECK: %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
-// CHECK: %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [1023 x half]*
-// CHECK: %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
-// CHECK: %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [1023 x half]*
-// CHECK: %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0
-// CHECK: %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [1024 x half]*
-// CHECK: %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0
-// CHECK: %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1024 x half]*
-// CHECK: ret void
-// CHECK: entry:
// CHECK: %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0
// CHECK: %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [1024 x half]*
// CHECK: %[[VAL_21:.*]] = getelementptr inbounds i8, i8* %[[VAL_22:.*]], i64 0
@@ -29,8 +15,8 @@
// CHECK: %[[VAL_32:.*]] = bitcast i8* %[[VAL_30]] to [1024 x half]*
// CHECK: %[[VAL_33:.*]] = getelementptr inbounds i8, i8* %[[VAL_34:.*]], i64 0
// CHECK: %[[VAL_35:.*]] = bitcast i8* %[[VAL_33]] to [1024 x half]*
-// CHECK: %[[VAL_36:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3
-// CHECK: %[[VAL_37:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4
+// CHECK: %[[VAL_36:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
+// CHECK: %[[VAL_37:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_38:.*]] = mul nuw nsw i32 %[[VAL_36]], 256
// CHECK: %[[VAL_39:.*]] = add nuw nsw i32 %[[VAL_38]], %[[VAL_37]]
// CHECK: %[[VAL_40:.*]] = icmp ult i32 %[[VAL_39]], 256
@@ -54,18 +40,18 @@
// CHECK: %[[VAL_57:.*]] = phi i32 [ 0, %[[VAL_56]] ]
// CHECK: %[[VAL_58:.*]] = sub nsw i32 %[[VAL_54]], %[[VAL_57]]
// CHECK: %[[VAL_59:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_58]]
-// CHECK: %[[VAL_60:.*]] = load half, half* %[[VAL_59]], align 2, !invariant.load !5
+// CHECK: %[[VAL_60:.*]] = load half, half* %[[VAL_59]], align 2, !invariant.load !4
// CHECK: %[[VAL_61:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_58]]
-// CHECK: %[[VAL_62:.*]] = load half, half* %[[VAL_61]], align 2, !invariant.load !5
+// CHECK: %[[VAL_62:.*]] = load half, half* %[[VAL_61]], align 2, !invariant.load !4
// CHECK: %[[VAL_63:.*]] = fmul half %[[VAL_60]], %[[VAL_62]]
// CHECK: br label %[[VAL_64:.*]]
// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_65:.*]]
// CHECK: %[[VAL_66:.*]] = phi i32 [ 1024, %[[VAL_65]] ]
// CHECK: %[[VAL_67:.*]] = sub nsw i32 %[[VAL_54]], %[[VAL_66]]
// CHECK: %[[VAL_68:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_26]], i32 0, i32 %[[VAL_67]]
-// CHECK: %[[VAL_69:.*]] = load half, half* %[[VAL_68]], align 2, !invariant.load !5
+// CHECK: %[[VAL_69:.*]] = load half, half* %[[VAL_68]], align 2, !invariant.load !4
// CHECK: %[[VAL_70:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_29]], i32 0, i32 %[[VAL_67]]
-// CHECK: %[[VAL_71:.*]] = load half, half* %[[VAL_70]], align 2, !invariant.load !5
+// CHECK: %[[VAL_71:.*]] = load half, half* %[[VAL_70]], align 2, !invariant.load !4
// CHECK: %[[VAL_72:.*]] = fadd half %[[VAL_69]], %[[VAL_71]]
// CHECK: br label %[[VAL_64]]
// CHECK: concatenate.pivot.1024.: ; preds = %[[VAL_50]]
@@ -80,9 +66,9 @@
// CHECK: %[[VAL_77:.*]] = insertvalue { half, half } undef, half %[[VAL_76]], 0
// CHECK: %[[VAL_78:.*]] = add i32 %[[VAL_42]], 0
// CHECK: %[[VAL_79:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_78]]
-// CHECK: %[[VAL_80:.*]] = load half, half* %[[VAL_79]], align 2, !invariant.load !5
+// CHECK: %[[VAL_80:.*]] = load half, half* %[[VAL_79]], align 2, !invariant.load !4
// CHECK: %[[VAL_81:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_78]]
-// CHECK: %[[VAL_82:.*]] = load half, half* %[[VAL_81]], align 2, !invariant.load !5
+// CHECK: %[[VAL_82:.*]] = load half, half* %[[VAL_81]], align 2, !invariant.load !4
// CHECK: %[[VAL_83:.*]] = fmul half %[[VAL_80]], %[[VAL_82]]
// CHECK: %[[VAL_84:.*]] = insertvalue { half, half } %[[VAL_77]], half %[[VAL_83]], 1
// CHECK: %[[VAL_85:.*]] = extractvalue { half, half } %[[VAL_84]], 0
@@ -99,18 +85,18 @@
// CHECK: %[[VAL_94:.*]] = phi i32 [ 0, %[[VAL_93]] ]
// CHECK: %[[VAL_95:.*]] = sub nsw i32 %[[VAL_91]], %[[VAL_94]]
// CHECK: %[[VAL_96:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_95]]
-// CHECK: %[[VAL_97:.*]] = load half, half* %[[VAL_96]], align 2, !invariant.load !5
+// CHECK: %[[VAL_97:.*]] = load half, half* %[[VAL_96]], align 2, !invariant.load !4
// CHECK: %[[VAL_98:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_95]]
-// CHECK: %[[VAL_99:.*]] = load half, half* %[[VAL_98]], align 2, !invariant.load !5
+// CHECK: %[[VAL_99:.*]] = load half, half* %[[VAL_98]], align 2, !invariant.load !4
// CHECK: %[[VAL_100:.*]] = fmul half %[[VAL_97]], %[[VAL_99]]
// CHECK: br label %[[VAL_101:.*]]
// CHECK: concat_index_from_operand_id16: ; preds = %[[VAL_102:.*]]
// CHECK: %[[VAL_103:.*]] = phi i32 [ 1024, %[[VAL_102]] ]
// CHECK: %[[VAL_104:.*]] = sub nsw i32 %[[VAL_91]], %[[VAL_103]]
// CHECK: %[[VAL_105:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_26]], i32 0, i32 %[[VAL_104]]
-// CHECK: %[[VAL_106:.*]] = load half, half* %[[VAL_105]], align 2, !invariant.load !5
+// CHECK: %[[VAL_106:.*]] = load half, half* %[[VAL_105]], align 2, !invariant.load !4
// CHECK: %[[VAL_107:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_29]], i32 0, i32 %[[VAL_104]]
-// CHECK: %[[VAL_108:.*]] = load half, half* %[[VAL_107]], align 2, !invariant.load !5
+// CHECK: %[[VAL_108:.*]] = load half, half* %[[VAL_107]], align 2, !invariant.load !4
// CHECK: %[[VAL_109:.*]] = fadd half %[[VAL_106]], %[[VAL_108]]
// CHECK: br label %[[VAL_101]]
// CHECK: concatenate.pivot.1024.8: ; preds = %[[VAL_64]]
@@ -125,9 +111,9 @@
// CHECK: %[[VAL_114:.*]] = insertvalue { half, half } undef, half %[[VAL_113]], 0
// CHECK: %[[VAL_115:.*]] = add i32 %[[VAL_44]], 0
// CHECK: %[[VAL_116:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_115]]
-// CHECK: %[[VAL_117:.*]] = load half, half* %[[VAL_116]], align 2, !invariant.load !5
+// CHECK: %[[VAL_117:.*]] = load half, half* %[[VAL_116]], align 2, !invariant.load !4
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_115]]
-// CHECK: %[[VAL_119:.*]] = load half, half* %[[VAL_118]], align 2, !invariant.load !5
+// CHECK: %[[VAL_119:.*]] = load half, half* %[[VAL_118]], align 2, !invariant.load !4
// CHECK: %[[VAL_120:.*]] = fmul half %[[VAL_117]], %[[VAL_119]]
// CHECK: %[[VAL_121:.*]] = insertvalue { half, half } %[[VAL_114]], half %[[VAL_120]], 1
// CHECK: %[[VAL_122:.*]] = extractvalue { half, half } %[[VAL_121]], 0
@@ -144,18 +130,18 @@
// CHECK: %[[VAL_131:.*]] = phi i32 [ 0, %[[VAL_130]] ]
// CHECK: %[[VAL_132:.*]] = sub nsw i32 %[[VAL_128]], %[[VAL_131]]
// CHECK: %[[VAL_133:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_132]]
-// CHECK: %[[VAL_134:.*]] = load half, half* %[[VAL_133]], align 2, !invariant.load !5
+// CHECK: %[[VAL_134:.*]] = load half, half* %[[VAL_133]], align 2, !invariant.load !4
// CHECK: %[[VAL_135:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_132]]
-// CHECK: %[[VAL_136:.*]] = load half, half* %[[VAL_135]], align 2, !invariant.load !5
+// CHECK: %[[VAL_136:.*]] = load half, half* %[[VAL_135]], align 2, !invariant.load !4
// CHECK: %[[VAL_137:.*]] = fmul half %[[VAL_134]], %[[VAL_136]]
// CHECK: br label %[[VAL_138:.*]]
// CHECK: concat_index_from_operand_id115: ; preds = %[[VAL_139:.*]]
// CHECK: %[[VAL_140:.*]] = phi i32 [ 1024, %[[VAL_139]] ]
// CHECK: %[[VAL_141:.*]] = sub nsw i32 %[[VAL_128]], %[[VAL_140]]
// CHECK: %[[VAL_142:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_26]], i32 0, i32 %[[VAL_141]]
-// CHECK: %[[VAL_143:.*]] = load half, half* %[[VAL_142]], align 2, !invariant.load !5
+// CHECK: %[[VAL_143:.*]] = load half, half* %[[VAL_142]], align 2, !invariant.load !4
// CHECK: %[[VAL_144:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_29]], i32 0, i32 %[[VAL_141]]
-// CHECK: %[[VAL_145:.*]] = load half, half* %[[VAL_144]], align 2, !invariant.load !5
+// CHECK: %[[VAL_145:.*]] = load half, half* %[[VAL_144]], align 2, !invariant.load !4
// CHECK: %[[VAL_146:.*]] = fadd half %[[VAL_143]], %[[VAL_145]]
// CHECK: br label %[[VAL_138]]
// CHECK: concatenate.pivot.1024.17: ; preds = %[[VAL_101]]
@@ -170,9 +156,9 @@
// CHECK: %[[VAL_151:.*]] = insertvalue { half, half } undef, half %[[VAL_150]], 0
// CHECK: %[[VAL_152:.*]] = add i32 %[[VAL_46]], 0
// CHECK: %[[VAL_153:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_152]]
-// CHECK: %[[VAL_154:.*]] = load half, half* %[[VAL_153]], align 2, !invariant.load !5
+// CHECK: %[[VAL_154:.*]] = load half, half* %[[VAL_153]], align 2, !invariant.load !4
// CHECK: %[[VAL_155:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_152]]
-// CHECK: %[[VAL_156:.*]] = load half, half* %[[VAL_155]], align 2, !invariant.load !5
+// CHECK: %[[VAL_156:.*]] = load half, half* %[[VAL_155]], align 2, !invariant.load !4
// CHECK: %[[VAL_157:.*]] = fmul half %[[VAL_154]], %[[VAL_156]]
// CHECK: %[[VAL_158:.*]] = insertvalue { half, half } %[[VAL_151]], half %[[VAL_157]], 1
// CHECK: %[[VAL_159:.*]] = extractvalue { half, half } %[[VAL_158]], 0
@@ -189,18 +175,18 @@
// CHECK: %[[VAL_168:.*]] = phi i32 [ 0, %[[VAL_167]] ]
// CHECK: %[[VAL_169:.*]] = sub nsw i32 %[[VAL_165]], %[[VAL_168]]
// CHECK: %[[VAL_170:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_169]]
-// CHECK: %[[VAL_171:.*]] = load half, half* %[[VAL_170]], align 2, !invariant.load !5
+// CHECK: %[[VAL_171:.*]] = load half, half* %[[VAL_170]], align 2, !invariant.load !4
// CHECK: %[[VAL_172:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_169]]
-// CHECK: %[[VAL_173:.*]] = load half, half* %[[VAL_172]], align 2, !invariant.load !5
+// CHECK: %[[VAL_173:.*]] = load half, half* %[[VAL_172]], align 2, !invariant.load !4
// CHECK: %[[VAL_174:.*]] = fmul half %[[VAL_171]], %[[VAL_173]]
// CHECK: br label %[[VAL_52]]
// CHECK: concat_index_from_operand_id124: ; preds = %[[VAL_175:.*]]
// CHECK: %[[VAL_176:.*]] = phi i32 [ 1024, %[[VAL_175]] ]
// CHECK: %[[VAL_177:.*]] = sub nsw i32 %[[VAL_165]], %[[VAL_176]]
// CHECK: %[[VAL_178:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_26]], i32 0, i32 %[[VAL_177]]
-// CHECK: %[[VAL_179:.*]] = load half, half* %[[VAL_178]], align 2, !invariant.load !5
+// CHECK: %[[VAL_179:.*]] = load half, half* %[[VAL_178]], align 2, !invariant.load !4
// CHECK: %[[VAL_180:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_29]], i32 0, i32 %[[VAL_177]]
-// CHECK: %[[VAL_181:.*]] = load half, half* %[[VAL_180]], align 2, !invariant.load !5
+// CHECK: %[[VAL_181:.*]] = load half, half* %[[VAL_180]], align 2, !invariant.load !4
// CHECK: %[[VAL_182:.*]] = fadd half %[[VAL_179]], %[[VAL_181]]
// CHECK: br label %[[VAL_52]]
// CHECK: concatenate.pivot.1024.26: ; preds = %[[VAL_138]]
@@ -215,9 +201,9 @@
// CHECK: %[[VAL_187:.*]] = insertvalue { half, half } undef, half %[[VAL_186]], 0
// CHECK: %[[VAL_188:.*]] = add i32 %[[VAL_48]], 0
// CHECK: %[[VAL_189:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_188]]
-// CHECK: %[[VAL_190:.*]] = load half, half* %[[VAL_189]], align 2, !invariant.load !5
+// CHECK: %[[VAL_190:.*]] = load half, half* %[[VAL_189]], align 2, !invariant.load !4
// CHECK: %[[VAL_191:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_188]]
-// CHECK: %[[VAL_192:.*]] = load half, half* %[[VAL_191]], align 2, !invariant.load !5
+// CHECK: %[[VAL_192:.*]] = load half, half* %[[VAL_191]], align 2, !invariant.load !4
// CHECK: %[[VAL_193:.*]] = fmul half %[[VAL_190]], %[[VAL_192]]
// CHECK: %[[VAL_194:.*]] = insertvalue { half, half } %[[VAL_187]], half %[[VAL_193]], 1
// CHECK: %[[VAL_195:.*]] = extractvalue { half, half } %[[VAL_194]], 0