[XLA/GPU] Use intra-block communication for row reduction
This improves precision and performance of row reduction (the new algorithm
requires only one global atomic per block, that is, up to <tiling size> * 1024
elements can be reduced without atomics), and allows us to reduce the "step
size" of tree reduction.
The algorithm roughly is:
```
__global__ void reduce(int num_rows, float *in, float out) {
__shared__ float[32] cache;
int offset = blockDim.x * blockIdx.x + threadIdx.x;
if (offset >= num_rows) return;
int tile_bound = std::min(offset + kTileSizeX, num_rows);
float accum = 0;
for (int i=offset; i<num_rows; i+= blockDim.x) {
accum += in[i];
}
accum = warp_reduce(accum);
if (threadIdx.x % kWarpSize == 0) {
cache[threadIdx.x / kWarpSize] = accum;
}
__syncthreads();
if (threadIdx.x / kWarpSize == 0) {
float block_accum = cache[threadIdx.x % kWarpSize];
block_accum = warp_reduce(accum);
if (threadIdx.x == 0) {
out += block_accum;
}
}
}
```
PiperOrigin-RevId: 295224089
Change-Id: Ifd6e2990d6b7724a7d9a27f13492eddf23b27e82
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index c535325..8646b4b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -133,7 +133,8 @@
CHECK_EQ(reduction_dimensions.dimensions[0], 1);
return {tile_z, 1, 16};
}
- if (reduction_dimensions.dimensions[2] % (kWarpSize * 64) == 0) {
+ if (reduction_dimensions.dimensions[2] % (kWarpSize * kWarpSize * 64) ==
+ 0) {
return {tile_z, 1, 64};
}
return {tile_z, 1, 8};
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 37204af..c6b167f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2127,29 +2127,36 @@
Store(init_ir_value,
InBoundsGEP(partial_result_address, {b_.getInt32(i)}));
}
+ reduction_info->GetMutableInitialValues()->push_back(init_ir_value);
- // Allocate __shared__ cache[num_partial_results][num_threads][num_threads +
- // 1], where num_threads == num_threads_x == num_threads_y. The "+1" is
- // used to avoid bank conflicts.
- if (!reduction_info->IsRowReduction()) {
- auto& mapping_scheme = reduction_info->GetKernelMappingScheme();
- int64 num_threads = mapping_scheme.GetNumThreadsX();
- CHECK_EQ(num_threads, mapping_scheme.GetNumThreadsY());
- llvm::Type* primitive_type = llvm_ir::PrimitiveTypeToIrType(
- reduce_inst->shape().element_type(), module_);
- llvm::Type* buffer_type = llvm::ArrayType::get(
- llvm::ArrayType::get(
- llvm::ArrayType::get(primitive_type, num_threads + 1),
- num_threads),
- num_partial_results);
-
- llvm::GlobalVariable* shared_cache_per_reduce =
- llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
- buffer_type,
- absl::StrCat("shared_cache_", i));
- reduction_info->GetMutableSharedCache()->push_back(
- shared_cache_per_reduce);
- }
+ auto& mapping_scheme = reduction_info->GetKernelMappingScheme();
+ int64 num_threads_x = mapping_scheme.GetNumThreadsX();
+ llvm::Type* primitive_type = llvm_ir::PrimitiveTypeToIrType(
+ reduce_inst->shape().element_type(), module_);
+ llvm::Type* buffer_type = [&] {
+ if (reduction_info->IsRowReduction()) {
+ // Allocate __shared__ cache[num_partial_results][num_threads].
+ return llvm::ArrayType::get(
+ llvm::ArrayType::get(primitive_type, num_threads_x),
+ num_partial_results);
+ } else {
+ // Allocate __shared__
+ // cache[num_partial_results][num_threads][num_threads + 1], where
+ // num_threads == num_threads_x == num_threads_y. The "+1" is used to
+ // avoid bank conflicts.
+ CHECK_EQ(num_threads_x, mapping_scheme.GetNumThreadsY());
+ return llvm::ArrayType::get(
+ llvm::ArrayType::get(
+ llvm::ArrayType::get(primitive_type, num_threads_x + 1),
+ num_threads_x),
+ num_partial_results);
+ }
+ }();
+ llvm::GlobalVariable* shared_cache_per_reduce =
+ llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
+ buffer_type,
+ absl::StrCat("shared_cache_", i));
+ reduction_info->GetMutableSharedCache()->push_back(shared_cache_per_reduce);
}
}
@@ -2241,13 +2248,6 @@
int num_reduces = reducers.size();
absl::Span<llvm::AllocaInst* const> partial_result_addresses =
reduction_info.GetPartialResultAddresses();
- if (reduction_info.IsRowReduction()) {
- EmitFullWarpShuffleDownLoopForAllReduces(reducers,
- partial_result_addresses);
- llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ICmpEQ(thread_id_info.lane_id, constant(0)), "lane_id_is_zero", &b_);
- llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
- }
int num_partial_results = GetNumberOfPartialResults(reduction_info);
@@ -2285,25 +2285,67 @@
element_index.GetType());
llvm::Value* output_address = output_array.EmitArrayElementAddress(
output_index, &b_, "output_element_address");
-
llvm::Value* current_output = b_.CreateInBoundsGEP(
partial_result_addresses[i], {constant(j)}, "current_output");
+ llvm::GlobalVariable* shared_cache = reduction_info.GetSharedCache()[i];
+
+ // __shared__ memory uses a different address space, so we cast it to
+ // global address space before writing or reading.
+ auto shared_to_global = [&](llvm::Value* input, llvm::Twine name = "") {
+ return b_.CreateAddrSpaceCast(
+ input,
+ llvm::PointerType::get(input->getType()->getPointerElementType(),
+ /*AddressSpace=*/0),
+ name);
+ };
+
+ auto is_zero = [&](llvm::Value* value) {
+ return b_.CreateICmpEQ(value, constant(0));
+ };
+
+ KernelSupportLibrary ksl(&b_);
+ llvm::Type* element_type =
+ partial_result_addresses[i]->getType()->getElementType();
if (reduction_info.IsRowReduction()) {
- TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
- *reducers[i], output_address, current_output));
+ EmitFullWarpShuffleDownLoopForReduce(reducers[i], element_type,
+ current_output);
+ llvm::Value* warp_id =
+ b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize));
+ ksl.If(is_zero(thread_id_info.lane_id), [&] {
+ llvm::Value* shmem_output_addr =
+ shared_to_global(b_.CreateInBoundsGEP(
+ shared_cache, {b_.getInt32(0), constant(j), warp_id}));
+ b_.CreateStore(b_.CreateLoad(current_output), shmem_output_addr);
+ });
+
+ EmitSyncThreads();
+ ksl.If(is_zero(warp_id), [&] {
+ llvm::Value* block_accum_addr = shared_to_global(b_.CreateInBoundsGEP(
+ shared_cache,
+ {b_.getInt32(0), constant(j), thread_id_info.lane_id}));
+ llvm::Value* initial_value = reduction_info.GetInitialValues()[i];
+ llvm::Value* initial_value_addr = b_.CreateAlloca(element_type);
+ b_.CreateStore(initial_value, initial_value_addr);
+
+ llvm::Value* warp_exists = b_.CreateICmpULT(
+ thread_id_info.thread_id_x,
+ constant(mapping_scheme.GetNumThreadsX() / kWarpSize));
+
+ llvm::Value* selected_value = b_.CreateSelect(
+ warp_exists, block_accum_addr, initial_value_addr);
+
+ EmitFullWarpShuffleDownLoopForReduce(
+ reducers[i], element_type,
+ /*block_accum_addr*/ selected_value);
+ ksl.If(is_zero(thread_id_info.thread_id_x), [&] {
+ TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
+ *reducers[i], output_address, block_accum_addr));
+ });
+ });
+
} else {
- llvm::GlobalVariable* shared_cache = reduction_info.GetSharedCache()[i];
- auto addr_cast = [&](llvm::Value* input, llvm::Twine name = "") {
- // __shared__ memory uses a different address space, so we cast it to
- // global address space before writing or reading.
- return b_.CreateAddrSpaceCast(
- input,
- llvm::PointerType::get(input->getType()->getPointerElementType(),
- /*AddressSpace=*/0),
- name);
- };
- llvm::Value* shmem_output_addr = addr_cast(
+ llvm::Value* shmem_output_addr = shared_to_global(
b_.CreateInBoundsGEP(shared_cache, {b_.getInt32(0), constant(j),
thread_id_info.thread_id_x,
thread_id_info.thread_id_y}),
@@ -2311,19 +2353,18 @@
llvm::Value* current_output_value = b_.CreateLoad(current_output);
b_.CreateStore(current_output_value, shmem_output_addr);
- EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
+ EmitSyncThreads();
// Get transposed element from shared memory.
- llvm::Value* shmem_transposed_addr = addr_cast(b_.CreateInBoundsGEP(
- shared_cache,
- {b_.getInt32(0), constant(j), thread_id_info.thread_id_y,
- thread_id_info.thread_id_x},
- "shmem_transposed_addr"));
+ llvm::Value* shmem_transposed_addr =
+ shared_to_global(b_.CreateInBoundsGEP(
+ shared_cache,
+ {b_.getInt32(0), constant(j), thread_id_info.thread_id_y,
+ thread_id_info.thread_id_x},
+ "shmem_transposed_addr"));
- EmitFullWarpShuffleDownLoopForReduce(
- reducers[i],
- partial_result_addresses[i]->getType()->getElementType(),
- shmem_transposed_addr);
+ EmitFullWarpShuffleDownLoopForReduce(reducers[i], element_type,
+ shmem_transposed_addr);
// Some threads in the block are completely outside of the bound of the
// tensor, so they should not write any output at all.
@@ -2335,13 +2376,10 @@
b_.CreateICmpULT(thread_id_info.thread_id_x,
tiling_kernel_info.output_tile_bounds[kDimY]));
- KernelSupportLibrary ksl(&b_);
- ksl.If(b_.CreateAnd(has_output, b_.CreateICmpEQ(thread_id_info.lane_id,
- constant(0))),
- [&] {
- TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
- *reducers[i], output_address, shmem_transposed_addr));
- });
+ ksl.If(b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
+ TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
+ *reducers[i], output_address, shmem_transposed_addr));
+ });
}
}
}
@@ -3050,7 +3088,16 @@
}
int64 num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize;
- int64 num_threads_x = kWarpSize;
+ int64 num_threads_x = [&] {
+ if (reduction_dimensions.is_row_reduction) {
+ return std::min(
+ kWarpSize * kWarpSize,
+ RoundUpToNearest(CeilOfRatio(reduction_dimensions.dimensions[2],
+ reduction_tiling[2]),
+ kWarpSize));
+ }
+ return kWarpSize;
+ }();
KernelMappingScheme mapping_scheme(
reduction_dimensions.dimensions,
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 3d6d095..a49f8c4 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -181,11 +181,38 @@
// Generates code for reduction to contiguous dimensions.
//
- // TODO(cheshire): Pseudocode for row reduction.
- // Column reduction uses the following algorithm described in CUDA-like
+ // Row reduction uses the following algorithm described in CUDA-like
// pseudocode:
//
// ```
+ // __global__ void reduce(int num_rows, float *in, float out) {
+ // __shared__ float[32] cache;
+ // int offset = blockDim.x * blockIdx.x + threadIdx.x;
+ // if (offset >= num_rows) return;
+ // int tile_bound = std::min(offset + kTileSizeX, num_rows);
+ // float accum = 0;
+ // for (int i=offset; i<num_rows; i+= blockDim.x) {
+ // accum += in[i];
+ // }
+ // accum = warp_reduce(accum);
+ // if (threadIdx.x % kWarpSize == 0) {
+ // cache[threadIdx.x / kWarpSize] = accum;
+ // }
+ // __syncthreads();
+ // if (threadIdx.x / kWarpSize == 0) {
+ // bool warp_exists = threadIdx.x < (blockDim.x / kWarpSize);
+ // float block_accum = warp_exists ? cache[threadIdx.x % kWarpSize] : 0;
+ // block_accum = warp_reduce(accum);
+ // if (threadIdx.x == 0) {
+ // out += block_accum;
+ // }
+ // }
+ // }
+ // ```
+ //
+ // Column reduction uses the following algorithm:
+ //
+ // ```
// void reduce(float** in, float* out) {
// __shared__ float[32][33] cache;
// int thread_id = GetThreadId();
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
index d4183ef..eeab8d4 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
@@ -170,6 +170,14 @@
return &reduction_input_addresses_;
}
+ std::vector<llvm::Value*>* GetMutableInitialValues() {
+ return &initial_values_;
+ }
+
+ absl::Span<llvm::Value* const> GetInitialValues() const {
+ return initial_values_;
+ }
+
// Returns the address of the input element to perform the reduction with.
absl::Span<llvm::AllocaInst* const> GetReductionInputAddresses() const {
return reduction_input_addresses_;
@@ -189,6 +197,7 @@
private:
std::vector<llvm::GlobalVariable*> shared_cache_;
+ std::vector<llvm::Value*> initial_values_;
const KernelMappingScheme mapping_scheme_;
AddressVector partial_result_addresses_;
AddressVector reduction_input_addresses_;