Roll forward with a fix

PiperOrigin-RevId: 304156244
Change-Id: Ib14f8613aa5de72de6f2f8117d98b73d6ed5e297
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 528a847..089f604 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -106,6 +106,11 @@
 const auto kDimZ = KernelMappingScheme::DimZ;
 const auto kDimTot = KernelMappingScheme::DimTot;
 
+const auto kLinearIndexingX = KernelMappingScheme::LinearIndexingX;
+const auto kStridedIndexingX = KernelMappingScheme::StridedIndexingX;
+const auto kLinearStridedIndexingX =
+    KernelMappingScheme::LinearStridedIndexingX;
+
 // If a dimensions is smaller than this, untiled transposition may be more
 // efficient.
 const int64 kMinDimensionToTransposeTiled = 16;
@@ -1863,9 +1868,8 @@
 bool MayPreventVectorization(const HloInstruction& hlo) {
   if (hlo.opcode() == HloOpcode::kFusion) {
     return absl::c_any_of(hlo.fused_instructions_computation()->instructions(),
-                          [](const HloInstruction* instr) {
+                          [&](const HloInstruction* instr) {
                             switch (instr->opcode()) {
-                              case HloOpcode::kReduce:
                               case HloOpcode::kReduceWindow:
                               case HloOpcode::kSort:
                               case HloOpcode::kDot:
@@ -1892,6 +1896,10 @@
       default:
         return false;
     }
+  } else if (hlo.opcode() == HloOpcode::kReduce) {
+    // TODO(nouiz): check if the to_apply() attribute contains instruction
+    // that break LLVM vectorization.
+    return false;
   }
   return true;
 }
@@ -1920,13 +1928,59 @@
                                     llvm::Value* thread_id_x,
                                     llvm::Type* index_ty,
                                     llvm::IRBuilder<>* b) {
-  if (mapping_scheme.DilatedX()) {
+  auto constant = [&](int64 val) {
+    return llvm::ConstantInt::get(index_ty, val);
+  };
+  if (mapping_scheme.GetIndexingOrder() == kStridedIndexingX) {
     return thread_id_x;
+  } else if (mapping_scheme.GetIndexingOrder() == kLinearStridedIndexingX) {
+    return b->CreateMul(thread_id_x, constant(mapping_scheme.GetVectorSize()));
   }
+  CHECK_EQ(mapping_scheme.GetIndexingOrder(), kLinearIndexingX);
   int64 x_num_steps =
       mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX();
-  return b->CreateMul(thread_id_x,
-                      llvm::ConstantInt::get(index_ty, x_num_steps));
+  return b->CreateMul(thread_id_x, constant(x_num_steps));
+}
+
+// Calls `emit_elem_function()` `x_num_steps` times.  If
+// `vector_size`==1, then each element index passed to
+// `emit_elem_function()` will be separated by `step_x`. If `vector_size`>1,
+// then it must be a multiple of `x_num_steps`.  In that case, it
+// triggers a different indexing order that is vectorizable by
+// LLVM. It generates many groups of calls to `emit_elem_function`. Each
+// group is separated by `step_x` elements.  Inside a group, elements
+// are consecutive. If `check_x_tile_bounds` is true, then it will check
+// if the element index is in bound compared to `tile_width` before
+// calling `emit_elem_function`.
+static void UnrollInnerTileLoop(
+    bool check_x_tile_bounds, int64 x_num_steps, int64 step_x,
+    int64 vector_size, const string& loop_name, KernelSupportLibrary* ksl,
+    llvm::Value* start_offset_x, llvm::Value* y_loc, llvm::Value* tile_width,
+    const IrArray::Index& source_idx, llvm::IRBuilder<>* b,
+    const IrEmitterUnnested::EmitElementFunction* emit_elem_function) {
+  llvm::Type* index_ty = tile_width->getType();
+  auto constant = [&](int64 val) {
+    return llvm::ConstantInt::get(index_ty, val);
+  };
+  for (int64 j = 0; j < x_num_steps / vector_size; j++) {
+    for (int64 i = 0; i < vector_size; i++) {
+      int64 linear_index = j * vector_size + i;
+      llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i),
+                                        start_offset_x, "x_loc");
+      IrArray::Index source_idx_x =
+          source_idx.AddOffsetToDim(y_loc, kDimY, b)
+              .AddOffsetToDim(constant(j * step_x * vector_size + i), kDimX, b);
+      auto emit_element = [&] {
+        return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index);
+      };
+      if (check_x_tile_bounds) {
+        ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width),
+                emit_element);
+      } else {
+        emit_element();
+      }
+    }
+  }
 }
 
 void IrEmitterUnnested::EmitTile(
@@ -1951,7 +2005,9 @@
   // of threads.
   // Otherwise, the stride is one, but we multiply each offset by the limit of
   // number of steps which can be made.
-  int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1;
+  int64 step_x =
+      mapping_scheme.GetIndexingOrder() == kLinearIndexingX ? 1 : num_threads_x;
+  int64 vector_size = mapping_scheme.GetVectorSize();
 
   IrArray::Index source_idx =
       tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
@@ -1987,21 +2043,29 @@
              llvm::Value* y_loc =
                  b_.CreateAdd(thread_id_info.thread_id_y,
                               b_.CreateMul(y_indvar, num_threads_y));
-             for (int64 j = 0; j < x_num_steps; j++) {
-               llvm::Value* x_loc =
-                   b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
-               IrArray::Index source_idx_x =
-                   source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
-                       .AddOffsetToDim(constant(j * step_x), kDimX, &b_);
-               auto emit_element = [&] {
-                 return emit_elem_function(source_idx_x, y_loc, x_loc, j);
-               };
-               if (!x_tile_fits) {
-                 ksl->If(loop_name + "_x_in_tile",
-                         b_.CreateICmpULT(x_loc, tile_width), emit_element);
-               } else {
-                 emit_element();
-               }
+             auto unrollInnerTileLoop = [&](bool check_x_tile_bounds) {
+               return UnrollInnerTileLoop(check_x_tile_bounds, x_num_steps,
+                                          step_x, vector_size, loop_name, ksl,
+                                          start_offset_x, y_loc, tile_width,
+                                          source_idx, &b_, &emit_elem_function);
+             };
+
+             // Only take this path when we unroll in a way vectorizable by
+             // LLVM. Special case when the tile doesn't fit completely for even
+             // row size. For odd row size every other row isn't aligned to the
+             // vectorized size, so it can't be vectorized by LLVM.
+             if (!x_tile_fits &&
+                 mapping_scheme.GetIndexingOrder() == kLinearStridedIndexingX) {
+               ksl->If(
+                   loop_name + "_is_full_tile",
+                   // For the last block, tile_width will be the number of
+                   // elements left.
+                   b_.CreateICmpEQ(constant(mapping_scheme.GetTileSizeX()),
+                                   tile_width),
+                   [&] { unrollInnerTileLoop(/*check_x_tile_bounds=*/false); },
+                   [&] { unrollInnerTileLoop(/*check_x_tile_bounds=*/true); });
+             } else {
+               unrollInnerTileLoop(/*check_x_tile_bounds=*/!x_tile_fits);
              }
            });
 }
