Refactor kernel thunk's launch dimension setting in EmitInputFusibleNonStridedSlices.

Move SetThunkLaunchDimensions() to right after KernelThunk construction. Launch dimension will be passed to KernelThunk's constructor as a parameter.

PiperOrigin-RevId: 387204323
Change-Id: I6e7ec07edc33a7ccc4d2c70cc0a707b89b9ae83f
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index af98790..d11ef12 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -5316,10 +5316,6 @@
 
   constexpr int unroll_factor = 1;
 
-  std::vector<llvm_ir::IrArray> ir_arrays;
-  TF_ASSIGN_OR_RETURN(auto kernel_thunk,
-                      BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays));
-
   TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
                       GetOrCreateSubComputationFromRegion(&fusion.region(),
                                                           /*is_fusion=*/true));
@@ -5330,6 +5326,10 @@
                       CalculateLaunchDimensions(
                           element_shape, ir_emitter_context_->gpu_device_info(),
                           {unroll_factor}));
+
+  std::vector<llvm_ir::IrArray> ir_arrays;
+  TF_ASSIGN_OR_RETURN(auto kernel_thunk,
+                      BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays));
   SetThunkLaunchDimensions(launch_dimensions, kernel_thunk.get(),
                            ir_emitter_context_->llvm_module());
 
diff --git a/tensorflow/compiler/xla/service/gpu/tests/fused_slice_different_operands.hlo b/tensorflow/compiler/xla/service/gpu/tests/fused_slice_different_operands.hlo
index 987ab0f..551183b 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/fused_slice_different_operands.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/fused_slice_different_operands.hlo
@@ -3,20 +3,6 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
 
 // CHECK-LABEL: entry:
-// CHECK:         %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
-// CHECK:         %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [1024 x half]*
-// CHECK:         %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
-// CHECK:         %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [1024 x half]*
-// CHECK:         %[[VAL_6:.*]] = getelementptr inbounds i8, i8* %[[VAL_7:.*]], i64 0
-// CHECK:         %[[VAL_8:.*]] = bitcast i8* %[[VAL_6]] to [1023 x half]*
-// CHECK:         %[[VAL_9:.*]] = getelementptr inbounds i8, i8* %[[VAL_10:.*]], i64 0
-// CHECK:         %[[VAL_11:.*]] = bitcast i8* %[[VAL_9]] to [1023 x half]*
-// CHECK:         %[[VAL_12:.*]] = getelementptr inbounds i8, i8* %[[VAL_13:.*]], i64 0
-// CHECK:         %[[VAL_14:.*]] = bitcast i8* %[[VAL_12]] to [1024 x half]*
-// CHECK:         %[[VAL_15:.*]] = getelementptr inbounds i8, i8* %[[VAL_16:.*]], i64 0
-// CHECK:         %[[VAL_17:.*]] = bitcast i8* %[[VAL_15]] to [1024 x half]*
-// CHECK:         ret void
-// CHECK:       entry:
 // CHECK:         %[[VAL_18:.*]] = getelementptr inbounds i8, i8* %[[VAL_19:.*]], i64 0
 // CHECK:         %[[VAL_20:.*]] = bitcast i8* %[[VAL_18]] to [1024 x half]*
 // CHECK:         %[[VAL_21:.*]] = getelementptr inbounds i8, i8* %[[VAL_22:.*]], i64 0
@@ -29,8 +15,8 @@
 // CHECK:         %[[VAL_32:.*]] = bitcast i8* %[[VAL_30]] to [1024 x half]*
 // CHECK:         %[[VAL_33:.*]] = getelementptr inbounds i8, i8* %[[VAL_34:.*]], i64 0
 // CHECK:         %[[VAL_35:.*]] = bitcast i8* %[[VAL_33]] to [1024 x half]*
