[XLA:GPU] [NFC] Remove unused function, simplify some checks
PiperOrigin-RevId: 395585703
Change-Id: I8cffec344febccd0f1f5fe339e8521908ebfa2e1
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index bf46c59..d1699b1 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -3861,23 +3861,12 @@
return reduction_codegen_state;
}
-void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
- absl::Span<HloComputation* const> reducers,
- absl::Span<llvm::AllocaInst* const> partial_result_addresses,
- int threads_per_block) {
- CHECK_EQ(reducers.size(), partial_result_addresses.size());
- for (int i = 0; i != reducers.size(); i++) {
- EmitFullWarpShuffleDownLoopForReduce(
- reducers[i], partial_result_addresses[i]->getType()->getElementType(),
- partial_result_addresses[i], threads_per_block);
- }
-}
-
void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
- HloComputation* reducer, llvm::Type* element_type,
- llvm::Value* partial_result_address, int threads_per_block) {
+ HloComputation* reducer, llvm::Value* partial_result_address) {
// This only works when the block size is a multiple of 32 threads.
- CHECK_EQ(threads_per_block % 32, 0);
+ llvm::Type* element_type =
+ llvm::cast<llvm::PointerType>(partial_result_address->getType())
+ ->getElementType();
for (int distance = 16; distance >= 1; distance /= 2) {
int bit_width = llvm_ir::GetSizeInBits(element_type);
llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
@@ -4071,8 +4060,7 @@
const KernelMappingScheme& mapping_scheme =
reduction_info.GetKernelMappingScheme();
- EmitFullWarpShuffleDownLoopForReduce(reducer, element_type, current_output,
- mapping_scheme.GetThreadsPerBlock());
+ EmitFullWarpShuffleDownLoopForReduce(reducer, current_output);
llvm::Value* warp_id =
b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize));
ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
@@ -4102,9 +4090,8 @@
llvm::Value* selected_value =
b_.CreateSelect(warp_exists, block_accum_addr, initial_value_addr);
- EmitFullWarpShuffleDownLoopForReduce(reducer, element_type,
- /*block_accum_addr*/ selected_value,
- mapping_scheme.GetThreadsPerBlock());
+ EmitFullWarpShuffleDownLoopForReduce(reducer,
+ /*block_accum_addr*/ selected_value);
ksl.If("reduction_write_output", is_zero(thread_id_info.thread_id_x), [&] {
if (reduction_info.IsRaceFree()) {
VLOG(10) << "Using deterministic reductions: writing out "
@@ -4155,9 +4142,7 @@
thread_id_info.thread_id_x},
"shmem_transposed_addr"));
- EmitFullWarpShuffleDownLoopForReduce(reducer, element_type,
- shmem_transposed_addr,
- mapping_scheme.GetThreadsPerBlock());
+ EmitFullWarpShuffleDownLoopForReduce(reducer, shmem_transposed_addr);
// Some warps in the block are completely outside of the bound of the
// tensor, so they should not write any output at all.
@@ -4908,6 +4893,7 @@
CHECK(!reducers.empty()) << " expect at least one reduce instructions.";
const KernelMappingScheme& mapping_scheme =
reduction_info.GetKernelMappingScheme();
+ CHECK_EQ(mapping_scheme.GetThreadsPerBlock() % 32, 0);
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
mapping_scheme.GetThreadsPerBlock());
llvm::Type* index_ty =
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index e49d2bb..dcc1f13 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -543,19 +543,10 @@
const Shape& input_shape,
const FusionLayoutAnalysis& layout_analysis);
- // For each reducer, emits the shuffle-down loop to accumulate the partial
- // result to the global result.
- void EmitFullWarpShuffleDownLoopForAllReduces(
- absl::Span<HloComputation* const> reducers,
- absl::Span<llvm::AllocaInst* const> partial_result_addresses,
- int threads_per_block);
-
// Emits shuffle-down reduction for the `partial_result_address` using the
// reduction computation `reducer` over types `element_type`.
- void EmitFullWarpShuffleDownLoopForReduce(HloComputation* reducer,
- llvm::Type* element_type,
- llvm::Value* partial_result_address,
- int threads_per_block);
+ void EmitFullWarpShuffleDownLoopForReduce(
+ HloComputation* reducer, llvm::Value* partial_result_address);
StatusOr<std::unique_ptr<Thunk>> BuildKernelThunkImpl(
absl::string_view name, Thunk::ThunkInfo thunk_info,