[XLA:GPU] Do not unroll the reduction X-tile in the emitters

Dramatically reduces compilation time for large reductions.

On BERT, backend compilation time goes from 2.79min to 1.86min (1.5x difference).
IR emission time goes from 12s to ~2ms, and LLVM verfier from ~4s to <1ms.

Similar timings are seen for micro-benchmarks:

For reduction:

```
HloModule module

scalar_add_computation {
  scalar_lhs = f32[] parameter(0)
  scalar_rhs = f32[] parameter(1)
  ROOT add.15 = f32[] add(scalar_lhs, scalar_rhs)
}

ENTRY computation {
  t = f32[70000,8192] parameter(0)
  zero = f32[] constant(0)
  ROOT out = reduce(t, zero), dimensions={1}, to_apply=scalar_add_computation
}
```

Before: 163ms LLVM compilation time
After: 121ms LLVM compilation time

Runtime: ~3302us before, ~3304us after

The LLVM IR after optimizations for individual reductions is ~identical,
so the runtime difference could be just noise.

PiperOrigin-RevId: 401825158
Change-Id: I490bd05936271a81a712a2940778414f53aa1bcf
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 72d3d12..f4ba1ba 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -3593,24 +3593,35 @@
     return llvm::ConstantInt::get(index_ty, val);
   };
   IrArray::Index source_idx_x_base = source_idx.AddOffsetToDim(y_loc, kDimY, b);