-// CHECK:         %[[VAL_36:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3
-// CHECK:         %[[VAL_37:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4
+// CHECK:         %[[VAL_36:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
+// CHECK:         %[[VAL_37:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
 // CHECK:         %[[VAL_38:.*]] = mul nuw nsw i32 %[[VAL_36]], 256
 // CHECK:         %[[VAL_39:.*]] = add nuw nsw i32 %[[VAL_38]], %[[VAL_37]]
 // CHECK:         %[[VAL_40:.*]] = icmp ult i32 %[[VAL_39]], 256
@@ -54,18 +40,18 @@
 // CHECK:         %[[VAL_57:.*]] = phi i32 [ 0, %[[VAL_56]] ]
 // CHECK:         %[[VAL_58:.*]] = sub nsw i32 %[[VAL_54]], %[[VAL_57]]
 // CHECK:         %[[VAL_59:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_58]]
-// CHECK:         %[[VAL_60:.*]] = load half, half* %[[VAL_59]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_60:.*]] = load half, half* %[[VAL_59]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_61:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_58]]
-// CHECK:         %[[VAL_62:.*]] = load half, half* %[[VAL_61]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_62:.*]] = load half, half* %[[VAL_61]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_63:.*]] = fmul half %[[VAL_60]], %[[VAL_62]]
 // CHECK:         br label %[[VAL_64:.*]]
 // CHECK:       concat_index_from_operand_id1:                    ; preds = %[[VAL_65:.*]]
 // CHECK:         %[[VAL_66:.*]] = phi i32 [ 1024, %[[VAL_65]] ]
 // CHECK:         %[[VAL_67:.*]] = sub nsw i32 %[[VAL_54]], %[[VAL_66]]
 // CHECK:         %[[VAL_68:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_26]], i32 0, i32 %[[VAL_67]]
-// CHECK:         %[[VAL_69:.*]] = load half, half* %[[VAL_68]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_69:.*]] = load half, half* %[[VAL_68]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_70:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_29]], i32 0, i32 %[[VAL_67]]
-// CHECK:         %[[VAL_71:.*]] = load half, half* %[[VAL_70]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_71:.*]] = load half, half* %[[VAL_70]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_72:.*]] = fadd half %[[VAL_69]], %[[VAL_71]]
 // CHECK:         br label %[[VAL_64]]
 // CHECK:       concatenate.pivot.1024.:                          ; preds = %[[VAL_50]]
@@ -80,9 +66,9 @@
 // CHECK:         %[[VAL_77:.*]] = insertvalue { half, half } undef, half %[[VAL_76]], 0
 // CHECK:         %[[VAL_78:.*]] = add i32 %[[VAL_42]], 0
 // CHECK:         %[[VAL_79:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_78]]
-// CHECK:         %[[VAL_80:.*]] = load half, half* %[[VAL_79]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_80:.*]] = load half, half* %[[VAL_79]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_81:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_78]]
-// CHECK:         %[[VAL_82:.*]] = load half, half* %[[VAL_81]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_82:.*]] = load half, half* %[[VAL_81]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_83:.*]] = fmul half %[[VAL_80]], %[[VAL_82]]
 // CHECK:         %[[VAL_84:.*]] = insertvalue { half, half } %[[VAL_77]], half %[[VAL_83]], 1
 // CHECK:         %[[VAL_85:.*]] = extractvalue { half, half } %[[VAL_84]], 0
@@ -99,18 +85,18 @@
 // CHECK:         %[[VAL_94:.*]] = phi i32 [ 0, %[[VAL_93]] ]
 // CHECK:         %[[VAL_95:.*]] = sub nsw i32 %[[VAL_91]], %[[VAL_94]]
 // CHECK:         %[[VAL_96:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_95]]
-// CHECK:         %[[VAL_97:.*]] = load half, half* %[[VAL_96]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_97:.*]] = load half, half* %[[VAL_96]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_98:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_95]]
-// CHECK:         %[[VAL_99:.*]] = load half, half* %[[VAL_98]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_99:.*]] = load half, half* %[[VAL_98]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_100:.*]] = fmul half %[[VAL_97]], %[[VAL_99]]
 // CHECK:         br label %[[VAL_101:.*]]
 // CHECK:       concat_index_from_operand_id16:                   ; preds = %[[VAL_102:.*]]
 // CHECK:         %[[VAL_103:.*]] = phi i32 [ 1024, %[[VAL_102]] ]
 // CHECK:         %[[VAL_104:.*]] = sub nsw i32 %[[VAL_91]], %[[VAL_103]]
 // CHECK:         %[[VAL_105:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_26]], i32 0, i32 %[[VAL_104]]