@@ -2035,6 +2099,19 @@
     const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
     const KernelMappingScheme& kernel_mapping_scheme) {
   DCHECK_EQ(normalized_shape_index.size(), 3);
+  // If the normalization only add a new dimensions of size 1,
+  // generate simpler indexing. LLVM doesn't always simplify the more
+  // complicated indexing and this prevents it from vectorizing some
+  // cases. We do this only for major_to_minor memory layout.
+  if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
+      unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] &&
+      unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] &&
+      unnormalized_shape.layout().minor_to_major(1) == 0) {
+    DCHECK_EQ(normalized_shape_index.dims()[0], 1);
+    auto multidim = normalized_shape_index.multidim();
+    return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape,
+                          normalized_shape_index.GetType());
+  }
   llvm::Value* linear = normalized_shape_index.Linearize(
       kernel_mapping_scheme.GetDimsInElems(), b_);
   return IrArray::Index(linear, unnormalized_shape, b_);
@@ -2077,21 +2154,6 @@
   }
 }
 
-// Gets the number of partial results accumulated by a single thread performing
-// reduction.
-static int GetNumberOfPartialResults(
-    const ReductionCodegenInfo& reduction_info) {
-  const KernelMappingScheme& mapping_scheme =
-      reduction_info.GetKernelMappingScheme();
-  if (reduction_info.IsRowReduction()) {
-    return 1;
-  }
-  int64 num_partial_results = mapping_scheme.DilatedX() ? 1 : 2;
-  CHECK_EQ(num_partial_results,
-           (mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX()));
-  return num_partial_results;
-}
-
 void IrEmitterUnnested::EmitPrologueForReduction(
     HloInstruction* unnested_hlo, ReductionCodegenInfo* reduction_info,
     absl::Span<HloInstruction* const> reduce_instructions,
@@ -2118,7 +2180,7 @@
     llvm::AllocaInst* reduction_input_address = Alloca(element_type);
     reduction_input_addresses->push_back(reduction_input_address);
 
-    int num_partial_results = GetNumberOfPartialResults(*reduction_info);
+    int num_partial_results = reduction_info->GetNumPartialResults();
     AddressVector* partial_result_addresses =
         reduction_info->GetMutablePartialResultAddresses();
     llvm::AllocaInst* partial_result_address =
@@ -2270,7 +2332,7 @@
   absl::Span<llvm::AllocaInst* const> partial_result_addresses =
       reduction_info.GetPartialResultAddresses();
 
-  int num_partial_results = GetNumberOfPartialResults(reduction_info);
+  int num_partial_results = reduction_info.GetNumPartialResults();
 
   // Emit an atomic operation that accumulates the partial reduction to the
   // output element. For row reduction, this is only for lane 0 due to the
@@ -2484,7 +2546,7 @@
   // GetElementPointer with array types. This enables the vectorization of
   // the computation for different partial results. Use this index if
   // 'num_partial_results > 1'.
-  int num_partial_results = GetNumberOfPartialResults(reduction_info);
+  int num_partial_results = reduction_info.GetNumPartialResults();
   auto index_without_linear = IrArray::Index(
       input_index.multidim(), reduction_operand_shape, input_index.GetType());
 
@@ -2670,7 +2732,8 @@
                                      /*tile_sizes=*/{1, kWarpSize, kWarpSize},
                                      /*num_threads_y=*/kNumRows,
                                      /*num_threads_x=*/kWarpSize,
-                                     /*is_dilated_x=*/false);
+                                     /*indexing_order=*/kLinearIndexingX,
+                                     /*vector_size=*/1);
   LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
                                      mapping_scheme.GetThreadsPerBlock());
   llvm::Type* index_type =
