[XLA GPU] [NFC] Tiling codegen: minor simplifications and renames, add more comments.

PiperOrigin-RevId: 270153908
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 87f7575..4eed95b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1854,25 +1854,6 @@
   return emit_status;
 }
 
-namespace {
-
-std::tuple<llvm::Value*, int64> GetStartOffsetAndStepForX(
-    const KernelMappingScheme& mapping_scheme, llvm::IRBuilder<>* b,
-    llvm::Value* x, const IrEmitterUnnested::ConstantGenerator& constant) {
-  llvm::Value* start_offset_x;
-  int64 step_x;
-  if (mapping_scheme.DilatedX()) {
-    start_offset_x = x;
-    step_x = mapping_scheme.GetNumberOfThreadsForDimensionX();
-  } else {
-    start_offset_x = b->CreateMul(
-        x, constant(mapping_scheme.GetTileSizeForDimensionX() /
-                    mapping_scheme.GetNumberOfThreadsForDimensionX()));
-    step_x = 1;
-  }
-  return std::make_tuple(start_offset_x, step_x);
-}
-
 // Emits code to process up to
 // (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
 // given `emit_elem_function` is the function to emit code to process one
@@ -1884,23 +1865,21 @@
 //
 // Pseudocode:
 //
-// for (y_indvar = 0; y_indvar < tile_height_bound; y_indvar += num_threads_y) {
-//   if (y_indvar < tile_height) {
-//     for (j = 0; j < tile_size_x / num_threads_x; j++) {
-//       if (dilated) {
-//         x_pos = x + j * num_threads_x;
-//       } else {
-//         x_pos = x * (tile_size_x / num_threads_x) + j;
-//       }
+// for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
+//   for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
+//     if (dilated) {
+//       x_loc = x + j * num_threads_x;
+//     } else {
+//       x_loc = x * (tile_size_x / num_threads_x) + j;
+//     }
 //
-//       if (x_pos < tile_width) {
-//         EmitElementary(y + y_indvar, x_pos);
-//       }
+//     if (x_loc < tile_width) {
+//       emit_elem_function(y + y_loc, x_loc);
 //     }
 //   }
 // }
 //
-void EmitTiledElementalCodeWithBoundsCheck(
+static void EmitTile(
     const KernelMappingScheme& mapping_scheme,
     const IrArray::Index& tile_origin_index, const string& loop_name,
     KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y,
@@ -1914,23 +1893,35 @@
   int64 num_threads_y = mapping_scheme.GetNumberOfThreadsForDimensionY();
   int64 tile_size_x = mapping_scheme.GetTileSizeForDimensionX();
 
+  int64 x_num_steps = tile_size_x / num_threads_x;
   llvm::Value* start_offset_x;
   int64 step_x;
-  std::tie(start_offset_x, step_x) =
-      GetStartOffsetAndStepForX(mapping_scheme, b, x, constant);
+
+  if (mapping_scheme.DilatedX()) {
+    // Using dilated mapping scheme, each thread steps with a stride of number
+    // of threads.
+    start_offset_x = x;
+    step_x = num_threads_x;
+  } else {
+    // Otherwise, the stride is one, but we multiply each offset by the limit of
+    // number of steps which can be made.
+    start_offset_x = b->CreateMul(x, constant(x_num_steps));
+    step_x = 1;
+  }
+
   IrArray::Index source_idx = tile_origin_index.AddOffsetToDim(
       start_offset_x, KernelMappingScheme::DimX, b);
 
   ksl->For(
-      loop_name,
+      loop_name + "_y_in_tile",
       /*start=*/y,
       /*end=*/tile_height,
       /*step=*/constant(num_threads_y), [&](llvm::Value* y_loc) {
         IrArray::Index source_idx_y =
             source_idx.AddOffsetToDim(y_loc, KernelMappingScheme::DimY, b);
-        for (int64 j = 0; j < tile_size_x / num_threads_x; j++) {
+        for (int64 j = 0; j < x_num_steps; j++) {
           llvm::Value* x_loc =
-              b->CreateAdd(constant(j * step_x), start_offset_x);
+              b->CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
           IrArray::Index source_idx_x = source_idx_y.AddOffsetToDim(
               constant(j * step_x), KernelMappingScheme::DimX, b);
           // The if-statement below always evaluates to true for the blocks
@@ -1940,7 +1931,6 @@
         }
       });
 }
-}  // namespace
 
 // Emits code to process a tensor element in a tile for the given kCopy HLO that
 // performs a 0-2-1 transpose.
@@ -2366,7 +2356,7 @@
 
 llvm::Value* IrEmitterUnnested::EmitTilingKernel(
     const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
-    TileElementGenerator tile_element_generator) {
+    const TileElementGenerator& tile_element_generator) {
   absl::Span<const int64> dims_in_tile = mapping_scheme.GetDimensionsInTiles();
   absl::Span<const int64> dims_in_block =
       mapping_scheme.GetDimensionsInBlocks();
@@ -2407,15 +2397,12 @@
             PRED /*arbitrary*/, mapping_scheme.GetDimensionsInBlocks()),
         &b_);
 
-    std::vector<llvm::Value*> multidim;
-    multidim.reserve(3);
-    for (int i = 0; i < 3; ++i) {
-      multidim.push_back(
-          b_.CreateMul(starting_block[i],
-                       llvm::ConstantInt::get(starting_block[i]->getType(),
-                                              mapping_scheme.BlockSize(i)),
-                       "block_origin." + std::to_string(i)));
-    }
+    std::vector<llvm::Value*> multidim = {
+        b_.CreateMul(starting_block[0],
+                     llvm::ConstantInt::get(starting_block[0]->getType(),
+                                            mapping_scheme.BlockSizeZ()),
+                     "block_origin.z"),
+        starting_block[1], starting_block[2]};
     return IrArray::Index(multidim, mapping_scheme.GetDimensionsInTiles(),
                           starting_block.GetType());
   }();
