[XLA:GPU] [NFC] Remove unused function, simplify some checks

PiperOrigin-RevId: 395585703
Change-Id: I8cffec344febccd0f1f5fe339e8521908ebfa2e1
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index bf46c59..d1699b1 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -3861,23 +3861,12 @@
   return reduction_codegen_state;
 }
 
-void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
-    absl::Span<HloComputation* const> reducers,
-    absl::Span<llvm::AllocaInst* const> partial_result_addresses,
-    int threads_per_block) {
-  CHECK_EQ(reducers.size(), partial_result_addresses.size());
-  for (int i = 0; i != reducers.size(); i++) {
-    EmitFullWarpShuffleDownLoopForReduce(
-        reducers[i], partial_result_addresses[i]->getType()->getElementType(),
-        partial_result_addresses[i], threads_per_block);
-  }
-}
-
 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
-    HloComputation* reducer, llvm::Type* element_type,
-    llvm::Value* partial_result_address, int threads_per_block) {
+    HloComputation* reducer, llvm::Value* partial_result_address) {
   // This only works when the block size is a multiple of 32 threads.
-  CHECK_EQ(threads_per_block % 32, 0);
+  llvm::Type* element_type =
+      llvm::cast<llvm::PointerType>(partial_result_address->getType())
+          ->getElementType();
   for (int distance = 16; distance >= 1; distance /= 2) {
     int bit_width = llvm_ir::GetSizeInBits(element_type);
     llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
@@ -4071,8 +4060,7 @@
   const KernelMappingScheme& mapping_scheme =
       reduction_info.GetKernelMappingScheme();
 
-  EmitFullWarpShuffleDownLoopForReduce(reducer, element_type, current_output,
-                                       mapping_scheme.GetThreadsPerBlock());
+  EmitFullWarpShuffleDownLoopForReduce(reducer, current_output);
   llvm::Value* warp_id =
       b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize));
   ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
@@ -4102,9 +4090,8 @@
     llvm::Value* selected_value =
         b_.CreateSelect(warp_exists, block_accum_addr, initial_value_addr);
 
-    EmitFullWarpShuffleDownLoopForReduce(reducer, element_type,
-                                         /*block_accum_addr*/ selected_value,
-                                         mapping_scheme.GetThreadsPerBlock());
+    EmitFullWarpShuffleDownLoopForReduce(reducer,
+                                         /*block_accum_addr*/ selected_value);
     ksl.If("reduction_write_output", is_zero(thread_id_info.thread_id_x), [&] {
       if (reduction_info.IsRaceFree()) {
         VLOG(10) << "Using deterministic reductions: writing out "
@@ -4155,9 +4142,7 @@
        thread_id_info.thread_id_x},
       "shmem_transposed_addr"));
 
-  EmitFullWarpShuffleDownLoopForReduce(reducer, element_type,
-                                       shmem_transposed_addr,
-                                       mapping_scheme.GetThreadsPerBlock());
+  EmitFullWarpShuffleDownLoopForReduce(reducer, shmem_transposed_addr);
 
   // Some warps in the block are completely outside of the bound of the
   // tensor, so they should not write any output at all.
@@ -4908,6 +4893,7 @@
   CHECK(!reducers.empty()) << " expect at least one reduce instructions.";
   const KernelMappingScheme& mapping_scheme =
       reduction_info.GetKernelMappingScheme();
+  CHECK_EQ(mapping_scheme.GetThreadsPerBlock() % 32, 0);
   LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
                                      mapping_scheme.GetThreadsPerBlock());
   llvm::Type* index_ty =
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index e49d2bb..dcc1f13 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -543,19 +543,10 @@
                           const Shape& input_shape,
                           const FusionLayoutAnalysis& layout_analysis);
 
-  // For each reducer, emits the shuffle-down loop to accumulate the partial
-  // result to the global result.
-  void EmitFullWarpShuffleDownLoopForAllReduces(
-      absl::Span<HloComputation* const> reducers,
-      absl::Span<llvm::AllocaInst* const> partial_result_addresses,
-      int threads_per_block);
-
   // Emits shuffle-down reduction for the `partial_result_address` using the
   // reduction computation `reducer` over types `element_type`.
-  void EmitFullWarpShuffleDownLoopForReduce(HloComputation* reducer,
-                                            llvm::Type* element_type,
-                                            llvm::Value* partial_result_address,
-                                            int threads_per_block);
+  void EmitFullWarpShuffleDownLoopForReduce(
+      HloComputation* reducer, llvm::Value* partial_result_address);
 
   StatusOr<std::unique_ptr<Thunk>> BuildKernelThunkImpl(
       absl::string_view name, Thunk::ThunkInfo thunk_info,