@@ -3111,15 +3174,6 @@
   std::array<int64, 3> reduction_tiling =
       GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits,
                          &ir_emitter_context_->device_description());
-  bool dilated_x =
-      reduction_dimensions.is_row_reduction ||
-      !IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape,
-                                            reduction_dimensions.dimensions[2]);
-
-  if (!dilated_x && !reduction_dimensions.is_row_reduction) {
-    // Vectorized loads: a single thread reduces two adjacent columns.
-    reduction_tiling[2] *= 2;
-  }
 
   int64 num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize;
   int64 num_threads_x = [&] {
@@ -3133,12 +3187,54 @@
     return kWarpSize;
   }();
 
+  bool tile_fit = reduction_dimensions.dimensions[kDimX] %
+                      (reduction_tiling[2] * num_threads_x) ==
+                  0;
+
+  int cc_major = 0, cc_minor = 0;
+  ir_emitter_context_->device_description().cuda_compute_capability(&cc_major,
+                                                                    &cc_minor);
+
+  int num_partial_results = 1;
+  KernelMappingScheme::IndexingOrder indexing_order = [&]() {
+    if (reduction_dimensions.is_row_reduction &&
+        // P100, only try to vectorize+coales memory access when the
+        // tile size fits exactly and dtypes <= 32 bits
+        ((cc_major == 6 && smallest_input_dtype_bits <= 32 && tile_fit) ||
+         // On V100, only try to vectorize+coales memory access for
+         // rows of even size.  For odd row sizes, every other row
+         // isn't aligned, so it can't be vectorized.
+         (cc_major >= 7 && reduction_dimensions.dimensions[2] % 2 == 0))) {
+      return kLinearStridedIndexingX;
+    } else if (!reduction_dimensions.is_row_reduction &&
+               IsUnrollingColumnReductionBeneficial(
+                   unnested_hlo, input_shape,
+                   reduction_dimensions.dimensions[2])) {
+      num_partial_results = 2;
+      reduction_tiling[2] *= num_partial_results;
+      return kLinearIndexingX;
+    } else {
+      return kStridedIndexingX;
+    }
+  }();
+
+  int vector_size = 1;
+  if (indexing_order == kLinearStridedIndexingX) {
+    if (reduction_dimensions.dimensions[2] % 2 == 0 &&
+        // Assuming XLA will perform the unrolling and LLVM will vectorize,
+        // disable the unroll for the cases that LLVM doesn't vectorize.
+        !MayPreventVectorization(*unnested_hlo)) {
+      vector_size = 2;
+    } else {
+      indexing_order = kStridedIndexingX;
+    }
+  }
   KernelMappingScheme mapping_scheme(
       reduction_dimensions.dimensions,
       {reduction_tiling[0], reduction_tiling[1] * num_threads_y,
        reduction_tiling[2] * num_threads_x},
-      num_threads_y, num_threads_x, dilated_x);
-  return ReductionCodegenInfo(mapping_scheme,
+      num_threads_y, num_threads_x, indexing_order, vector_size);
+  return ReductionCodegenInfo(mapping_scheme, num_partial_results,
                               reduction_dimensions.is_row_reduction);
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
index eeab8d4..cd690c9 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
@@ -76,19 +76,33 @@
 class KernelMappingScheme {
  public:
   enum { DimZ = 0, DimY, DimX, DimTot };
+  enum IndexingOrder {
+    // Thread reads consecutive elements.
+    LinearIndexingX,
+    // Thread reads strided elements while keeping memory coalescing.
+    StridedIndexingX,
+    // Thread reads a few consecutive elements then take a strided
+    // step. This can trigger vectorized reads and keep memory
+    // coalescing.
+    LinearStridedIndexingX
+  };
+
   KernelMappingScheme(absl::Span<const int64> dims_in_elems,
                       absl::Span<const int64> tile_sizes, int64 num_threads_y,
-                      int64 num_threads_x, bool is_dilated_x)
+                      int64 num_threads_x, IndexingOrder indexing_order,
+                      int vector_size)
       : dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]},
         tile_sizes_{tile_sizes[0], tile_sizes[1], tile_sizes[2]},
         num_threads_x_(num_threads_x),
         num_threads_y_(num_threads_y),
