Roll forward with a fix
PiperOrigin-RevId: 304156244
Change-Id: Ib14f8613aa5de72de6f2f8117d98b73d6ed5e297
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 528a847..089f604 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -106,6 +106,11 @@
const auto kDimZ = KernelMappingScheme::DimZ;
const auto kDimTot = KernelMappingScheme::DimTot;
+const auto kLinearIndexingX = KernelMappingScheme::LinearIndexingX;
+const auto kStridedIndexingX = KernelMappingScheme::StridedIndexingX;
+const auto kLinearStridedIndexingX =
+ KernelMappingScheme::LinearStridedIndexingX;
+
// If a dimensions is smaller than this, untiled transposition may be more
// efficient.
const int64 kMinDimensionToTransposeTiled = 16;
@@ -1863,9 +1868,8 @@
bool MayPreventVectorization(const HloInstruction& hlo) {
if (hlo.opcode() == HloOpcode::kFusion) {
return absl::c_any_of(hlo.fused_instructions_computation()->instructions(),
- [](const HloInstruction* instr) {
+ [&](const HloInstruction* instr) {
switch (instr->opcode()) {
- case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kSort:
case HloOpcode::kDot:
@@ -1892,6 +1896,10 @@
default:
return false;
}
+ } else if (hlo.opcode() == HloOpcode::kReduce) {
+ // TODO(nouiz): check if the to_apply() attribute contains instruction
+ // that break LLVM vectorization.
+ return false;
}
return true;
}
@@ -1920,13 +1928,59 @@
llvm::Value* thread_id_x,
llvm::Type* index_ty,
llvm::IRBuilder<>* b) {
- if (mapping_scheme.DilatedX()) {
+ auto constant = [&](int64 val) {
+ return llvm::ConstantInt::get(index_ty, val);
+ };
+ if (mapping_scheme.GetIndexingOrder() == kStridedIndexingX) {
return thread_id_x;
+ } else if (mapping_scheme.GetIndexingOrder() == kLinearStridedIndexingX) {
+ return b->CreateMul(thread_id_x, constant(mapping_scheme.GetVectorSize()));
}
+ CHECK_EQ(mapping_scheme.GetIndexingOrder(), kLinearIndexingX);
int64 x_num_steps =
mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX();
- return b->CreateMul(thread_id_x,
- llvm::ConstantInt::get(index_ty, x_num_steps));
+ return b->CreateMul(thread_id_x, constant(x_num_steps));
+}
+
+// Calls `emit_elem_function()` `x_num_steps` times. If
+// `vector_size`==1, then each element index passed to
+// `emit_elem_function()` will be separated by `step_x`. If `vector_size`>1,
+// then it must be a multiple of `x_num_steps`. In that case, it
+// triggers a different indexing order that is vectorizable by
+// LLVM. It generates many groups of calls to `emit_elem_function`. Each
+// group is separated by `step_x` elements. Inside a group, elements
+// are consecutive. If `check_x_tile_bounds` is true, then it will check
+// if the element index is in bound compared to `tile_width` before
+// calling `emit_elem_function`.
+static void UnrollInnerTileLoop(
+ bool check_x_tile_bounds, int64 x_num_steps, int64 step_x,
+ int64 vector_size, const string& loop_name, KernelSupportLibrary* ksl,
+ llvm::Value* start_offset_x, llvm::Value* y_loc, llvm::Value* tile_width,
+ const IrArray::Index& source_idx, llvm::IRBuilder<>* b,
+ const IrEmitterUnnested::EmitElementFunction* emit_elem_function) {
+ llvm::Type* index_ty = tile_width->getType();
+ auto constant = [&](int64 val) {
+ return llvm::ConstantInt::get(index_ty, val);
+ };
+ for (int64 j = 0; j < x_num_steps / vector_size; j++) {
+ for (int64 i = 0; i < vector_size; i++) {
+ int64 linear_index = j * vector_size + i;
+ llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i),
+ start_offset_x, "x_loc");
+ IrArray::Index source_idx_x =
+ source_idx.AddOffsetToDim(y_loc, kDimY, b)
+ .AddOffsetToDim(constant(j * step_x * vector_size + i), kDimX, b);
+ auto emit_element = [&] {
+ return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index);
+ };
+ if (check_x_tile_bounds) {
+ ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width),
+ emit_element);
+ } else {
+ emit_element();
+ }
+ }
+ }
}
void IrEmitterUnnested::EmitTile(
@@ -1951,7 +2005,9 @@
// of threads.
// Otherwise, the stride is one, but we multiply each offset by the limit of
// number of steps which can be made.
- int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1;
+ int64 step_x =
+ mapping_scheme.GetIndexingOrder() == kLinearIndexingX ? 1 : num_threads_x;
+ int64 vector_size = mapping_scheme.GetVectorSize();
IrArray::Index source_idx =
tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
@@ -1987,21 +2043,29 @@
llvm::Value* y_loc =
b_.CreateAdd(thread_id_info.thread_id_y,
b_.CreateMul(y_indvar, num_threads_y));
- for (int64 j = 0; j < x_num_steps; j++) {
- llvm::Value* x_loc =
- b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
- IrArray::Index source_idx_x =
- source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
- .AddOffsetToDim(constant(j * step_x), kDimX, &b_);
- auto emit_element = [&] {
- return emit_elem_function(source_idx_x, y_loc, x_loc, j);
- };
- if (!x_tile_fits) {
- ksl->If(loop_name + "_x_in_tile",
- b_.CreateICmpULT(x_loc, tile_width), emit_element);
- } else {
- emit_element();
- }
+ auto unrollInnerTileLoop = [&](bool check_x_tile_bounds) {
+ return UnrollInnerTileLoop(check_x_tile_bounds, x_num_steps,
+ step_x, vector_size, loop_name, ksl,
+ start_offset_x, y_loc, tile_width,
+ source_idx, &b_, &emit_elem_function);
+ };
+
+ // Only take this path when we unroll in a way vectorizable by
+ // LLVM. Special case when the tile doesn't fit completely for even
+ // row size. For odd row size every other row isn't aligned to the
+ // vectorized size, so it can't be vectorized by LLVM.
+ if (!x_tile_fits &&
+ mapping_scheme.GetIndexingOrder() == kLinearStridedIndexingX) {
+ ksl->If(
+ loop_name + "_is_full_tile",
+ // For the last block, tile_width will be the number of
+ // elements left.
+ b_.CreateICmpEQ(constant(mapping_scheme.GetTileSizeX()),
+ tile_width),
+ [&] { unrollInnerTileLoop(/*check_x_tile_bounds=*/false); },
+ [&] { unrollInnerTileLoop(/*check_x_tile_bounds=*/true); });
+ } else {
+ unrollInnerTileLoop(/*check_x_tile_bounds=*/!x_tile_fits);
}
});
}
@@ -2035,6 +2099,19 @@
const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
const KernelMappingScheme& kernel_mapping_scheme) {
DCHECK_EQ(normalized_shape_index.size(), 3);
+ // If the normalization only add a new dimensions of size 1,
+ // generate simpler indexing. LLVM doesn't always simplify the more
+ // complicated indexing and this prevents it from vectorizing some
+ // cases. We do this only for major_to_minor memory layout.
+ if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
+ unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] &&
+ unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] &&
+ unnormalized_shape.layout().minor_to_major(1) == 0) {
+ DCHECK_EQ(normalized_shape_index.dims()[0], 1);
+ auto multidim = normalized_shape_index.multidim();
+ return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape,
+ normalized_shape_index.GetType());
+ }
llvm::Value* linear = normalized_shape_index.Linearize(
kernel_mapping_scheme.GetDimsInElems(), b_);
return IrArray::Index(linear, unnormalized_shape, b_);
@@ -2077,21 +2154,6 @@
}
}
-// Gets the number of partial results accumulated by a single thread performing
-// reduction.
-static int GetNumberOfPartialResults(
- const ReductionCodegenInfo& reduction_info) {
- const KernelMappingScheme& mapping_scheme =
- reduction_info.GetKernelMappingScheme();
- if (reduction_info.IsRowReduction()) {
- return 1;
- }
- int64 num_partial_results = mapping_scheme.DilatedX() ? 1 : 2;
- CHECK_EQ(num_partial_results,
- (mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX()));
- return num_partial_results;
-}
-
void IrEmitterUnnested::EmitPrologueForReduction(
HloInstruction* unnested_hlo, ReductionCodegenInfo* reduction_info,
absl::Span<HloInstruction* const> reduce_instructions,
@@ -2118,7 +2180,7 @@
llvm::AllocaInst* reduction_input_address = Alloca(element_type);
reduction_input_addresses->push_back(reduction_input_address);
- int num_partial_results = GetNumberOfPartialResults(*reduction_info);
+ int num_partial_results = reduction_info->GetNumPartialResults();
AddressVector* partial_result_addresses =
reduction_info->GetMutablePartialResultAddresses();
llvm::AllocaInst* partial_result_address =
@@ -2270,7 +2332,7 @@
absl::Span<llvm::AllocaInst* const> partial_result_addresses =
reduction_info.GetPartialResultAddresses();
- int num_partial_results = GetNumberOfPartialResults(reduction_info);
+ int num_partial_results = reduction_info.GetNumPartialResults();
// Emit an atomic operation that accumulates the partial reduction to the
// output element. For row reduction, this is only for lane 0 due to the
@@ -2484,7 +2546,7 @@
// GetElementPointer with array types. This enables the vectorization of
// the computation for different partial results. Use this index if
// 'num_partial_results > 1'.
- int num_partial_results = GetNumberOfPartialResults(reduction_info);
+ int num_partial_results = reduction_info.GetNumPartialResults();
auto index_without_linear = IrArray::Index(
input_index.multidim(), reduction_operand_shape, input_index.GetType());
@@ -2670,7 +2732,8 @@
/*tile_sizes=*/{1, kWarpSize, kWarpSize},
/*num_threads_y=*/kNumRows,
/*num_threads_x=*/kWarpSize,
- /*is_dilated_x=*/false);
+ /*indexing_order=*/kLinearIndexingX,
+ /*vector_size=*/1);
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
mapping_scheme.GetThreadsPerBlock());
llvm::Type* index_type =
@@ -3111,15 +3174,6 @@
std::array<int64, 3> reduction_tiling =
GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits,
&ir_emitter_context_->device_description());
- bool dilated_x =
- reduction_dimensions.is_row_reduction ||
- !IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape,
- reduction_dimensions.dimensions[2]);
-
- if (!dilated_x && !reduction_dimensions.is_row_reduction) {
- // Vectorized loads: a single thread reduces two adjacent columns.
- reduction_tiling[2] *= 2;
- }
int64 num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize;
int64 num_threads_x = [&] {
@@ -3133,12 +3187,54 @@
return kWarpSize;
}();
+ bool tile_fit = reduction_dimensions.dimensions[kDimX] %
+ (reduction_tiling[2] * num_threads_x) ==
+ 0;
+
+ int cc_major = 0, cc_minor = 0;
+ ir_emitter_context_->device_description().cuda_compute_capability(&cc_major,
+ &cc_minor);
+
+ int num_partial_results = 1;
+ KernelMappingScheme::IndexingOrder indexing_order = [&]() {
+ if (reduction_dimensions.is_row_reduction &&
+ // P100, only try to vectorize+coales memory access when the
+ // tile size fits exactly and dtypes <= 32 bits
+ ((cc_major == 6 && smallest_input_dtype_bits <= 32 && tile_fit) ||
+ // On V100, only try to vectorize+coales memory access for
+ // rows of even size. For odd row sizes, every other row
+ // isn't aligned, so it can't be vectorized.
+ (cc_major >= 7 && reduction_dimensions.dimensions[2] % 2 == 0))) {
+ return kLinearStridedIndexingX;
+ } else if (!reduction_dimensions.is_row_reduction &&
+ IsUnrollingColumnReductionBeneficial(
+ unnested_hlo, input_shape,
+ reduction_dimensions.dimensions[2])) {
+ num_partial_results = 2;
+ reduction_tiling[2] *= num_partial_results;
+ return kLinearIndexingX;
+ } else {
+ return kStridedIndexingX;
+ }
+ }();
+
+ int vector_size = 1;
+ if (indexing_order == kLinearStridedIndexingX) {
+ if (reduction_dimensions.dimensions[2] % 2 == 0 &&
+ // Assuming XLA will perform the unrolling and LLVM will vectorize,
+ // disable the unroll for the cases that LLVM doesn't vectorize.
+ !MayPreventVectorization(*unnested_hlo)) {
+ vector_size = 2;
+ } else {
+ indexing_order = kStridedIndexingX;
+ }
+ }
KernelMappingScheme mapping_scheme(
reduction_dimensions.dimensions,
{reduction_tiling[0], reduction_tiling[1] * num_threads_y,
reduction_tiling[2] * num_threads_x},
- num_threads_y, num_threads_x, dilated_x);
- return ReductionCodegenInfo(mapping_scheme,
+ num_threads_y, num_threads_x, indexing_order, vector_size);
+ return ReductionCodegenInfo(mapping_scheme, num_partial_results,
reduction_dimensions.is_row_reduction);
}
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
index eeab8d4..cd690c9 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h
@@ -76,19 +76,33 @@
class KernelMappingScheme {
public:
enum { DimZ = 0, DimY, DimX, DimTot };
+ enum IndexingOrder {
+ // Thread reads consecutive elements.
+ LinearIndexingX,
+ // Thread reads strided elements while keeping memory coalescing.
+ StridedIndexingX,
+ // Thread reads a few consecutive elements then take a strided
+ // step. This can trigger vectorized reads and keep memory
+ // coalescing.
+ LinearStridedIndexingX
+ };
+
KernelMappingScheme(absl::Span<const int64> dims_in_elems,
absl::Span<const int64> tile_sizes, int64 num_threads_y,
- int64 num_threads_x, bool is_dilated_x)
+ int64 num_threads_x, IndexingOrder indexing_order,
+ int vector_size)
: dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]},
tile_sizes_{tile_sizes[0], tile_sizes[1], tile_sizes[2]},
num_threads_x_(num_threads_x),
num_threads_y_(num_threads_y),
- dilated_x_(is_dilated_x) {
+ indexing_order_(indexing_order),
+ vector_size_(vector_size) {
CHECK_EQ(tile_sizes[1] % num_threads_y_, 0);
CHECK_EQ(tile_sizes[2] % num_threads_x_, 0);
VLOG(10) << "dims_in_elems_ = " << absl::StrJoin(dims_in_elems_, ",");
- if (!dilated_x_) {
- // dilated_x_=false is for the purpose of vectorization, which requires
+ if (indexing_order != LinearIndexingX) {
+ // StridedIndexingX, and LinearStridedIndexingX
+ // is for the purpose of vectorization, which requires
// GetTileSizeFor(DimX) to be a multiplier of num_threads_x_.
CHECK_EQ(GetTileSizeFor(DimX) % num_threads_x_, 0);
}
@@ -118,7 +132,8 @@
return GetNumThreadsX() * GetNumThreadsY();
}
- bool DilatedX() const { return dilated_x_; }
+ IndexingOrder GetIndexingOrder() const { return indexing_order_; }
+ int GetVectorSize() const { return vector_size_; }
private:
// The number of elements in each dimension.
@@ -133,12 +148,17 @@
// Number of threads used to process elements in the Y direction of a tile.
const int64 num_threads_y_;
- // When num_threads_x threads process a total of tile_size_x elements in the
- // X dimension of a tile, each threads process n=tile_size_x/num_threads_x
- // elements. When dilated_x=false, the n elements processed by a thread are
- // contiguous. On the other hand, when dilated_x=true the n elements are
- // dilated by a factor of num_threads_x.
- const bool dilated_x_;
+ // When num_threads_x threads process a total of tile_size_x
+ // elements in the X dimension of a tile, each threads process
+ // n=tile_size_x/num_threads_x elements.
+ // indexing_order defines which tile's elements each thread reads.
+ const IndexingOrder indexing_order_;
+
+ // vector_size_ only supported for row reduction and must be a divisor
+ // of tile_sizes_[2]/num_threads_x. Interesting values are 2 and 4
+ // to trigger vectorized loads on GPUs while keeping memory
+ // coalescing.
+ const int vector_size_;
};
// Information to support the code generation for a tiled reduction kernel.
@@ -146,8 +166,15 @@
class ReductionCodegenInfo {
public:
explicit ReductionCodegenInfo(KernelMappingScheme mapping_scheme,
- bool is_row_reduction)
- : mapping_scheme_(mapping_scheme), is_row_reduction_(is_row_reduction) {}
+ int num_partial_results, bool is_row_reduction)
+ : mapping_scheme_(mapping_scheme),
+ num_partial_results_(num_partial_results),
+ is_row_reduction_(is_row_reduction) {
+ if (num_partial_results > 1) {
+ CHECK_EQ(num_partial_results, (mapping_scheme.GetTileSizeX() /
+ mapping_scheme.GetNumThreadsX()));
+ }
+ }
const KernelMappingScheme& GetKernelMappingScheme() const {
return mapping_scheme_;
@@ -183,6 +210,7 @@
return reduction_input_addresses_;
}
+ int GetNumPartialResults() const { return num_partial_results_; }
bool IsRowReduction() const { return is_row_reduction_; }
// Gets a pointer to a mutable shared cache used by reduction.
@@ -201,6 +229,7 @@
const KernelMappingScheme mapping_scheme_;
AddressVector partial_result_addresses_;
AddressVector reduction_input_addresses_;
+ int num_partial_results_;
bool is_row_reduction_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 1fd51c7..e04dba4 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -165,6 +165,33 @@
)
tf_cc_test(
+ name = "reduction_vectorization_test",
+ srcs = [
+ "reduction_vectorization_test.cc",
+ ],
+ tags = tf_cuda_tests_tags() + ["no_rocm"],
+ deps = [
+ ":gpu_codegen_test",
+ "//tensorflow/compiler/xla:debug_options_flags",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/service:gpu_plugin",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
+ "//tensorflow/compiler/xla/service/gpu:gpu_executable",
+ "//tensorflow/compiler/xla/tests:filecheck",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/stream_executor/lib",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+tf_cc_test(
name = "reduction_dimension_grouper_test",
srcs = [
"reduction_dimension_grouper_test.cc",
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc
new file mode 100644
index 0000000..5f27df0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc
@@ -0,0 +1,299 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/filecheck.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+class ReductionVectorizationTest : public GpuCodegenTest {};
+
+TEST_F(ReductionVectorizationTest, Power2) {
+ const char* hlo_text = R"(
+HloModule ReducePower2
+
+%max_ {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y)
+}
+
+ENTRY %main {
+ %param_0 = f32[5,131072] parameter(0)
+ %constant.3 = f32[] constant(0)
+ ROOT %reduce.8 = f32[5] reduce(f32[5,131072] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ se::StreamExecutor* executor = backend().default_stream_executor();
+ int cc_major = 0, cc_minor = 0;
+ executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
+ &cc_minor);
+ string expected_ptx;
+ if (cc_major >= 6) {
+ expected_ptx = R"(
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+)";
+ } else {
+ expected_ptx = R"(
+CHECK-NOT: ld.global.nc.v2.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+)";
+ }
+ CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(ReductionVectorizationTest, TileFit) {
+ const char* hlo_text = R"(
+HloModule ReduceTileFit
+
+%max_ {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y)
+}
+
+ENTRY %main {
+ %param_0 = f32[5,122880] parameter(0)
+ %constant.3 = f32[] constant(0)
+ ROOT %reduce.8 = f32[5] reduce(f32[5,122880] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ se::StreamExecutor* executor = backend().default_stream_executor();
+ int cc_major = 0, cc_minor = 0;
+ executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
+ &cc_minor);
+ string expected_ptx;
+ if (cc_major >= 6) {
+ expected_ptx = R"(
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+)";
+ } else {
+ expected_ptx = R"(
+CHECK-NOT: ld.global.nc.v2.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+)";
+ }
+ CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(ReductionVectorizationTest, EvenColumns) {
+ const char* hlo_text = R"(
+HloModule ReducePower2
+
+%max_ {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y)
+}
+
+ENTRY %main {
+ %param_0 = f32[5,131070] parameter(0)
+ %constant.3 = f32[] constant(0)
+ ROOT %reduce.8 = f32[5] reduce(f32[5,131070] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ se::StreamExecutor* executor = backend().default_stream_executor();
+ int cc_major = 0, cc_minor = 0;
+ executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
+ &cc_minor);
+ string expected_ptx;
+ if (cc_major >= 7) {
+ expected_ptx = R"(
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK-NOT: ld.global.nc.v2.f32
+// TODO: Make this a vectorized load
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+)";
+ } else {
+ expected_ptx = R"(
+CHECK-NOT: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+)";
+ }
+ CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(ReductionVectorizationTest, DisableOddColumns) {
+ const char* hlo_text = R"(
+HloModule ReduceTileFit
+
+%max_ {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %maximum.7 = f32[] maximum(%x, %y)
+}
+
+ENTRY %main {
+ %param_0 = f32[5,131071] parameter(0)
+ %constant.3 = f32[] constant(0)
+ ROOT %reduce.8 = f32[5] reduce(f32[5,131071] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ CompileAndOptionallyVerifyPtx(std::move(optimized_module),
+ R"(
+CHECK-NOT: ld.global.nc.v2.f32
+CHECK-NOT: ld.global.nc.v4.f32
+CHECK-NOT: ld.global.nc.u64
+CHECK-NOT: ld.global.u64
+)");
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(ReductionVectorizationTest, Exp) {
+ const char* hlo_text = R"(
+HloModule DisableSin
+
+%add_float {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add.17 = f32[] add(f32[] %x, f32[] %y)
+}
+
+ENTRY %main {
+ %arg0.1 = f32[5,131072] parameter(0)
+ %sine = f32[5,131072] exponential(f32[5,131072] %arg0.1)
+ %constant.0 = f32[] constant(0)
+ ROOT %reduce.18 = f32[5] reduce(f32[5,131072] %sine, f32[] %constant.0), dimensions={1}, to_apply=%add_float
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ se::StreamExecutor* executor = backend().default_stream_executor();
+ int cc_major = 0, cc_minor = 0;
+ executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
+ &cc_minor);
+ string expected_ptx;
+ if (cc_major >= 6) {
+ expected_ptx = R"(
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+CHECK: ld.global.nc.v2.f32
+)";
+ } else {
+ expected_ptx = R"(
+CHECK-NOT: ld.global.nc.v2.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+CHECK: ld.global.nc.f32
+)";
+ }
+ CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+TEST_F(ReductionVectorizationTest, DisableSin) {
+ const char* hlo_text = R"(
+HloModule DisableSin
+
+%add_float {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add.17 = f32[] add(f32[] %x, f32[] %y)
+}
+
+ENTRY %main {
+ %arg0.1 = f32[5,131072] parameter(0)
+ %sine = f32[5,131072] sine(f32[5,131072] %arg0.1)
+ %constant.0 = f32[] constant(0)
+ ROOT %reduce.18 = f32[5] reduce(f32[5,131072] %sine, f32[] %constant.0), dimensions={1}, to_apply=%add_float
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ CompileAndOptionallyVerifyPtx(std::move(optimized_module),
+ R"(
+CHECK-NOT: ld.global.nc.v2.f32
+CHECK-NOT: ld.global.nc.v4.f32
+CHECK-NOT: ld.global.nc.u64
+CHECK-NOT: ld.global.u64
+)");
+
+ EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla