[XLA/GPU] Use intra-block communication for row reduction

This improves precision and performance of row reduction (the new algorithm
requires only one global atomic per block, that is, up to <tiling size> * 1024
elements can be reduced without atomics), and allows us to reduce the "step
size" of tree reduction.

The algorithm roughly is:

```
  __global__ void reduce(int num_rows, float *in, float out) {
    __shared__ float[32] cache;
    int offset = blockDim.x * blockIdx.x + threadIdx.x;
    if (offset >= num_rows) return;
    int tile_bound = std::min(offset + kTileSizeX, num_rows);
    float accum = 0;
    for (int i=offset; i<num_rows; i+= blockDim.x) {
      accum += in[i];
    }
    accum = warp_reduce(accum);
    if (threadIdx.x % kWarpSize == 0) {
      cache[threadIdx.x / kWarpSize] = accum;
    }
    __syncthreads();
    if (threadIdx.x / kWarpSize == 0) {
      float block_accum = cache[threadIdx.x % kWarpSize];
      block_accum = warp_reduce(accum);
      if (threadIdx.x == 0) {
        out += block_accum;
      }
    }
  }
```

PiperOrigin-RevId: 295224089
Change-Id: Ifd6e2990d6b7724a7d9a27f13492eddf23b27e82
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index c535325..8646b4b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -133,7 +133,8 @@
       CHECK_EQ(reduction_dimensions.dimensions[0], 1);
       return {tile_z, 1, 16};
     }
-    if (reduction_dimensions.dimensions[2] % (kWarpSize * 64) == 0) {
+    if (reduction_dimensions.dimensions[2] % (kWarpSize * kWarpSize * 64) ==
+        0) {
       return {tile_z, 1, 64};
     }
     return {tile_z, 1, 8};
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 37204af..c6b167f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2127,29 +2127,36 @@
       Store(init_ir_value,
             InBoundsGEP(partial_result_address, {b_.getInt32(i)}));
     }
+    reduction_info->GetMutableInitialValues()->push_back(init_ir_value);
 
-    // Allocate __shared__ cache[num_partial_results][num_threads][num_threads +
-    // 1], where num_threads == num_threads_x == num_threads_y.  The "+1" is
-    // used to avoid bank conflicts.
-    if (!reduction_info->IsRowReduction()) {
-      auto& mapping_scheme = reduction_info->GetKernelMappingScheme();
-      int64 num_threads = mapping_scheme.GetNumThreadsX();
-      CHECK_EQ(num_threads, mapping_scheme.GetNumThreadsY());
-      llvm::Type* primitive_type = llvm_ir::PrimitiveTypeToIrType(
-          reduce_inst->shape().element_type(), module_);
-      llvm::Type* buffer_type = llvm::ArrayType::get(
-          llvm::ArrayType::get(
-              llvm::ArrayType::get(primitive_type, num_threads + 1),
-              num_threads),
-          num_partial_results);
-
-      llvm::GlobalVariable* shared_cache_per_reduce =
-          llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
-                                            buffer_type,
-                                            absl::StrCat("shared_cache_", i));
-      reduction_info->GetMutableSharedCache()->push_back(
-          shared_cache_per_reduce);
-    }
+    auto& mapping_scheme = reduction_info->GetKernelMappingScheme();
+    int64 num_threads_x = mapping_scheme.GetNumThreadsX();
+    llvm::Type* primitive_type = llvm_ir::PrimitiveTypeToIrType(
+        reduce_inst->shape().element_type(), module_);
+    llvm::Type* buffer_type = [&] {
+      if (reduction_info->IsRowReduction()) {
+        // Allocate __shared__ cache[num_partial_results][num_threads].
+        return llvm::ArrayType::get(
+            llvm::ArrayType::get(primitive_type, num_threads_x),
+            num_partial_results);
+      } else {
+        // Allocate __shared__
+        // cache[num_partial_results][num_threads][num_threads + 1], where
+        // num_threads == num_threads_x == num_threads_y.  The "+1" is used to
+        // avoid bank conflicts.
+        CHECK_EQ(num_threads_x, mapping_scheme.GetNumThreadsY());
+        return llvm::ArrayType::get(
+            llvm::ArrayType::get(
+                llvm::ArrayType::get(primitive_type, num_threads_x + 1),
+                num_threads_x),
+            num_partial_results);
+      }
+    }();
+    llvm::GlobalVariable* shared_cache_per_reduce =
+        llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
+                                          buffer_type,
+                                          absl::StrCat("shared_cache_", i));
+    reduction_info->GetMutableSharedCache()->push_back(shared_cache_per_reduce);
   }
 }
 
@@ -2241,13 +2248,6 @@
   int num_reduces = reducers.size();
   absl::Span<llvm::AllocaInst* const> partial_result_addresses =
       reduction_info.GetPartialResultAddresses();
-  if (reduction_info.IsRowReduction()) {
-    EmitFullWarpShuffleDownLoopForAllReduces(reducers,
-                                             partial_result_addresses);
-    llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
-        ICmpEQ(thread_id_info.lane_id, constant(0)), "lane_id_is_zero", &b_);
-    llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
-  }
 
   int num_partial_results = GetNumberOfPartialResults(reduction_info);
 
@@ -2285,25 +2285,67 @@
                                   element_index.GetType());
       llvm::Value* output_address = output_array.EmitArrayElementAddress(
           output_index, &b_, "output_element_address");
-
       llvm::Value* current_output = b_.CreateInBoundsGEP(
           partial_result_addresses[i], {constant(j)}, "current_output");
 
+      llvm::GlobalVariable* shared_cache = reduction_info.GetSharedCache()[i];
+
+      // __shared__ memory uses a different address space, so we cast it to
+      // global address space before writing or reading.
+      auto shared_to_global = [&](llvm::Value* input, llvm::Twine name = "") {
+        return b_.CreateAddrSpaceCast(
+            input,
+            llvm::PointerType::get(input->getType()->getPointerElementType(),
+                                   /*AddressSpace=*/0),
+            name);
+      };
+
+      auto is_zero = [&](llvm::Value* value) {
+        return b_.CreateICmpEQ(value, constant(0));
+      };
+
+      KernelSupportLibrary ksl(&b_);
+      llvm::Type* element_type =
+          partial_result_addresses[i]->getType()->getElementType();
       if (reduction_info.IsRowReduction()) {
-        TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
-            *reducers[i], output_address, current_output));
+        EmitFullWarpShuffleDownLoopForReduce(reducers[i], element_type,
+                                             current_output);
+        llvm::Value* warp_id =
+            b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize));
+        ksl.If(is_zero(thread_id_info.lane_id), [&] {
+          llvm::Value* shmem_output_addr =
+              shared_to_global(b_.CreateInBoundsGEP(
+                  shared_cache, {b_.getInt32(0), constant(j), warp_id}));
+          b_.CreateStore(b_.CreateLoad(current_output), shmem_output_addr);
+        });
+
+        EmitSyncThreads();
+        ksl.If(is_zero(warp_id), [&] {
+          llvm::Value* block_accum_addr = shared_to_global(b_.CreateInBoundsGEP(
+              shared_cache,
+              {b_.getInt32(0), constant(j), thread_id_info.lane_id}));
+          llvm::Value* initial_value = reduction_info.GetInitialValues()[i];
+          llvm::Value* initial_value_addr = b_.CreateAlloca(element_type);
+          b_.CreateStore(initial_value, initial_value_addr);
+
+          llvm::Value* warp_exists = b_.CreateICmpULT(
+              thread_id_info.thread_id_x,
+              constant(mapping_scheme.GetNumThreadsX() / kWarpSize));
+
+          llvm::Value* selected_value = b_.CreateSelect(
+              warp_exists, block_accum_addr, initial_value_addr);
+
+          EmitFullWarpShuffleDownLoopForReduce(
+              reducers[i], element_type,
+              /*block_accum_addr*/ selected_value);
+          ksl.If(is_zero(thread_id_info.thread_id_x), [&] {
+            TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
+                *reducers[i], output_address, block_accum_addr));
+          });
+        });
+
       } else {
-        llvm::GlobalVariable* shared_cache = reduction_info.GetSharedCache()[i];
-        auto addr_cast = [&](llvm::Value* input, llvm::Twine name = "") {
-          // __shared__ memory uses a different address space, so we cast it to
-          // global address space before writing or reading.
-          return b_.CreateAddrSpaceCast(
-              input,
-              llvm::PointerType::get(input->getType()->getPointerElementType(),
-                                     /*AddressSpace=*/0),
-              name);
-        };
-        llvm::Value* shmem_output_addr = addr_cast(
+        llvm::Value* shmem_output_addr = shared_to_global(
             b_.CreateInBoundsGEP(shared_cache, {b_.getInt32(0), constant(j),
                                                 thread_id_info.thread_id_x,
                                                 thread_id_info.thread_id_y}),
@@ -2311,19 +2353,18 @@
         llvm::Value* current_output_value = b_.CreateLoad(current_output);
         b_.CreateStore(current_output_value, shmem_output_addr);
 
-        EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
+        EmitSyncThreads();
 
         // Get transposed element from shared memory.
-        llvm::Value* shmem_transposed_addr = addr_cast(b_.CreateInBoundsGEP(
-            shared_cache,
-            {b_.getInt32(0), constant(j), thread_id_info.thread_id_y,
-             thread_id_info.thread_id_x},
-            "shmem_transposed_addr"));
+        llvm::Value* shmem_transposed_addr =
+            shared_to_global(b_.CreateInBoundsGEP(
+                shared_cache,
+                {b_.getInt32(0), constant(j), thread_id_info.thread_id_y,
+                 thread_id_info.thread_id_x},
+                "shmem_transposed_addr"));
 
-        EmitFullWarpShuffleDownLoopForReduce(
-            reducers[i],
-            partial_result_addresses[i]->getType()->getElementType(),
-            shmem_transposed_addr);
+        EmitFullWarpShuffleDownLoopForReduce(reducers[i], element_type,
+                                             shmem_transposed_addr);
 
         // Some threads in the block are completely outside of the bound of the
         // tensor, so they should not write any output at all.
@@ -2335,13 +2376,10 @@
             b_.CreateICmpULT(thread_id_info.thread_id_x,
                              tiling_kernel_info.output_tile_bounds[kDimY]));
 
-        KernelSupportLibrary ksl(&b_);
-        ksl.If(b_.CreateAnd(has_output, b_.CreateICmpEQ(thread_id_info.lane_id,
-                                                        constant(0))),
-               [&] {
-                 TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
-                     *reducers[i], output_address, shmem_transposed_addr));
-               });
+        ksl.If(b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
+          TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
+              *reducers[i], output_address, shmem_transposed_addr));
+        });
       }
     }
   }