-        dilated_x_(is_dilated_x) {
+        indexing_order_(indexing_order),
+        vector_size_(vector_size) {
     CHECK_EQ(tile_sizes[1] % num_threads_y_, 0);
     CHECK_EQ(tile_sizes[2] % num_threads_x_, 0);
     VLOG(10) << "dims_in_elems_ = " << absl::StrJoin(dims_in_elems_, ",");
-    if (!dilated_x_) {
-      // dilated_x_=false is for the purpose of vectorization, which requires
+    if (indexing_order != LinearIndexingX) {
+      // StridedIndexingX, and LinearStridedIndexingX
+      // is for the purpose of vectorization, which requires
       // GetTileSizeFor(DimX) to be a multiplier of num_threads_x_.
       CHECK_EQ(GetTileSizeFor(DimX) % num_threads_x_, 0);
     }
@@ -118,7 +132,8 @@
     return GetNumThreadsX() * GetNumThreadsY();
   }
 
-  bool DilatedX() const { return dilated_x_; }
+  IndexingOrder GetIndexingOrder() const { return indexing_order_; }
+  int GetVectorSize() const { return vector_size_; }
 
  private:
   // The number of elements in each dimension.
@@ -133,12 +148,17 @@
   // Number of threads used to process elements in the Y direction of a tile.
   const int64 num_threads_y_;
 
-  // When num_threads_x threads process a total of tile_size_x elements in the
-  // X dimension of a tile, each threads process n=tile_size_x/num_threads_x
-  // elements. When dilated_x=false, the n elements processed by a thread are
-  // contiguous. On the other hand, when dilated_x=true the n elements are
-  // dilated by a factor of num_threads_x.
-  const bool dilated_x_;
+  // When num_threads_x threads process a total of tile_size_x
+  // elements in the X dimension of a tile, each threads process
+  // n=tile_size_x/num_threads_x elements.
+  // indexing_order defines which tile's elements each thread reads.
+  const IndexingOrder indexing_order_;
+
+  // vector_size_ only supported for row reduction and must be a divisor
+  // of tile_sizes_[2]/num_threads_x.  Interesting values are 2 and 4
+  // to trigger vectorized loads on GPUs while keeping memory
+  // coalescing.
+  const int vector_size_;
 };
 
 // Information to support the code generation for a tiled reduction kernel.
@@ -146,8 +166,15 @@
 class ReductionCodegenInfo {
  public:
   explicit ReductionCodegenInfo(KernelMappingScheme mapping_scheme,
-                                bool is_row_reduction)
-      : mapping_scheme_(mapping_scheme), is_row_reduction_(is_row_reduction) {}
+                                int num_partial_results, bool is_row_reduction)
+      : mapping_scheme_(mapping_scheme),
+        num_partial_results_(num_partial_results),
+        is_row_reduction_(is_row_reduction) {
+    if (num_partial_results > 1) {
+      CHECK_EQ(num_partial_results, (mapping_scheme.GetTileSizeX() /
+                                     mapping_scheme.GetNumThreadsX()));
+    }
+  }
 
   const KernelMappingScheme& GetKernelMappingScheme() const {
     return mapping_scheme_;
@@ -183,6 +210,7 @@
     return reduction_input_addresses_;
   }
 
+  int GetNumPartialResults() const { return num_partial_results_; }
   bool IsRowReduction() const { return is_row_reduction_; }
 
   // Gets a pointer to a mutable shared cache used by reduction.
@@ -201,6 +229,7 @@
   const KernelMappingScheme mapping_scheme_;
   AddressVector partial_result_addresses_;
   AddressVector reduction_input_addresses_;
+  int num_partial_results_;
   bool is_row_reduction_;
 };
 
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 1fd51c7..e04dba4 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -165,6 +165,33 @@
 )
 
 tf_cc_test(
+    name = "reduction_vectorization_test",
+    srcs = [
+        "reduction_vectorization_test.cc",
+    ],
+    tags = tf_cuda_tests_tags() + ["no_rocm"],
+    deps = [
+        ":gpu_codegen_test",
+        "//tensorflow/compiler/xla:debug_options_flags",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla/service:gpu_plugin",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_module_config",
+        "//tensorflow/compiler/xla/service:hlo_parser",
+        "//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
+        "//tensorflow/compiler/xla/service/gpu:gpu_executable",
+        "//tensorflow/compiler/xla/tests:filecheck",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/stream_executor/lib",
+        "@com_google_absl//absl/memory",
+    ],
+)
+
+tf_cc_test(
     name = "reduction_dimension_grouper_test",
     srcs = [
         "reduction_dimension_grouper_test.cc",
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc
new file mode 100644
index 0000000..5f27df0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc
@@ -0,0 +1,299 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/filecheck.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+class ReductionVectorizationTest : public GpuCodegenTest {};
+
+TEST_F(ReductionVectorizationTest, Power2) {
+  const char* hlo_text = R"(
+HloModule ReducePower2
+
+%max_ {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y)
+}
+
+ENTRY %main {
+  %param_0 = f32[5,131072] parameter(0)
+  %constant.3 = f32[] constant(0)
+  ROOT %reduce.8 = f32[5] reduce(f32[5,131072] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  se::StreamExecutor* executor = backend().default_stream_executor();
+  int cc_major = 0, cc_minor = 0;
+  executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
+                                                           &cc_minor);
+  string expected_ptx;
+  if (cc_major >= 6) {
+    expected_ptx = R"(
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+)";
+  } else {
+    expected_ptx = R"(
+CHECK-NOT: ld.global.nc.v2.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+)";
+  }
+  CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(ReductionVectorizationTest, TileFit) {
+  const char* hlo_text = R"(
+HloModule ReduceTileFit
+
+%max_ {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y)
+}
+
+ENTRY %main {
+  %param_0 = f32[5,122880] parameter(0)
+  %constant.3 = f32[] constant(0)
+  ROOT %reduce.8 = f32[5] reduce(f32[5,122880] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  se::StreamExecutor* executor = backend().default_stream_executor();
+  int cc_major = 0, cc_minor = 0;
+  executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
+                                                           &cc_minor);
+  string expected_ptx;
+  if (cc_major >= 6) {
+    expected_ptx = R"(
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+)";
+  } else {
+    expected_ptx = R"(
+CHECK-NOT: ld.global.nc.v2.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+)";
+  }
+  CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(ReductionVectorizationTest, EvenColumns) {
+  const char* hlo_text = R"(
+HloModule ReducePower2
+
+%max_ {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y)
+}
+
+ENTRY %main {
+  %param_0 = f32[5,131070] parameter(0)
+  %constant.3 = f32[] constant(0)
+  ROOT %reduce.8 = f32[5] reduce(f32[5,131070] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  se::StreamExecutor* executor = backend().default_stream_executor();
+  int cc_major = 0, cc_minor = 0;
+  executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
+                                                           &cc_minor);
+  string expected_ptx;
+  if (cc_major >= 7) {
+    expected_ptx = R"(
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK-NOT: ld.global.nc.v2.f32
+// TODO: Make this a vectorized load
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+)";
+  } else {
+    expected_ptx = R"(
+CHECK-NOT: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+)";
+  }
+  CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(ReductionVectorizationTest, DisableOddColumns) {
+  const char* hlo_text = R"(
+HloModule ReduceTileFit
+
+%max_ {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %maximum.7 = f32[] maximum(%x, %y)
+}
+
+ENTRY %main {
+  %param_0 = f32[5,131071] parameter(0)
+  %constant.3 = f32[] constant(0)
+  ROOT %reduce.8 = f32[5] reduce(f32[5,131071] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  CompileAndOptionallyVerifyPtx(std::move(optimized_module),
+                                R"(
+CHECK-NOT: ld.global.nc.v2.f32
+CHECK-NOT: ld.global.nc.v4.f32
+CHECK-NOT: ld.global.nc.u64
+CHECK-NOT: ld.global.u64
+)");
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(ReductionVectorizationTest, Exp) {
+  const char* hlo_text = R"(
+HloModule DisableSin
+
+%add_float {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add.17 = f32[] add(f32[] %x, f32[] %y)
+}
+
+ENTRY %main {
+  %arg0.1 = f32[5,131072] parameter(0)
+  %sine = f32[5,131072] exponential(f32[5,131072] %arg0.1)
+  %constant.0 = f32[] constant(0)
+  ROOT %reduce.18 = f32[5] reduce(f32[5,131072] %sine, f32[] %constant.0), dimensions={1}, to_apply=%add_float
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  se::StreamExecutor* executor = backend().default_stream_executor();
+  int cc_major = 0, cc_minor = 0;
+  executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
+                                                           &cc_minor);
+  string expected_ptx;
+  if (cc_major >= 6) {
+    expected_ptx = R"(
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+)";
+  } else {
+    expected_ptx = R"(
+CHECK-NOT: ld.global.nc.v2.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+)";
+  }
+  CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(ReductionVectorizationTest, DisableSin) {
+  const char* hlo_text = R"(
+HloModule DisableSin
+
+%add_float {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add.17 = f32[] add(f32[] %x, f32[] %y)
+}
+
+ENTRY %main {
+  %arg0.1 = f32[5,131072] parameter(0)
+  %sine = f32[5,131072] sine(f32[5,131072] %arg0.1)
+  %constant.0 = f32[] constant(0)
+  ROOT %reduce.18 = f32[5] reduce(f32[5,131072] %sine, f32[] %constant.0), dimensions={1}, to_apply=%add_float
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+                          ParseAndReturnVerifiedModule(hlo_text));
+  CompileAndOptionallyVerifyPtx(std::move(optimized_module),
+                                R"(
+CHECK-NOT: ld.global.nc.v2.f32
+CHECK-NOT: ld.global.nc.v4.f32
+CHECK-NOT: ld.global.nc.u64
+CHECK-NOT: ld.global.u64
+)");
+
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+}  // namespace
+}  // namespace gpu
+}  // namespace xla