@@ -2441,18 +2428,17 @@
   };
 
   int dim_z = KernelMappingScheme::DimZ;
-
-  if (mapping_scheme.BlockSize(dim_z) == 1) {
+  if (mapping_scheme.BlockSizeZ() == 1) {
     emit_tile(starting_tile);
   } else {
     llvm::Value* starting_tile_index_for_dim = starting_tile[dim_z];
-    llvm::Value* block_size_for_dim = constant(mapping_scheme.BlockSize(dim_z));
+    llvm::Value* block_size_for_dim = constant(mapping_scheme.BlockSizeZ());
     llvm::Value* block_id_for_dim =
         b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
     llvm::Value* last_block_for_dim = constant(dims_in_block[dim_z] - 1);
     llvm::Value* last_block_size_for_dim =
         constant(dims_in_tile[dim_z] -
-                 (dims_in_block[dim_z] - 1) * mapping_scheme.BlockSize(dim_z));
+                 (dims_in_block[dim_z] - 1) * mapping_scheme.BlockSizeZ());
 
     llvm::Value* num_tiles_in_block =
         b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim),
@@ -2497,6 +2483,7 @@
     HloInstruction* hlo, Thunk* kernel_thunk,
     absl::Span<const int64> reduced_output_dims,
     absl::Span<const int64> tiled_param_ids) {
+  LOG(ERROR) << "EmitHlo021Tile";
   constexpr int kNumRows = 4;
   KernelMappingScheme mapping_scheme(
       reduced_output_dims, /*tile_size_y=*/kWarpSize,
@@ -2581,33 +2568,33 @@
           // tile[y, x] = input[index]
           // Note that tile_width and tile_height are flipped here because we
           // are reading a transposed tile.
-          EmitTiledElementalCodeWithBoundsCheck(
-              mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
-              tile_width, tile_height,
-              [&](const IrArray::Index& index, llvm::Value* y_loc,
-                  llvm::Value* x_loc, int64 /*x_iter_num*/) {
-                for (int64 id : tiled_param_ids) {
-                  IrArray& input_in_logical_shape =
-                      param_in_reduced_shape_arrays[id];
+          EmitTile(mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
+                   tile_width, tile_height,
+                   [&](const IrArray::Index& index, llvm::Value* y_loc,
+                       llvm::Value* x_loc, int64 /*x_iter_num*/) {
+                     for (int64 id : tiled_param_ids) {
+                       IrArray& input_in_logical_shape =
+                           param_in_reduced_shape_arrays[id];
 
-                  llvm::Value* shmem_buffer = param_shmem_buffers[id];
-                  llvm::Value* zero = llvm::ConstantInt::get(index_type, 0);
-                  // TODO(jlebar): Add AA metadata to this store.  Tile buffers
-                  // are global variables, so LLVM can't infer much about it.
-                  Store(input_in_logical_shape.EmitReadArrayElement(
-                            index, &b_, "input_element"),
-                        GEP(shmem_buffer, {zero, y_loc, x_loc}));
-                }
-              });
+                       llvm::Value* shmem_buffer = param_shmem_buffers[id];
+                       llvm::Value* zero =
+                           llvm::ConstantInt::get(index_type, 0);
+                       // TODO(jlebar): Add AA metadata to this store.  Tile
+                       // buffers are global variables, so LLVM can't infer much
+                       // about it.
+                       Store(input_in_logical_shape.EmitReadArrayElement(
+                                 index, &b_, "input_element"),
+                             GEP(shmem_buffer, {zero, y_loc, x_loc}));
+                     }
+                   });
 
           // Wait for all threads to reach this point using `__syncthreads` in
           // CUDA.
           EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
         }
 
-        EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name,
-                                              ksl, &b_, y, x, tile_height,
-                                              tile_width, element_generator);
+        EmitTile(mapping_scheme, index, loop_name, ksl, &b_, y, x, tile_height,
+                 tile_width, element_generator);
         bool block_contains_multi_tiles =
             mapping_scheme.GetNumberOfTilesInOneBlock() > 1;
 