@@ -3050,7 +3088,16 @@
   }
 
   int64 num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize;
-  int64 num_threads_x = kWarpSize;
+  int64 num_threads_x = [&] {
+    if (reduction_dimensions.is_row_reduction) {
+      return std::min(
+          kWarpSize * kWarpSize,
+          RoundUpToNearest(CeilOfRatio(reduction_dimensions.dimensions[2],
+                                       reduction_tiling[2]),
+                           kWarpSize));
+    }
+    return kWarpSize;
+  }();
 
   KernelMappingScheme mapping_scheme(
       reduction_dimensions.dimensions,
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 3d6d095..a49f8c4 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -181,11 +181,38 @@
 
   // Generates code for reduction to contiguous dimensions.
   //
-  // TODO(cheshire): Pseudocode for row reduction.
-  // Column reduction uses the following algorithm described in CUDA-like
+  // Row reduction uses the following algorithm described in CUDA-like
   // pseudocode:
   //
   // ```
+  //  __global__ void reduce(int num_rows, float *in, float out) {
+  //    __shared__ float[32] cache;
+  //    int offset = blockDim.x * blockIdx.x + threadIdx.x;
+  //    if (offset >= num_rows) return;
+  //    int tile_bound = std::min(offset + kTileSizeX, num_rows);
+  //    float accum = 0;
+  //    for (int i=offset; i<num_rows; i+= blockDim.x) {
+  //      accum += in[i];
+  //    }
+  //    accum = warp_reduce(accum);
+  //    if (threadIdx.x % kWarpSize == 0) {
+  //      cache[threadIdx.x / kWarpSize] = accum;
+  //    }
+  //    __syncthreads();
+  //    if (threadIdx.x / kWarpSize == 0) {
+  //      bool warp_exists = threadIdx.x < (blockDim.x / kWarpSize);
+  //      float block_accum = warp_exists ? cache[threadIdx.x % kWarpSize] : 0;
+  //      block_accum = warp_reduce(accum);
+  //      if (threadIdx.x == 0) {
+  //        out += block_accum;
+  //      }
+  //    }
+  //  }
+  // ```
+  //
+  // Column reduction uses the following algorithm:
+  //
+  // ```
   // void reduce(float** in, float* out) {
   //   __shared__ float[32][33] cache;
   //   int thread_id = GetThreadId();
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
index d4183ef..eeab8d4 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
@@ -170,6 +170,14 @@
     return &reduction_input_addresses_;
   }
 
+  std::vector<llvm::Value*>* GetMutableInitialValues() {
+    return &initial_values_;
+  }
+
+  absl::Span<llvm::Value* const> GetInitialValues() const {
+    return initial_values_;
+  }
+
   // Returns the address of the input element to perform the reduction with.
   absl::Span<llvm::AllocaInst* const> GetReductionInputAddresses() const {
     return reduction_input_addresses_;
@@ -189,6 +197,7 @@
 
  private:
   std::vector<llvm::GlobalVariable*> shared_cache_;
+  std::vector<llvm::Value*> initial_values_;
   const KernelMappingScheme mapping_scheme_;
   AddressVector partial_result_addresses_;
   AddressVector reduction_input_addresses_;