[XLA GPU] [NFC] Tiling codegen: minor simplifications and renames, add more comments.
PiperOrigin-RevId: 270153908
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 87f7575..4eed95b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1854,25 +1854,6 @@
return emit_status;
}
-namespace {
-
-std::tuple<llvm::Value*, int64> GetStartOffsetAndStepForX(
- const KernelMappingScheme& mapping_scheme, llvm::IRBuilder<>* b,
- llvm::Value* x, const IrEmitterUnnested::ConstantGenerator& constant) {
- llvm::Value* start_offset_x;
- int64 step_x;
- if (mapping_scheme.DilatedX()) {
- start_offset_x = x;
- step_x = mapping_scheme.GetNumberOfThreadsForDimensionX();
- } else {
- start_offset_x = b->CreateMul(
- x, constant(mapping_scheme.GetTileSizeForDimensionX() /
- mapping_scheme.GetNumberOfThreadsForDimensionX()));
- step_x = 1;
- }
- return std::make_tuple(start_offset_x, step_x);
-}
-
// Emits code to process up to
// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
// given `emit_elem_function` is the function to emit code to process one
@@ -1884,23 +1865,21 @@
//
// Pseudocode:
//
-// for (y_indvar = 0; y_indvar < tile_height_bound; y_indvar += num_threads_y) {
-// if (y_indvar < tile_height) {
-// for (j = 0; j < tile_size_x / num_threads_x; j++) {
-// if (dilated) {
-// x_pos = x + j * num_threads_x;
-// } else {
-// x_pos = x * (tile_size_x / num_threads_x) + j;
-// }
+// for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
+// for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
+// if (dilated) {
+// x_loc = x + j * num_threads_x;
+// } else {
+// x_loc = x * (tile_size_x / num_threads_x) + j;
+// }
//
-// if (x_pos < tile_width) {
-// EmitElementary(y + y_indvar, x_pos);
-// }
+// if (x_loc < tile_width) {
+// emit_elem_function(y + y_loc, x_loc);
// }
// }
// }
//
-void EmitTiledElementalCodeWithBoundsCheck(
+static void EmitTile(
const KernelMappingScheme& mapping_scheme,
const IrArray::Index& tile_origin_index, const string& loop_name,
KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y,
@@ -1914,23 +1893,35 @@
int64 num_threads_y = mapping_scheme.GetNumberOfThreadsForDimensionY();
int64 tile_size_x = mapping_scheme.GetTileSizeForDimensionX();
+ int64 x_num_steps = tile_size_x / num_threads_x;
llvm::Value* start_offset_x;
int64 step_x;
- std::tie(start_offset_x, step_x) =
- GetStartOffsetAndStepForX(mapping_scheme, b, x, constant);
+
+ if (mapping_scheme.DilatedX()) {
+ // Using dilated mapping scheme, each thread steps with a stride of number
+ // of threads.
+ start_offset_x = x;
+ step_x = num_threads_x;
+ } else {
+ // Otherwise, the stride is one, but we multiply each offset by the limit of
+ // number of steps which can be made.
+ start_offset_x = b->CreateMul(x, constant(x_num_steps));
+ step_x = 1;
+ }
+
IrArray::Index source_idx = tile_origin_index.AddOffsetToDim(
start_offset_x, KernelMappingScheme::DimX, b);
ksl->For(
- loop_name,
+ loop_name + "_y_in_tile",
/*start=*/y,
/*end=*/tile_height,
/*step=*/constant(num_threads_y), [&](llvm::Value* y_loc) {
IrArray::Index source_idx_y =
source_idx.AddOffsetToDim(y_loc, KernelMappingScheme::DimY, b);
- for (int64 j = 0; j < tile_size_x / num_threads_x; j++) {
+ for (int64 j = 0; j < x_num_steps; j++) {
llvm::Value* x_loc =
- b->CreateAdd(constant(j * step_x), start_offset_x);
+ b->CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
IrArray::Index source_idx_x = source_idx_y.AddOffsetToDim(
constant(j * step_x), KernelMappingScheme::DimX, b);
// The if-statement below always evaluates to true for the blocks
@@ -1940,7 +1931,6 @@
}
});
}
-} // namespace
// Emits code to process a tensor element in a tile for the given kCopy HLO that
// performs a 0-2-1 transpose.
@@ -2366,7 +2356,7 @@
llvm::Value* IrEmitterUnnested::EmitTilingKernel(
const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
- TileElementGenerator tile_element_generator) {
+ const TileElementGenerator& tile_element_generator) {
absl::Span<const int64> dims_in_tile = mapping_scheme.GetDimensionsInTiles();
absl::Span<const int64> dims_in_block =
mapping_scheme.GetDimensionsInBlocks();
@@ -2407,15 +2397,12 @@
PRED /*arbitrary*/, mapping_scheme.GetDimensionsInBlocks()),
&b_);
- std::vector<llvm::Value*> multidim;
- multidim.reserve(3);
- for (int i = 0; i < 3; ++i) {
- multidim.push_back(
- b_.CreateMul(starting_block[i],
- llvm::ConstantInt::get(starting_block[i]->getType(),
- mapping_scheme.BlockSize(i)),
- "block_origin." + std::to_string(i)));
- }
+ std::vector<llvm::Value*> multidim = {
+ b_.CreateMul(starting_block[0],
+ llvm::ConstantInt::get(starting_block[0]->getType(),
+ mapping_scheme.BlockSizeZ()),
+ "block_origin.z"),
+ starting_block[1], starting_block[2]};
return IrArray::Index(multidim, mapping_scheme.GetDimensionsInTiles(),
starting_block.GetType());
}();
@@ -2441,18 +2428,17 @@
};
int dim_z = KernelMappingScheme::DimZ;
-
- if (mapping_scheme.BlockSize(dim_z) == 1) {
+ if (mapping_scheme.BlockSizeZ() == 1) {
emit_tile(starting_tile);
} else {
llvm::Value* starting_tile_index_for_dim = starting_tile[dim_z];
- llvm::Value* block_size_for_dim = constant(mapping_scheme.BlockSize(dim_z));
+ llvm::Value* block_size_for_dim = constant(mapping_scheme.BlockSizeZ());
llvm::Value* block_id_for_dim =
b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
llvm::Value* last_block_for_dim = constant(dims_in_block[dim_z] - 1);
llvm::Value* last_block_size_for_dim =
constant(dims_in_tile[dim_z] -
- (dims_in_block[dim_z] - 1) * mapping_scheme.BlockSize(dim_z));
+ (dims_in_block[dim_z] - 1) * mapping_scheme.BlockSizeZ());
llvm::Value* num_tiles_in_block =
b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim),
@@ -2497,6 +2483,7 @@
HloInstruction* hlo, Thunk* kernel_thunk,
absl::Span<const int64> reduced_output_dims,
absl::Span<const int64> tiled_param_ids) {
+ LOG(ERROR) << "EmitHlo021Tile";
constexpr int kNumRows = 4;
KernelMappingScheme mapping_scheme(
reduced_output_dims, /*tile_size_y=*/kWarpSize,
@@ -2581,33 +2568,33 @@
// tile[y, x] = input[index]
// Note that tile_width and tile_height are flipped here because we
// are reading a transposed tile.
- EmitTiledElementalCodeWithBoundsCheck(
- mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
- tile_width, tile_height,
- [&](const IrArray::Index& index, llvm::Value* y_loc,
- llvm::Value* x_loc, int64 /*x_iter_num*/) {
- for (int64 id : tiled_param_ids) {
- IrArray& input_in_logical_shape =
- param_in_reduced_shape_arrays[id];
+ EmitTile(mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
+ tile_width, tile_height,
+ [&](const IrArray::Index& index, llvm::Value* y_loc,
+ llvm::Value* x_loc, int64 /*x_iter_num*/) {
+ for (int64 id : tiled_param_ids) {
+ IrArray& input_in_logical_shape =
+ param_in_reduced_shape_arrays[id];
- llvm::Value* shmem_buffer = param_shmem_buffers[id];
- llvm::Value* zero = llvm::ConstantInt::get(index_type, 0);
- // TODO(jlebar): Add AA metadata to this store. Tile buffers
- // are global variables, so LLVM can't infer much about it.
- Store(input_in_logical_shape.EmitReadArrayElement(
- index, &b_, "input_element"),
- GEP(shmem_buffer, {zero, y_loc, x_loc}));
- }
- });
+ llvm::Value* shmem_buffer = param_shmem_buffers[id];
+ llvm::Value* zero =
+ llvm::ConstantInt::get(index_type, 0);
+ // TODO(jlebar): Add AA metadata to this store. Tile
+ // buffers are global variables, so LLVM can't infer much
+ // about it.
+ Store(input_in_logical_shape.EmitReadArrayElement(
+ index, &b_, "input_element"),
+ GEP(shmem_buffer, {zero, y_loc, x_loc}));
+ }
+ });
// Wait for all threads to reach this point using `__syncthreads` in
// CUDA.
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
}
- EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name,
- ksl, &b_, y, x, tile_height,
- tile_width, element_generator);
+ EmitTile(mapping_scheme, index, loop_name, ksl, &b_, y, x, tile_height,
+ tile_width, element_generator);
bool block_contains_multi_tiles =
mapping_scheme.GetNumberOfTilesInOneBlock() > 1;
@@ -2994,15 +2981,16 @@
ThreadsPerBlockLimit(ir_emitter_context_->device_description());
if (IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape,
dims_in_elem[2])) {
+ // Vectorized loads: two elements per thread.
tile_size_x = std::min(2 * hw_threads_per_block_limit, dims_in_elem[2]);
num_threads_x = tile_size_x / 2;
dilated_x = false;
} else {
+ // One element per thread.
tile_size_x = std::min(hw_threads_per_block_limit, dims_in_elem[2]);
num_threads_x = tile_size_x;
}
- int64 kNumElementsPerPartialSum = 128;
- tile_size_y = kNumElementsPerPartialSum;
+ tile_size_y = 128;
}
KernelMappingScheme mapping_scheme(dims_in_elem, tile_size_y, tile_size_x,
@@ -3103,9 +3091,8 @@
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
const string& loop_name, llvm::Value* tile_height,
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
- EmitTiledElementalCodeWithBoundsCheck(
- reduction_info.GetKernelMappingScheme(), index, loop_name, ksl, &b_,
- y, x, tile_height, tile_width, emit_reduction_tile);
+ EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
+ &b_, y, x, tile_height, tile_width, emit_reduction_tile);
});
EmitEpilogueForReduction(unnested_hlo, reduction_info, reduce_instructions,
reduction_output_shape_indices, reducers, lane_id);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 3afcde8..a5b75c8 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -209,9 +209,9 @@
// scheme.
//
// Returns lane_id as an LLVM value.
- llvm::Value* EmitTilingKernel(const KernelMappingScheme& mapping_scheme,
- llvm::Type* index_ty,
- TileElementGenerator tile_element_generator);
+ llvm::Value* EmitTilingKernel(
+ const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
+ const TileElementGenerator& tile_element_generator);
// Emits code to process a tensor element in a tile for the given kCopy HLO
// that performs a 0-2-1 transpose.
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
index e25f1b6..f955283 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
@@ -85,7 +85,7 @@
dims_in_tiles_{dims_in_elems[0],
CeilOfRatio<int64>(dims_in_elems[1], tile_size_y),
CeilOfRatio<int64>(dims_in_elems[2], tile_size_x)},
- dims_in_blocks_{dims_in_elems[0] / block_size_z, dims_in_tiles_[1],
+ dims_in_blocks_{dims_in_tiles_[0] / block_size_z, dims_in_tiles_[1],
dims_in_tiles_[2]},
block_size_z_{block_size_z},
num_threads_x_(num_threads_x),
@@ -109,8 +109,8 @@
return dims_in_elems_;
}
- // Ratio of elements in each dimension over tile sizes for Z/Y/X
- // respectively.
+ // Number of tiles required to cover the input tensor in each dimension (Z/Y/X
+ // respectively).
absl::Span<const int64> GetDimensionsInTiles() const {
return dims_in_tiles_;
}
@@ -126,18 +126,14 @@
int64 GetNumberOfTilesInOneBlock() const { return block_size_z_; }
- int64 BlockSize(int d) const {
- DCHECK(d >= DimZ && d <= DimX);
- if (d == DimZ) {
- return block_size_z_;
- }
- return 1;
- }
+ int64 BlockSizeZ() const { return block_size_z_; }
int64 GetNumberOfBlocks() const {
return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies<int64>());
}
+ // Tile size for a given dimensions. Tiles are assigned per thread block,
+ // and are processed by all threads in the block.
int64 GetTileSizeForDimension(int d) const { return tile_sizes_.at(d); }
int64 GetTileSizeForDimensionX() const {
return GetTileSizeForDimension(DimX);