@@ -2994,15 +2981,16 @@
         ThreadsPerBlockLimit(ir_emitter_context_->device_description());
     if (IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape,
                                              dims_in_elem[2])) {
+      // Vectorized loads: two elements per thread.
       tile_size_x = std::min(2 * hw_threads_per_block_limit, dims_in_elem[2]);
       num_threads_x = tile_size_x / 2;
       dilated_x = false;
     } else {
+      // One element per thread.
       tile_size_x = std::min(hw_threads_per_block_limit, dims_in_elem[2]);
       num_threads_x = tile_size_x;
     }
-    int64 kNumElementsPerPartialSum = 128;
-    tile_size_y = kNumElementsPerPartialSum;
+    tile_size_y = 128;
   }
 
   KernelMappingScheme mapping_scheme(dims_in_elem, tile_size_y, tile_size_x,
@@ -3103,9 +3091,8 @@
       [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
           const string& loop_name, llvm::Value* tile_height,
           llvm::Value* tile_width, KernelSupportLibrary* ksl) {
-        EmitTiledElementalCodeWithBoundsCheck(
-            reduction_info.GetKernelMappingScheme(), index, loop_name, ksl, &b_,
-            y, x, tile_height, tile_width, emit_reduction_tile);
+        EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
+                 &b_, y, x, tile_height, tile_width, emit_reduction_tile);
       });
   EmitEpilogueForReduction(unnested_hlo, reduction_info, reduce_instructions,
                            reduction_output_shape_indices, reducers, lane_id);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 3afcde8..a5b75c8 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -209,9 +209,9 @@
   // scheme.
   //
   // Returns lane_id as an LLVM value.
-  llvm::Value* EmitTilingKernel(const KernelMappingScheme& mapping_scheme,
-                                llvm::Type* index_ty,
-                                TileElementGenerator tile_element_generator);
+  llvm::Value* EmitTilingKernel(
+      const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
+      const TileElementGenerator& tile_element_generator);
 
   // Emits code to process a tensor element in a tile for the given kCopy HLO
   // that performs a 0-2-1 transpose.
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
index e25f1b6..f955283 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
@@ -85,7 +85,7 @@
         dims_in_tiles_{dims_in_elems[0],
                        CeilOfRatio<int64>(dims_in_elems[1], tile_size_y),
                        CeilOfRatio<int64>(dims_in_elems[2], tile_size_x)},
-        dims_in_blocks_{dims_in_elems[0] / block_size_z, dims_in_tiles_[1],
+        dims_in_blocks_{dims_in_tiles_[0] / block_size_z, dims_in_tiles_[1],
                         dims_in_tiles_[2]},
         block_size_z_{block_size_z},
         num_threads_x_(num_threads_x),
@@ -109,8 +109,8 @@
     return dims_in_elems_;
   }
 
-  // Ratio of elements in each dimension over tile sizes for Z/Y/X
-  // respectively.
+  // Number of tiles required to cover the input tensor in each dimension (Z/Y/X
+  // respectively).
   absl::Span<const int64> GetDimensionsInTiles() const {
     return dims_in_tiles_;
   }
@@ -126,18 +126,14 @@
 
   int64 GetNumberOfTilesInOneBlock() const { return block_size_z_; }
 
-  int64 BlockSize(int d) const {
-    DCHECK(d >= DimZ && d <= DimX);
-    if (d == DimZ) {
-      return block_size_z_;
-    }
-    return 1;
-  }
+  int64 BlockSizeZ() const { return block_size_z_; }
 
   int64 GetNumberOfBlocks() const {
     return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies<int64>());
   }
 
+  // Tile size for a given dimensions. Tiles are assigned per thread block,
+  // and are processed by all threads in the block.
   int64 GetTileSizeForDimension(int d) const { return tile_sizes_.at(d); }
   int64 GetTileSizeForDimensionX() const {
     return GetTileSizeForDimension(DimX);