-// CHECK:         %[[VAL_106:.*]] = load half, half* %[[VAL_105]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_106:.*]] = load half, half* %[[VAL_105]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_107:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_29]], i32 0, i32 %[[VAL_104]]
-// CHECK:         %[[VAL_108:.*]] = load half, half* %[[VAL_107]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_108:.*]] = load half, half* %[[VAL_107]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_109:.*]] = fadd half %[[VAL_106]], %[[VAL_108]]
 // CHECK:         br label %[[VAL_101]]
 // CHECK:       concatenate.pivot.1024.8:                         ; preds = %[[VAL_64]]
@@ -125,9 +111,9 @@
 // CHECK:         %[[VAL_114:.*]] = insertvalue { half, half } undef, half %[[VAL_113]], 0
 // CHECK:         %[[VAL_115:.*]] = add i32 %[[VAL_44]], 0
 // CHECK:         %[[VAL_116:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_115]]
-// CHECK:         %[[VAL_117:.*]] = load half, half* %[[VAL_116]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_117:.*]] = load half, half* %[[VAL_116]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_118:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_115]]
-// CHECK:         %[[VAL_119:.*]] = load half, half* %[[VAL_118]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_119:.*]] = load half, half* %[[VAL_118]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_120:.*]] = fmul half %[[VAL_117]], %[[VAL_119]]
 // CHECK:         %[[VAL_121:.*]] = insertvalue { half, half } %[[VAL_114]], half %[[VAL_120]], 1
 // CHECK:         %[[VAL_122:.*]] = extractvalue { half, half } %[[VAL_121]], 0
@@ -144,18 +130,18 @@
 // CHECK:         %[[VAL_131:.*]] = phi i32 [ 0, %[[VAL_130]] ]
 // CHECK:         %[[VAL_132:.*]] = sub nsw i32 %[[VAL_128]], %[[VAL_131]]
 // CHECK:         %[[VAL_133:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_132]]
-// CHECK:         %[[VAL_134:.*]] = load half, half* %[[VAL_133]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_134:.*]] = load half, half* %[[VAL_133]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_135:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_132]]
-// CHECK:         %[[VAL_136:.*]] = load half, half* %[[VAL_135]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_136:.*]] = load half, half* %[[VAL_135]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_137:.*]] = fmul half %[[VAL_134]], %[[VAL_136]]
 // CHECK:         br label %[[VAL_138:.*]]
 // CHECK:       concat_index_from_operand_id115:                  ; preds = %[[VAL_139:.*]]
 // CHECK:         %[[VAL_140:.*]] = phi i32 [ 1024, %[[VAL_139]] ]
 // CHECK:         %[[VAL_141:.*]] = sub nsw i32 %[[VAL_128]], %[[VAL_140]]
 // CHECK:         %[[VAL_142:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_26]], i32 0, i32 %[[VAL_141]]
-// CHECK:         %[[VAL_143:.*]] = load half, half* %[[VAL_142]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_143:.*]] = load half, half* %[[VAL_142]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_144:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_29]], i32 0, i32 %[[VAL_141]]
-// CHECK:         %[[VAL_145:.*]] = load half, half* %[[VAL_144]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_145:.*]] = load half, half* %[[VAL_144]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_146:.*]] = fadd half %[[VAL_143]], %[[VAL_145]]
 // CHECK:         br label %[[VAL_138]]
 // CHECK:       concatenate.pivot.1024.17:                        ; preds = %[[VAL_101]]
@@ -170,9 +156,9 @@
 // CHECK:         %[[VAL_151:.*]] = insertvalue { half, half } undef, half %[[VAL_150]], 0
 // CHECK:         %[[VAL_152:.*]] = add i32 %[[VAL_46]], 0
 // CHECK:         %[[VAL_153:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_152]]
