[XLA:GPU] Do not generate inter-warp shuffle reduce in row reduction if not required.

PiperOrigin-RevId: 465363219
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 023f9ca..3897d0b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -4058,9 +4058,16 @@
       selected_values.push_back({selected_value, element_type});
     }
 
-    EmitFullWarpShuffleDownLoopForReduce(reducer,
-                                         absl::MakeSpan(selected_values),
-                                         tiling_scheme.GetNumThreadsPerBlock());
+    // If only one warp is present in the block, then we don't need inter-warp
+    // reduction.
+    // TODO(b/241414088) If only warp is present, then inter-warp communication
+    // using shared memory and synchronization using barrier is also unnecessary
+    // and should be removed.
+    if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) {
+      EmitFullWarpShuffleDownLoopForReduce(
+          reducer, absl::MakeSpan(selected_values),
+          tiling_scheme.GetNumThreadsPerBlock());
+    }
 
     ksl.If("reduction_write_output", is_zero(thread_id_info.thread_id_x), [&] {
       for (int oidx = 0; oidx < num_outputs; oidx++) {
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc
index 0b59468..9717f70 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc
@@ -112,10 +112,18 @@
   ROOT %reduce.277 = f32[6400,4,8]{2,1,0} reduce(f32[6400,4,8,32]{3,2,1,0} %param_0, f32[] %constant_0), dimensions={3}, to_apply=%search_fn
 }
 )";
+
   const char* expected_optimized_llvm_ir = R"(
 CHECK:  %[[thread_id:.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
 CHECK:  %[[masked_thread_id:.*]] = and i32 %[[thread_id]], 31
+// Verify that there is no comparison masking half the warp.
 CHECK-NOT: icmp ult i32 %[[masked_thread_id]], 16
+// Verify that we only do one warp reducton by checking that there are 6
+// shfl.sync corresponding to 1 declaration and 5 shuffle instructions.  The
+// second warp reduction was originally produced for inter-warp reduction
+// which we have now optimized away.
+CHECK-COUNT-6: llvm.nvvm.shfl.sync.down.f32
+CHECK-NOT: llvm.nvvm.shfl.sync.down.f32
 )";
 
   CompileAndVerifyIr(hlo_text, expected_optimized_llvm_ir, true);
@@ -123,6 +131,7 @@
   // Check that there is a single scalar load.
   const char* expected_ptx = R"(
 CHECK: ld.global.nc.f32
+CHECK: shfl.sync.down
 CHECK-NOT: ld.global.nc.f32
 CHECK-NOT: ld.global.v2.f32
 )";