-  for (int64_t j = 0; j < x_num_steps / vector_size; j++) {
-    for (int64_t i = 0; i < vector_size; i++) {
-      int64_t linear_index = j * vector_size + i;
-      llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i),
-                                        start_offset_x, "x_loc");
-      IrArray::Index source_idx_x = source_idx_x_base.AddOffsetToDim(
-          constant(j * step_x * vector_size + i), kDimX, b);
-      auto emit_element = [&] {
-        return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index);
-      };
-      if (check_x_tile_bounds) {
-        ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width),
-                emit_element);
-      } else {
-        emit_element();
-      }
-    }
-  }
+  KernelSupportLibrary unrolled_ksl(b, llvm_ir::UnrollMode::kFullyUnroll);
+  unrolled_ksl.For(
+      "tile_loop",
+      /*start=*/constant(0),
+      /*end=*/constant(x_num_steps / vector_size),
+      /*step=*/1, [&](llvm::Value* j) {
+        for (int64_t i = 0; i < vector_size; i++) {
+          llvm::Value* linear_index =
+              b->CreateAdd(b->CreateMul(j, constant(vector_size)), constant(i));
+          llvm::Value* x_loc = b->CreateAdd(
+              b->CreateAdd(b->CreateMul(j, constant(step_x * vector_size)),
+                           constant(i)),
+              start_offset_x, "x_loc");
+          IrArray::Index source_idx_x = source_idx_x_base.AddOffsetToDim(
+              b->CreateAdd(b->CreateMul(j, constant(step_x * vector_size)),
+                           constant(i)),
+              kDimX, b);
+          auto emit_element = [&] {
+            return (*emit_elem_function)(source_idx_x, y_loc, x_loc,
+                                         linear_index);
+          };
+          if (check_x_tile_bounds) {
+            ksl->If(loop_name + "_x_in_tile",
+                    b->CreateICmpULT(x_loc, tile_width), emit_element);
+          } else {
+            emit_element();
+          }
+        }
+      });
 }
 
 void IrEmitterUnnested::EmitTile(
@@ -4472,7 +4483,7 @@
 
   EmitElementFunction element_generator =
       [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
-          llvm::Value* x_loc, int64_t x_iter_num) {
+          llvm::Value* x_loc, llvm::Value* x_iter_num) {
         auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op);
         EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index,
                                  tiling_scheme, y_loc, x_loc,
@@ -4499,7 +4510,7 @@
           EmitTile(tiling_scheme, input_tile_origin, "input", ksl,
                    thread_id_info, tile_width, tile_height,
                    [&](const IrArray::Index& index, llvm::Value* y_loc,
-                       llvm::Value* x_loc, int64_t /*x_iter_num*/) {
+                       llvm::Value* x_loc, llvm::Value* /*x_iter_num*/) {
                      for (int64_t id : tiled_param_ids) {
                        IrArray& input_in_logical_shape =
                            param_in_reduced_shape_arrays.at(id);
@@ -4963,7 +4974,7 @@
 // Generate a single element of the tile (update the accumulator state) for a
 // given reducer of index `i`.
 void IrEmitterUnnested::GenerateElementForReducer(
-    int i, int partial_result_index, const HloComputation* reducer,
+    int i, llvm::Value* partial_result_index, const HloComputation* reducer,
     const ReductionCodegenState& codegen_state,
     const llvm_ir::IrArray::Index& index_without_linear,
     const IrArray::Index& input_index, int num_partial_results,
@@ -4983,7 +4994,7 @@
         num_partial_results > 1 ? index_without_linear : input_index);
     b_.CreateStore(input_ir_value, input_address);
     llvm::Value* partial_result_address = b_.CreateInBoundsGEP(
-        partial_reduction_result_address, {b_.getInt32(partial_result_index)});
+        partial_reduction_result_address, {partial_result_index});
     reduction_accumulators.push_back(partial_result_address);
     reduction_input_value.push_back(input_address);
   }
@@ -5045,12 +5056,12 @@
 
   EmitElementFunction emit_reduction_element =
       [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
-          llvm::Value* x_loc, int64_t x_iter_num) {
+          llvm::Value* x_loc, llvm::Value* x_iter_num) {
         IrArray::Index input_index = GetUnnormalizedIndex(
             index, input_shape, &b_, codegen_state.GetTilingScheme());
 
-        int partial_result_index =
-            codegen_state.IsRowReduction() ? 0 : x_iter_num;
+        llvm::Value* partial_result_index =
+            codegen_state.IsRowReduction() ? b_.getInt32(0) : x_iter_num;
 
         // Clear the linear index field of the IrArray::Index to enable the use
         // of GetElementPointer with array types. This enables the vectorization
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 7eeeb35..0897351 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -104,7 +104,7 @@
   //             has a value of 0..N-1 to identify the element being process.
   using EmitElementFunction = std::function<void(
       const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
-      llvm::Value* x_loc, int64_t x_iter_num)>;
+      llvm::Value* x_loc, llvm::Value* x_iter_num)>;
 
   using ConstantGenerator = std::function<llvm::Value*(int64_t)>;
 
@@ -546,7 +546,7 @@
   // Generate a single element of the tile (update the accumulator state) for a
   // given reducer of index `i`.
   void GenerateElementForReducer(
-      int i, int partial_result_index, const HloComputation* reducer,
+      int i, llvm::Value* partial_result_index, const HloComputation* reducer,
       const ReductionCodegenState& codegen_state,
       const llvm_ir::IrArray::Index& index_without_linear,
       const llvm_ir::IrArray::Index& input_index, int num_partial_results,
diff --git a/tensorflow/compiler/xla/service/gpu/tests/copy.hlo b/tensorflow/compiler/xla/service/gpu/tests/copy.hlo
index b71915f..ac2c599 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/copy.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/copy.hlo
@@ -2,85 +2,129 @@
 
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
 
+// The script is designed to make adding checks to
+// a test case fast, it is *not* designed to be authoritative
+// about what constitutes a good test! The CHECK should be
+// minimized and named to reflect the test intent.
+
+
 // CHECK-LABEL: entry:
 // CHECK:         %[[VAL_0:.*]] = alloca i32, align 4
 // CHECK:         %[[VAL_1:.*]] = alloca i32, align 4
-// CHECK:         %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_3:.*]], i64 0
-// CHECK:         %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [100 x [200 x float]]*
-// CHECK:         %[[VAL_5:.*]] = getelementptr inbounds i8, i8* %[[VAL_6:.*]], i64 0
-// CHECK:         %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [200 x [100 x float]]*
-// CHECK:         %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_4]] to [1 x [100 x [200 x float]]]*
-// CHECK:         %[[VAL_9:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
-// CHECK:         %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 32
-// CHECK:         %[[VAL_11:.*]] = udiv i32 %[[VAL_9]], 32
-// CHECK:         %[[VAL_12:.*]] = urem i32 %[[VAL_9]], 32
-// CHECK:         %[[VAL_13:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3
-// CHECK:         %[[VAL_14:.*]] = udiv i32 %[[VAL_13]], 1
-// CHECK:         %[[VAL_15:.*]] = urem i32 %[[VAL_14]], 4
-// CHECK:         %[[VAL_16:.*]] = udiv i32 %[[VAL_13]], 4
-// CHECK:         %[[VAL_17:.*]] = urem i32 %[[VAL_16]], 7
-// CHECK:         %[[VAL_18:.*]] = udiv i32 %[[VAL_13]], 28
-// CHECK:         %[[VAL_19:.*]] = mul i32 %[[VAL_18]], 1
-// CHECK:         %[[VAL_20:.*]] = icmp eq i32 %[[VAL_17]], 6
-// CHECK:         %[[VAL_21:.*]] = select i1 %[[VAL_20]], i32 8, i32 32
-// CHECK:         %[[VAL_22:.*]] = icmp eq i32 %[[VAL_15]], 3
-// CHECK:         %[[VAL_23:.*]] = select i1 %[[VAL_22]], i32 4, i32 32
-// CHECK:         %[[VAL_24:.*]] = mul i32 %[[VAL_17]], 32
-// CHECK:         %[[VAL_25:.*]] = mul i32 %[[VAL_15]], 32
-// CHECK:         %[[VAL_26:.*]] = mul i32 %[[VAL_10]], 1
-// CHECK:         %[[VAL_27:.*]] = add i32 %[[VAL_24]], %[[VAL_26]]
-// CHECK:         store i32 %[[VAL_11]], i32* %[[VAL_1]], align 4
-// CHECK:         br label %[[VAL_28:.*]]
-// CHECK:       input_y_in_tile.loop_header:                      ; preds = %[[VAL_29:.*]], %[[VAL_30:.*]]
-// CHECK:         %[[VAL_31:.*]] = load i32, i32* %[[VAL_1]], align 4
-// CHECK:         %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %[[VAL_23]]
-// CHECK:         br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]]
-// CHECK:       input_y_in_tile.loop_body:                        ; preds = %[[VAL_28]]
-// CHECK:         %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 4
-// CHECK:         store i32 %[[VAL_35]], i32* %[[VAL_1]], align 4
-// CHECK:         %[[VAL_36:.*]] = icmp eq i32 %[[VAL_31]], %[[VAL_11]]
-// CHECK:         %[[VAL_37:.*]] = add i32 %[[VAL_25]], %[[VAL_31]]
-// CHECK:         %[[VAL_38:.*]] = add i32 0, %[[VAL_26]]
-// CHECK:         %[[VAL_39:.*]] = add i32 %[[VAL_27]], 0
-// CHECK:         %[[VAL_40:.*]] = icmp ult i32 %[[VAL_38]], %[[VAL_21]]
-// CHECK:         br i1 %[[VAL_40]], label %[[VAL_41:.*]], label %[[VAL_29]]
-// CHECK:       input_x_in_tile-after:                            ; preds = %[[VAL_41]], %[[VAL_34]]
-// CHECK:         br label %[[VAL_28]], !llvm.loop !4
-// CHECK:       input_y_in_tile.loop_exit:                        ; preds = %[[VAL_28]]
+// CHECK:         %[[VAL_2:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_3:.*]] = alloca i32, align 4
+// CHECK:         %[[VAL_4:.*]] = getelementptr inbounds i8, i8* %[[VAL_5:.*]], i64 0
+// CHECK:         %[[VAL_6:.*]] = bitcast i8* %[[VAL_4]] to [100 x [200 x float]]*
+// CHECK:         %[[VAL_7:.*]] = getelementptr inbounds i8, i8* %[[VAL_8:.*]], i64 0
+// CHECK:         %[[VAL_9:.*]] = bitcast i8* %[[VAL_7]] to [200 x [100 x float]]*
+// CHECK:         %[[VAL_10:.*]] = bitcast [100 x [200 x float]]* %[[VAL_6]] to [1 x [100 x [200 x float]]]*
+// CHECK:         %[[VAL_11:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
+// CHECK:         %[[VAL_12:.*]] = urem i32 %[[VAL_11]], 32
+// CHECK:         %[[VAL_13:.*]] = udiv i32 %[[VAL_11]], 32
+// CHECK:         %[[VAL_14:.*]] = urem i32 %[[VAL_11]], 32
+// CHECK:         %[[VAL_15:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3
+// CHECK:         %[[VAL_16:.*]] = udiv i32 %[[VAL_15]], 1
+// CHECK:         %[[VAL_17:.*]] = urem i32 %[[VAL_16]], 4
+// CHECK:         %[[VAL_18:.*]] = udiv i32 %[[VAL_15]], 4
+// CHECK:         %[[VAL_19:.*]] = urem i32 %[[VAL_18]], 7
+// CHECK:         %[[VAL_20:.*]] = udiv i32 %[[VAL_15]], 28
+// CHECK:         %[[VAL_21:.*]] = mul i32 %[[VAL_20]], 1
+// CHECK:         %[[VAL_22:.*]] = icmp eq i32 %[[VAL_19]], 6
+// CHECK:         %[[VAL_23:.*]] = select i1 %[[VAL_22]], i32 8, i32 32
+// CHECK:         %[[VAL_24:.*]] = icmp eq i32 %[[VAL_17]], 3
+// CHECK:         %[[VAL_25:.*]] = select i1 %[[VAL_24]], i32 4, i32 32
+// CHECK:         %[[VAL_26:.*]] = mul i32 %[[VAL_19]], 32
+// CHECK:         %[[VAL_27:.*]] = mul i32 %[[VAL_17]], 32
+// CHECK:         %[[VAL_28:.*]] = mul i32 %[[VAL_12]], 1
+// CHECK:         %[[VAL_29:.*]] = add i32 %[[VAL_26]], %[[VAL_28]]
+// CHECK:         store i32 %[[VAL_13]], i32* %[[VAL_3]], align 4
+// CHECK:         br label %[[VAL_30:.*]]
+// CHECK:       input_y_in_tile.loop_header:                      ; preds = %[[VAL_31:.*]], %[[VAL_32:.*]]
+// CHECK:         %[[VAL_33:.*]] = load i32, i32* %[[VAL_3]], align 4
+// CHECK:         %[[VAL_34:.*]] = icmp uge i32 %[[VAL_33]], %[[VAL_25]]
+// CHECK:         br i1 %[[VAL_34]], label %[[VAL_35:.*]], label %[[VAL_36:.*]]
+// CHECK:       input_y_in_tile.loop_body:                        ; preds = %[[VAL_30]]
+// CHECK:         %[[VAL_37:.*]] = add nuw nsw i32 %[[VAL_33]], 4
+// CHECK:         store i32 %[[VAL_37]], i32* %[[VAL_3]], align 4
+// CHECK:         %[[VAL_38:.*]] = icmp eq i32 %[[VAL_33]], %[[VAL_13]]
+// CHECK:         %[[VAL_39:.*]] = add i32 %[[VAL_27]], %[[VAL_33]]
+// CHECK:         store i32 0, i32* %[[VAL_2]], align 4
+// CHECK:         br label %[[VAL_40:.*]]
+// CHECK:       tile_loop.loop_header:                            ; preds = %[[VAL_41:.*]], %[[VAL_36]]
+// CHECK:         %[[VAL_42:.*]] = load i32, i32* %[[VAL_2]], align 4
+// CHECK:         %[[VAL_43:.*]] = icmp uge i32 %[[VAL_42]], 1
+// CHECK:         br i1 %[[VAL_43]], label %[[VAL_31]], label %[[VAL_44:.*]]
+// CHECK:       tile_loop.loop_body:                              ; preds = %[[VAL_40]]
+// CHECK:         %[[VAL_45:.*]] = add nuw nsw i32 %[[VAL_42]], 1
+// CHECK:         store i32 %[[VAL_45]], i32* %[[VAL_2]], align 4
+// CHECK:         %[[VAL_46:.*]] = icmp eq i32 %[[VAL_42]], 0
+// CHECK:         %[[VAL_47:.*]] = mul i32 %[[VAL_42]], 1
+// CHECK:         %[[VAL_48:.*]] = add i32 %[[VAL_47]], 0
+// CHECK:         %[[VAL_49:.*]] = mul i32 %[[VAL_42]], 1
+// CHECK:         %[[VAL_50:.*]] = add i32 %[[VAL_49]], 0
+// CHECK:         %[[VAL_51:.*]] = add i32 %[[VAL_50]], %[[VAL_28]]
+// CHECK:         %[[VAL_52:.*]] = mul i32 %[[VAL_42]], 1
+// CHECK:         %[[VAL_53:.*]] = add i32 %[[VAL_52]], 0
+// CHECK:         %[[VAL_54:.*]] = add i32 %[[VAL_29]], %[[VAL_53]]
+// CHECK:         %[[VAL_55:.*]] = icmp ult i32 %[[VAL_51]], %[[VAL_23]]
+// CHECK:         br i1 %[[VAL_55]], label %[[VAL_56:.*]], label %[[VAL_41]]
+// CHECK:       input_x_in_tile-after:                            ; preds = %[[VAL_56]], %[[VAL_44]]
+// CHECK:         br label %[[VAL_40]], !llvm.loop !4
+// CHECK:       tile_loop.loop_exit:                              ; preds = %[[VAL_40]]
+// CHECK:         br label %[[VAL_30]], !llvm.loop !7
+// CHECK:       input_y_in_tile.loop_exit:                        ; preds = %[[VAL_30]]
 // CHECK:         call void @llvm.nvvm.barrier0()
-// CHECK:         %[[VAL_42:.*]] = mul i32 %[[VAL_10]], 1
-// CHECK:         %[[VAL_43:.*]] = add i32 %[[VAL_25]], %[[VAL_42]]
-// CHECK:         store i32 %[[VAL_11]], i32* %[[VAL_0]], align 4
-// CHECK:         br label %[[VAL_44:.*]]
-// CHECK:       output_y_in_tile.loop_header:                     ; preds = %[[VAL_45:.*]], %[[VAL_33]]
-// CHECK:         %[[VAL_46:.*]] = load i32, i32* %[[VAL_0]], align 4
-// CHECK:         %[[VAL_47:.*]] = icmp uge i32 %[[VAL_46]], %[[VAL_21]]
-// CHECK:         br i1 %[[VAL_47]], label %[[VAL_48:.*]], label %[[VAL_49:.*]]
-// CHECK:       output_y_in_tile.loop_body:                       ; preds = %[[VAL_44]]
-// CHECK:         %[[VAL_50:.*]] = add nuw nsw i32 %[[VAL_46]], 4
-// CHECK:         store i32 %[[VAL_50]], i32* %[[VAL_0]], align 4
-// CHECK:         %[[VAL_51:.*]] = icmp eq i32 %[[VAL_46]], %[[VAL_11]]
-// CHECK:         %[[VAL_52:.*]] = add i32 %[[VAL_24]], %[[VAL_46]]
-// CHECK:         %[[VAL_53:.*]] = add i32 0, %[[VAL_42]]
-// CHECK:         %[[VAL_54:.*]] = add i32 %[[VAL_43]], 0
-// CHECK:         %[[VAL_55:.*]] = icmp ult i32 %[[VAL_53]], %[[VAL_23]]
-// CHECK:         br i1 %[[VAL_55]], label %[[VAL_56:.*]], label %[[VAL_45]]
-// CHECK:       output_x_in_tile-after:                           ; preds = %[[VAL_56]], %[[VAL_49]]
-// CHECK:         br label %[[VAL_44]], !llvm.loop !6
-// CHECK:       output_y_in_tile.loop_exit:                       ; preds = %[[VAL_44]]
+// CHECK:         %[[VAL_57:.*]] = mul i32 %[[VAL_12]], 1
+// CHECK:         %[[VAL_58:.*]] = add i32 %[[VAL_27]], %[[VAL_57]]
+// CHECK:         store i32 %[[VAL_13]], i32* %[[VAL_1]], align 4
+// CHECK:         br label %[[VAL_59:.*]]
+// CHECK:       output_y_in_tile.loop_header:                     ; preds = %[[VAL_60:.*]], %[[VAL_35]]
+// CHECK:         %[[VAL_61:.*]] = load i32, i32* %[[VAL_1]], align 4
+// CHECK:         %[[VAL_62:.*]] = icmp uge i32 %[[VAL_61]], %[[VAL_23]]
+// CHECK:         br i1 %[[VAL_62]], label %[[VAL_63:.*]], label %[[VAL_64:.*]]
+// CHECK:       output_y_in_tile.loop_body:                       ; preds = %[[VAL_59]]
+// CHECK:         %[[VAL_65:.*]] = add nuw nsw i32 %[[VAL_61]], 4
+// CHECK:         store i32 %[[VAL_65]], i32* %[[VAL_1]], align 4
+// CHECK:         %[[VAL_66:.*]] = icmp eq i32 %[[VAL_61]], %[[VAL_13]]
+// CHECK:         %[[VAL_67:.*]] = add i32 %[[VAL_26]], %[[VAL_61]]
+// CHECK:         store i32 0, i32* %[[VAL_0]], align 4
+// CHECK:         br label %[[VAL_68:.*]]
+// CHECK:       tile_loop.loop_header6:                           ; preds = %[[VAL_69:.*]], %[[VAL_64]]
+// CHECK:         %[[VAL_70:.*]] = load i32, i32* %[[VAL_0]], align 4
+// CHECK:         %[[VAL_71:.*]] = icmp uge i32 %[[VAL_70]], 1
+// CHECK:         br i1 %[[VAL_71]], label %[[VAL_60]], label %[[VAL_72:.*]]
+// CHECK:       tile_loop.loop_body7:                             ; preds = %[[VAL_68]]
+// CHECK:         %[[VAL_73:.*]] = add nuw nsw i32 %[[VAL_70]], 1
+// CHECK:         store i32 %[[VAL_73]], i32* %[[VAL_0]], align 4
+// CHECK:         %[[VAL_74:.*]] = icmp eq i32 %[[VAL_70]], 0
+// CHECK:         %[[VAL_75:.*]] = mul i32 %[[VAL_70]], 1
+// CHECK:         %[[VAL_76:.*]] = add i32 %[[VAL_75]], 0
+// CHECK:         %[[VAL_77:.*]] = mul i32 %[[VAL_70]], 1
+// CHECK:         %[[VAL_78:.*]] = add i32 %[[VAL_77]], 0
+// CHECK:         %[[VAL_79:.*]] = add i32 %[[VAL_78]], %[[VAL_57]]
+// CHECK:         %[[VAL_80:.*]] = mul i32 %[[VAL_70]], 1
+// CHECK:         %[[VAL_81:.*]] = add i32 %[[VAL_80]], 0
+// CHECK:         %[[VAL_82:.*]] = add i32 %[[VAL_58]], %[[VAL_81]]
+// CHECK:         %[[VAL_83:.*]] = icmp ult i32 %[[VAL_79]], %[[VAL_25]]
+// CHECK:         br i1 %[[VAL_83]], label %[[VAL_84:.*]], label %[[VAL_69]]
+// CHECK:       output_x_in_tile-after:                           ; preds = %[[VAL_84]], %[[VAL_72]]
+// CHECK:         br label %[[VAL_68]], !llvm.loop !8
+// CHECK:       tile_loop.loop_exit5:                             ; preds = %[[VAL_68]]
+// CHECK:         br label %[[VAL_59]], !llvm.loop !9
+// CHECK:       output_y_in_tile.loop_exit:                       ; preds = %[[VAL_59]]
 // CHECK:         ret void
-// CHECK:       input_x_in_tile-true:                             ; preds = %[[VAL_34]]
-// CHECK:         %[[VAL_57:.*]] = getelementptr inbounds [1 x [100 x [200 x float]]], [1 x [100 x [200 x float]]]* %[[VAL_8]], i32 0, i32 0, i32 %[[VAL_37]], i32 %[[VAL_39]]
-// CHECK:         %[[VAL_58:.*]] = load float, float* %[[VAL_57]], align 4, !invariant.load !7
-// CHECK:         %[[VAL_59:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i32 0, i32 %[[VAL_31]], i32 %[[VAL_38]]
-// CHECK:         store float %[[VAL_58]], float addrspace(3)* %[[VAL_59]], align 4
-// CHECK:         br label %[[VAL_29]]
-// CHECK:       output_x_in_tile-true:                            ; preds = %[[VAL_49]]
-// CHECK:         %[[VAL_60:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i32 0, i32 %[[VAL_53]], i32 %[[VAL_46]]
-// CHECK:         %[[VAL_61:.*]] = load float, float addrspace(3)* %[[VAL_60]], align 4
-// CHECK:         %[[VAL_62:.*]] = getelementptr inbounds [200 x [100 x float]], [200 x [100 x float]]* %[[VAL_7]], i32 0, i32 %[[VAL_52]], i32 %[[VAL_54]]
-// CHECK:         store float %[[VAL_61]], float* %[[VAL_62]], align 4
-// CHECK:         br label %[[VAL_45]]
+// CHECK:       input_x_in_tile-true:                             ; preds = %[[VAL_44]]
+// CHECK:         %[[VAL_85:.*]] = getelementptr inbounds [1 x [100 x [200 x float]]], [1 x [100 x [200 x float]]]* %[[VAL_10]], i32 0, i32 0, i32 %[[VAL_39]], i32 %[[VAL_54]]
+// CHECK:         %[[VAL_86:.*]] = load float, float* %[[VAL_85]], align 4, !invariant.load !10
+// CHECK:         %[[VAL_87:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i32 0, i32 %[[VAL_33]], i32 %[[VAL_51]]
+// CHECK:         store float %[[VAL_86]], float addrspace(3)* %[[VAL_87]], align 4
+// CHECK:         br label %[[VAL_41]]
+// CHECK:       output_x_in_tile-true:                            ; preds = %[[VAL_72]]
+// CHECK:         %[[VAL_88:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i32 0, i32 %[[VAL_79]], i32 %[[VAL_61]]
+// CHECK:         %[[VAL_89:.*]] = load float, float addrspace(3)* %[[VAL_88]], align 4
+// CHECK:         %[[VAL_90:.*]] = getelementptr inbounds [200 x [100 x float]], [200 x [100 x float]]* %[[VAL_9]], i32 0, i32 %[[VAL_67]], i32 %[[VAL_82]]
+// CHECK:         store float %[[VAL_89]], float* %[[VAL_90]], align 4
+// CHECK:         br label %[[VAL_69]]
 
 HloModule Test