-// CHECK:         %[[VAL_154:.*]] = load half, half* %[[VAL_153]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_154:.*]] = load half, half* %[[VAL_153]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_155:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_152]]
-// CHECK:         %[[VAL_156:.*]] = load half, half* %[[VAL_155]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_156:.*]] = load half, half* %[[VAL_155]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_157:.*]] = fmul half %[[VAL_154]], %[[VAL_156]]
 // CHECK:         %[[VAL_158:.*]] = insertvalue { half, half } %[[VAL_151]], half %[[VAL_157]], 1
 // CHECK:         %[[VAL_159:.*]] = extractvalue { half, half } %[[VAL_158]], 0
@@ -189,18 +175,18 @@
 // CHECK:         %[[VAL_168:.*]] = phi i32 [ 0, %[[VAL_167]] ]
 // CHECK:         %[[VAL_169:.*]] = sub nsw i32 %[[VAL_165]], %[[VAL_168]]
 // CHECK:         %[[VAL_170:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_169]]
-// CHECK:         %[[VAL_171:.*]] = load half, half* %[[VAL_170]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_171:.*]] = load half, half* %[[VAL_170]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_172:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_169]]
-// CHECK:         %[[VAL_173:.*]] = load half, half* %[[VAL_172]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_173:.*]] = load half, half* %[[VAL_172]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_174:.*]] = fmul half %[[VAL_171]], %[[VAL_173]]
 // CHECK:         br label %[[VAL_52]]
 // CHECK:       concat_index_from_operand_id124:                  ; preds = %[[VAL_175:.*]]
 // CHECK:         %[[VAL_176:.*]] = phi i32 [ 1024, %[[VAL_175]] ]
 // CHECK:         %[[VAL_177:.*]] = sub nsw i32 %[[VAL_165]], %[[VAL_176]]
 // CHECK:         %[[VAL_178:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_26]], i32 0, i32 %[[VAL_177]]
-// CHECK:         %[[VAL_179:.*]] = load half, half* %[[VAL_178]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_179:.*]] = load half, half* %[[VAL_178]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_180:.*]] = getelementptr inbounds [1023 x half], [1023 x half]* %[[VAL_29]], i32 0, i32 %[[VAL_177]]
-// CHECK:         %[[VAL_181:.*]] = load half, half* %[[VAL_180]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_181:.*]] = load half, half* %[[VAL_180]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_182:.*]] = fadd half %[[VAL_179]], %[[VAL_181]]
 // CHECK:         br label %[[VAL_52]]
 // CHECK:       concatenate.pivot.1024.26:                        ; preds = %[[VAL_138]]
@@ -215,9 +201,9 @@
 // CHECK:         %[[VAL_187:.*]] = insertvalue { half, half } undef, half %[[VAL_186]], 0
 // CHECK:         %[[VAL_188:.*]] = add i32 %[[VAL_48]], 0
 // CHECK:         %[[VAL_189:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_20]], i32 0, i32 %[[VAL_188]]
-// CHECK:         %[[VAL_190:.*]] = load half, half* %[[VAL_189]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_190:.*]] = load half, half* %[[VAL_189]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_191:.*]] = getelementptr inbounds [1024 x half], [1024 x half]* %[[VAL_23]], i32 0, i32 %[[VAL_188]]
-// CHECK:         %[[VAL_192:.*]] = load half, half* %[[VAL_191]], align 2, !invariant.load !5
+// CHECK:         %[[VAL_192:.*]] = load half, half* %[[VAL_191]], align 2, !invariant.load !4
 // CHECK:         %[[VAL_193:.*]] = fmul half %[[VAL_190]], %[[VAL_192]]
 // CHECK:         %[[VAL_194:.*]] = insertvalue { half, half } %[[VAL_187]], half %[[VAL_193]], 1
 // CHECK:         %[[VAL_195:.*]] = extractvalue { half, half } %[[VAL_194]], 0