[XLA] [NFC] Make reduction rewriter tests more resilient: only run selected pass, do not run generated code
We rely on integration tests for numerical correctness/overall flow.
PiperOrigin-RevId: 454848300
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index da1ce6b..59b2075 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -273,17 +273,13 @@
srcs = [
"reduction_degenerate_dim_remover_test.cc",
],
- tags = tf_cuda_tests_tags(),
deps = [
- ":gpu_codegen_test",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
- "//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
- "//tensorflow/compiler/xla/service/gpu:gpu_executable",
+ "//tensorflow/compiler/xla/service/gpu:reduction_degenerate_dim_remover",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
@@ -300,24 +296,19 @@
srcs = [
"reduction_layout_normalizer_test.cc",
],
- tags = tf_cuda_tests_tags(),
deps = [
- ":gpu_codegen_test",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
- "//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
- "//tensorflow/compiler/xla/service/gpu:gpu_executable",
+ "//tensorflow/compiler/xla/service/gpu:reduction_layout_normalizer",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
- "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/memory",
],
)
@@ -327,17 +318,13 @@
srcs = [
"tree_reduction_rewriter_test.cc",
],
- tags = tf_cuda_tests_tags(),
deps = [
- ":gpu_codegen_test",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
- "//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
- "//tensorflow/compiler/xla/service/gpu:gpu_executable",
+ "//tensorflow/compiler/xla/service/gpu:tree_reduction_rewriter",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
@@ -346,6 +333,7 @@
"//tensorflow/core:test_main",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -408,17 +396,13 @@
srcs = [
"reduction_dimension_grouper_test.cc",
],
- tags = tf_cuda_tests_tags(),
deps = [
- ":gpu_codegen_test",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
- "//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
- "//tensorflow/compiler/xla/service/gpu:gpu_executable",
+ "//tensorflow/compiler/xla/service/gpu:reduction_dimension_grouper",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc
index ab9890f..e1ecf40 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc
@@ -13,10 +13,11 @@
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
+
+#include <optional>
#include <utility>
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
-#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
@@ -28,23 +29,20 @@
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
-namespace gpu {
namespace {
-class ReductionDegenerateDimRemoverTest : public GpuCodegenTest {
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
- debug_options.add_xla_disable_hlo_passes("reduction-layout-normalizer");
- debug_options.add_xla_disable_hlo_passes("reduction-dimension-grouper");
- debug_options.add_xla_disable_hlo_passes("reduction-splitter");
- debug_options.add_xla_disable_hlo_passes("gpu-tree-reduction-rewriter");
- return debug_options;
+class ReductionDegenerateDimRemoverTest : public HloTestBase {
+ public:
+ void CheckDegenerateDimRemover(absl::string_view hlo,
+ std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(hlo, gpu::ReductionDegenerateDimRemover{},
+ expected);
}
};
TEST_F(ReductionDegenerateDimRemoverTest, ReductionWithDegenerateDimensions) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithDegenerateDimensions
add {
@@ -62,16 +60,16 @@
)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: f32[] reduce(f32[3,4,5]{2,1,0} {{.+}}, f32[] {{.+}}), dimensions={0,1,2}, to_apply=%add
- )");
+ CheckDegenerateDimRemover(hlo, R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[reduce_2:%[^ ]+]] = f32[] reduce([[bitcast_0]], [[zero_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[bitcast_1_5:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[reduce_2]])
+ )");
}
TEST_F(ReductionDegenerateDimRemoverTest,
ReductionWithDegenerateDimensionsVariadic) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithDegenerateDimensions
argmax {
@@ -102,15 +100,20 @@
)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (f32[], u32[]) reduce(f32[3,4,5]{2,1,0} %bitcast.5, u32[3,4,5]{2,1,0} %bitcast.4, f32[] %zero_1, u32[] %zero_idx_1), dimensions={0,1,2}
+ CheckDegenerateDimRemover(hlo, R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[bitcast_1_2:%[^ ]+]] = u32[3,4,5]{2,1,0} bitcast([[idxs_3:%[^ ]+]])
+// CHECK: [[reduce_4:%[^ ]+]] = (f32[], u32[]) reduce([[bitcast_0]], [[bitcast_1_2]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
+// CHECK-NEXT: [[get_tuple_element_8:%[^ ]+]] = f32[] get-tuple-element([[reduce_4]]), index=0
+// CHECK-NEXT: [[bitcast_2_9:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_8]])
+// CHECK-NEXT: [[get_tuple_element_1_10:%[^ ]+]] = u32[] get-tuple-element([[reduce_4]]), index=1
+// CHECK-NEXT: [[bitcast_3_11:%[^ ]+]] = u32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_1_10]])
+// CHECK-NEXT: ROOT [[tuple_12:%[^ ]+]] = (f32[1,1,1,1]{3,2,1,0}, u32[1,1,1,1]{3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
)");
}
TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithDegenerateDimensions
add {
@@ -125,22 +128,13 @@
ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add
}
-
)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- // Copy instruction is added after bitcast because of copy-insertion pass,
- // so we check the entire hlo module to verify there is no reduce instruction
- // in this case.
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: ENTRY %main (input: f32[1,3,1,4,1,5,1]) -> f32[3,4,5,1] {
-// CHECK: %input = f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} parameter(0)
-// CHECK: %bitcast{{.+}} = f32[3,4,5,1]{3,2,1,0} bitcast(f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} %input)
-// CHECK: ROOT %copy{{.+}} = f32[3,4,5,1]{3,2,1,0} copy(f32[3,4,5,1]{3,2,1,0} %bitcast{{.+}})
+ CheckDegenerateDimRemover(hlo,
+ R"(
+// CHECK: ROOT [[bitcast_0:%[^ ]+]] = f32[3,4,5,1]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
)");
}
} // namespace
-} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc
index 4e45753..74818d1 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc
@@ -13,10 +13,11 @@
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
+
+#include <optional>
#include <utility>
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
-#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
@@ -25,21 +26,19 @@
#include "tensorflow/core/platform/test.h"
namespace xla {
-namespace gpu {
namespace {
-class ReductionDimensionGrouperTest : public GpuCodegenTest {
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
- debug_options.add_xla_disable_hlo_passes("reduction-layout-normalizer");
- debug_options.add_xla_disable_hlo_passes("layout-assignment");
- return debug_options;
+class ReductionDimensionGrouperTest : public HloTestBase {
+ public:
+ void CheckDimensionGrouper(absl::string_view hlo,
+ std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(hlo, gpu::ReductionDimensionGrouper{}, expected);
}
};
TEST_F(ReductionDimensionGrouperTest, ReductionWithGrouping) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReductionWithGrouping
add {
@@ -54,19 +53,18 @@
ROOT out = f32[100,10]{0,1} reduce(input, zero), dimensions={2,3}, to_apply=add
}
-
-
)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: f32[100,10]{0,1} reduce(f32[100,10,96]{2,1,0} {{.+}}, f32[] {{.+}}), dimensions={2}, to_apply=%add
+ CheckDimensionGrouper(hlo,
+ R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
+// CHECK: ROOT [[out_1_2:%[^ ]+]] = f32[100,10]{0,1} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
)");
}
TEST_F(ReductionDimensionGrouperTest, ReductionWithGroupingVariadic) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReductionWithGrouping
argmax {
@@ -94,16 +92,16 @@
ROOT out = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce(input, idxs, zero, zero_idx), dimensions={2,3}, to_apply=argmax
}
-
-
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce(f32[100,10,96]{2,1,0} %bitcast.3, u32[100,10,96]{2,1,0} %bitcast.2, f32[] %zero_1, u32[] %zero_idx_1), dimensions={2}
+ CheckDimensionGrouper(hlo, R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
+// CHECK: [[idxs_2:%[^ ]+]] = u32[100,10,32,3]{3,2,1,0} parameter(1)
+// CHECK: [[bitcast_1_3:%[^ ]+]] = u32[100,10,96]{2,1,0} bitcast([[idxs_2]])
+// CHECK: ROOT [[out_1_4:%[^ ]+]] = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_7:%[^ ]+]]
)");
}
} // namespace
-} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc
index 28a700e..20a58f9 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc
@@ -13,10 +13,11 @@
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
+
+#include <optional>
#include <utility>
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
-#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
@@ -25,25 +26,19 @@
#include "tensorflow/core/platform/test.h"
namespace xla {
-namespace gpu {
namespace {
-// TODO(b/210165681): The tests in this file are fragile to HLO op names.
-
-class ReductionLayoutNormalizerTest : public GpuCodegenTest {
- DebugOptions GetDebugOptionsForTest() override {
- DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
- debug_options.add_xla_disable_hlo_passes("reduction-dimension-grouper");
- debug_options.add_xla_disable_hlo_passes("reduction-splitter");
- debug_options.add_xla_disable_hlo_passes("layout-assignment");
- debug_options.add_xla_disable_hlo_passes("gpu-tree-reduction-rewriter");
- return debug_options;
+class ReductionLayoutNormalizerTest : public HloTestBase {
+ public:
+ void CheckReductionLayoutNormalizer(
+ absl::string_view hlo, std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(hlo, gpu::ReductionLayoutNormalizer{}, expected);
}
};
TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTest) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithLayoutChange
add {
@@ -61,15 +56,16 @@
)";
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: f32[4,12,12,16,5]{2,1,3,4,0} reduce(f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} {{.+}}, f32[] {{.+}}), dimensions={0,1,2}, to_apply=%add
+ CheckReductionLayoutNormalizer(hlo,
+ R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_1:%[^ ]+]])
+// CHECK: [[reduce_2:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} reduce([[bitcast_0]], [[constant0_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[bitcast_1_5:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[reduce_2]])
)");
}
TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTestVariadic) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithLayoutChangeVariadic
@@ -104,16 +100,24 @@
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: %reduce.1 = (f32[4,12,12,16,5]{2,1,3,4,0}, u32[4,12,12,16,5]{2,1,3,4,0}) reduce(f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} %bitcast.5, u32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} %bitcast.4, f32[] %constant0_1, u32[] %constant1_1), dimensions={0,1,2}, to_apply=%argmax
-//
+ CheckReductionLayoutNormalizer(hlo,
+ R"(
+// CHECK: [[arg0_0:%[^ ]+]] = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_0]])
+// CHECK: [[idxs_2:%[^ ]+]] = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
+// CHECK: [[bitcast_1_3:%[^ ]+]] = u32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[idxs_2]])
+// CHECK: [[reduce_4:%[^ ]+]] = (f32[4,12,12,16,5]{2,1,3,4,0}, u32[4,12,12,16,5]{2,1,3,4,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[constant0_5:%[^ ]+]], [[constant1_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
+// CHECK: [[get_tuple_element_8:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=0
+// CHECK: [[bitcast_2_9:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_8]])
+// CHECK: [[get_tuple_element_1_10:%[^ ]+]] = u32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=1
+// CHECK: [[bitcast_3_11:%[^ ]+]] = u32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_1_10]])
+// CHECK: ROOT [[tuple_12:%[^ ]+]] = (f32[4,5,16,12,12]{4,3,2,1,0}, u32[4,5,16,12,12]{4,3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
)");
}
TEST_F(ReductionLayoutNormalizerTest,
LayoutCanonicalizerTestVariadicDifferentLayouts) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithLayoutChangeVariadicDifferent
argmax {
@@ -147,27 +151,17 @@
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: %fused_computation (param_0.1: u32[2,3,4,7]) -> u32[7,2,3,4] {
-// CHECK: %param_0.1 = u32[2,3,4,7]{3,2,1,0} parameter(0)
-// CHECK: %copy.1 = u32[2,3,4,7]{2,1,0,3} copy(u32[2,3,4,7]{3,2,1,0} %param_0.1)
-// CHECK: ROOT %bitcast.2 = u32[7,2,3,4]{3,2,1,0} bitcast(u32[2,3,4,7]{2,1,0,3} %copy.1)
-// CHECK: }
-//
-// CHECK: ENTRY %main (arg0: f32[2,3,4,7], idxs: u32[2,3,4,7]) -> (f32[2,3,4], u32[2,3,4]) {
-// CHECK: %arg0 = f32[2,3,4,7]{2,1,0,3} parameter(0)
-// CHECK: %bitcast = f32[7,2,3,4]{3,2,1,0} bitcast(f32[2,3,4,7]{2,1,0,3} %arg0)
-// CHECK: %idxs = u32[2,3,4,7]{3,2,1,0} parameter(1)
-// CHECK: %fusion = u32[7,2,3,4]{3,2,1,0} fusion(u32[2,3,4,7]{3,2,1,0} %idxs), kind=kLoop, calls=%fused_computation
-// CHECK: %constant0 = f32[] constant(0)
-// CHECK: %constant1 = u32[] constant(0)
-// CHECK: ROOT %reduce0 = (f32[2,3,4]{2,1,0}, u32[2,3,4]{2,1,0}) reduce(f32[7,2,3,4]{3,2,1,0} %bitcast, u32[7,2,3,4]{3,2,1,0} %fusion, f32[] %constant0, u32[] %constant1), dimensions={0}, to_apply=%argmax
-// CHECK: }
+ CheckReductionLayoutNormalizer(hlo,
+ R"(
+// CHECK: [[arg0_0:%[^ ]+]] = f32[2,3,4,7]{2,1,0,3} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[7,2,3,4]{3,2,1,0} bitcast([[arg0_0]])
+// CHECK: [[idxs_2:%[^ ]+]] = u32[2,3,4,7]{3,2,1,0} parameter(1)
+// CHECK: [[copy_3:%[^ ]+]] = u32[2,3,4,7]{2,1,0,3} copy([[idxs_2]])
+// CHECK: [[bitcast_1_4:%[^ ]+]] = u32[7,2,3,4]{3,2,1,0} bitcast([[copy_3]])
+// CHECK: ROOT [[reduce0_5:%[^ ]+]] = (f32[2,3,4]{2,1,0}, u32[2,3,4]{2,1,0}) reduce([[bitcast_1]], [[bitcast_1_4]], [[constant0_6:%[^ ]+]], [[constant1_7:%[^ ]+]]), dimensions={0}, to_apply=[[argmax_8:%[^ ]+]]
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+ EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5}));
}
} // namespace
-} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc
index 99c942c..c36d842 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc
@@ -13,10 +13,12 @@
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h"
+
+#include <optional>
#include <utility>
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
-#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
@@ -28,16 +30,21 @@
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
-namespace gpu {
namespace {
-// TODO(b/210165681): The tests in this file are fragile to HLO op names.
-
-class TreeReductionRewriterTest : public GpuCodegenTest {};
+class TreeReductionRewriterTest : public HloTestBase {
+ public:
+ void CheckTreeRewriter(absl::string_view hlo,
+ std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(
+ hlo, gpu::GpuTreeReductionRewriter{se::CudaComputeCapability{8, 1}},
+ expected);
+ }
+};
TEST_F(TreeReductionRewriterTest, RowReductionSingleDimensionNoBatched) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -53,27 +60,17 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.2: f32[50000]) -> f32[224] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[50000]{0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[INSTR_2:%[^ ]+]] = f32[50176]{0} pad(f32[50000]{0} [[INSTR_0]], f32[] [[INSTR_1]]), padding=0_176
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[224,224]{1,0} bitcast(f32[50176]{0} [[INSTR_2]])
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = f32[224]{0} reduce(f32[224,224]{1,0} [[INSTR_3]], f32[] [[INSTR_1]]), dimensions={1}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[50000]{0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[224]{0} fusion(f32[50000]{0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = f32[] reduce(f32[224]{0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[pad_0:%[^ ]+]] = f32[50176]{0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_176
+// CHECK: [[bitcast_3:%[^ ]+]] = f32[224,224]{1,0} bitcast([[pad_0]])
+// CHECK: [[reduce_4:%[^ ]+]] = f32[224]{0} reduce([[bitcast_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_5:%[^ ]+]]
+// CHECK: ROOT [[out_1_6:%[^ ]+]] = f32[] reduce([[reduce_4]], [[zero_2]]), dimensions={0}, to_apply=[[add_5]]
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest, RowReductionWeirdOutputLayout) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -90,16 +87,15 @@
)";
// Check that we preserve the layout.
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
+ CheckTreeRewriter(hlo,
+ R"(
// CHECK: f32[2,4]{0,1} reduce(
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest,
RowReductionSingleDimensionNoBatchedDivisible) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -115,26 +111,18 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.2: f32[49952]) -> f32[223] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[49952]{0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[223,224]{1,0} bitcast(f32[49952]{0} [[INSTR_0]])
-// CHECK: [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_3:%[^ ]+]] = f32[223]{0} reduce(f32[223,224]{1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={1}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[49952]) -> f32[] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[49952]{0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[223]{0} fusion(f32[49952]{0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = f32[] reduce(f32[223]{0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[49952]{0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[223,224]{1,0} bitcast([[input_0]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[223]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest, RowReductionNoBatched) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -150,28 +138,18 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.2: f32[100,10,90000]) -> f32[100,10,300] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[100,10,90000]{2,1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[100,10,300,300]{3,2,1,0} bitcast(f32[100,10,90000]{2,1,0} [[INSTR_0]])
-// CHECK: [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_3:%[^ ]+]] = f32[100,10,300]{2,1,0} reduce(f32[100,10,300,300]{3,2,1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={3}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[100,10,90000]) -> f32[100,10] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[100,10,90000]{2,1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[100,10,300]{2,1,0} fusion(f32[100,10,90000]{2,1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = f32[100,10]{1,0} reduce(f32[100,10,300]{2,1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={2}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[100,10,300,300]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[100,10,300]{2,1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={3}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100,10]{1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={2}, to_apply=[[add_4]]
)");
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest,
RowReductionSingleDimensionNoBatchedLargeInput) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -187,26 +165,18 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.2: f32[1000000]) -> f32[1000] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[1000000]{0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[1000,1000]{1,0} bitcast(f32[1000000]{0} [[INSTR_0]])
-// CHECK: [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_3:%[^ ]+]] = f32[1000]{0} reduce(f32[1000,1000]{1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={1}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[1000000]{0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[1000]{0} fusion(f32[1000000]{0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = f32[] reduce(f32[1000]{0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[1000000]{0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[1000,1000]{1,0} bitcast([[input_0]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[1000]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionFits) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -222,27 +192,17 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.2: f32[8,100,90000]) -> f32[100,300] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[8,100,90000]{2,1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[8,100,300,300]{3,2,1,0} bitcast(f32[8,100,90000]{2,1,0} [[INSTR_0]])
-// CHECK: [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_3:%[^ ]+]] = f32[100,300]{1,0} reduce(f32[8,100,300,300]{3,2,1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={3,0}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[8,100,90000]) -> f32[100] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[8,100,90000]{2,1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[100,300]{1,0} fusion(f32[8,100,90000]{2,1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = f32[100]{0} reduce(f32[100,300]{1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={1}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[8,100,300,300]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[100,300]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={3,0}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
)");
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -258,30 +218,15 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.2: f32[32,100,90000]) -> f32[32,100,300] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[32,100,90000]{2,1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[32,100,300,300]{3,2,1,0} bitcast(f32[32,100,90000]{2,1,0} [[INSTR_0]])
-// CHECK: [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_3:%[^ ]+]] = f32[32,100,300]{2,1,0} reduce(f32[32,100,300,300]{3,2,1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={3}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[32,100,90000]) -> f32[100] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[32,100,90000]{2,1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[32,100,300]{2,1,0} fusion(f32[32,100,90000]{2,1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[INSTR_4:%[^ ]+]] = f32[32,100]{1,0} reduce(f32[32,100,300]{2,1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={2}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: ROOT [[INSTR_6:%[^ ]+]] = f32[100]{0} reduce(f32[32,100]{1,0} [[INSTR_4]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[reduce_0:%[^ ]+]] = f32[32,100]{1,0} reduce([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), dimensions={2}, to_apply=[[add_3:%[^ ]+]]
+// CHECK: ROOT [[out_1_4:%[^ ]+]] = f32[100]{0} reduce([[reduce_0]], [[zero_2]]), dimensions={0}, to_apply=[[add_3]]
)");
-
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest, ColumnReductionSimple) {
- // TODO(cheshire): reduce duplication for HLO text, factor out the common
- // part.
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -297,27 +242,18 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.2: f32[10000,100]) -> f32[100,100] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[10000,100]{1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[100,100,100]{2,1,0} bitcast(f32[10000,100]{1,0}
-// %param_0.2)
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_1:%[^ ]+]] = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} [[INSTR_2:%[^ ]+]], f32[] [[INSTR_0]]), dimensions={0}, to_apply=[[INSTR_3:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[10000,100]) -> f32[100] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[10000,100]{1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[100,100]{1,0} fusion(f32[10000,100]{1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = f32[100]{0} reduce(f32[100,100]{1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+
+// CHECK: [[input_0:%[^ ]+]] = f32[10000,100]{1,0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[100,100,100]{2,1,0} bitcast([[input_0]])
+// CHECK: [[reduce_2:%[^ ]+]] = f32[100,100]{1,0} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={0}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_2]], [[zero_3]]), dimensions={0}, to_apply=[[add_4]]
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest, ColumnReductionSimpleNoSquareDivisible) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -333,27 +269,18 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.2: f32[10302,100]) -> f32[102,100] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[10302,100]{1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[101,102,100]{2,1,0} bitcast(f32[10302,100]{1,0}
-// %param_0.2)
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_1:%[^ ]+]] = f32[102,100]{1,0} reduce(f32[101,102,100]{2,1,0} [[INSTR_2:%[^ ]+]], f32[] [[INSTR_0]]), dimensions={0}, to_apply=[[INSTR_3:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[10302,100]) -> f32[100] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[10302,100]{1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[102,100]{1,0} fusion(f32[10302,100]{1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = f32[100]{0} reduce(f32[102,100]{1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[10302,100]{1,0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[101,102,100]{2,1,0} bitcast([[input_0]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[102,100]{1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={0}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest, ColumnReductionOtherIndex) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -369,28 +296,18 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.2: f32[10000,2,2,2]) -> f32[100,2,2,2] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[10000,2,2,2]{3,2,1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[100,100,2,2,2]{4,3,2,1,0} bitcast(f32[10000,2,2,2]{3,2,1,0} [[INSTR_0]])
-// CHECK: [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_3:%[^ ]+]] = f32[100,2,2,2]{3,2,1,0} reduce(f32[100,100,2,2,2]{4,3,2,1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={0}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[10000,2,2,2]) -> f32[2,2,2] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[10000,2,2,2]{3,2,1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[100,2,2,2]{3,2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = f32[2,2,2]{2,1,0} reduce(f32[100,2,2,2]{3,2,1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[10000,2,2,2]{3,2,1,0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[100,100,2,2,2]{4,3,2,1,0} bitcast([[input_0]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[100,2,2,2]{3,2,1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={0}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[2,2,2]{2,1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest, ColumnReductionVeryLargeInput) {
- // TODO(cheshire): reduce duplication for HLO text, factor out the common
- // part.
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule ReduceWithPadding
add {
@@ -406,27 +323,18 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.2: f32[1000000,5]) -> f32[1000,5] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[1000000,5]{1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[1000,1000,5]{2,1,0} bitcast(f32[1000000,5]{1,0}
-// %param_0.2)
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_1:%[^ ]+]] = f32[1000,5]{1,0} reduce(f32[1000,1000,5]{2,1,0} [[INSTR_2:%[^ ]+]], f32[] [[INSTR_0]]), dimensions={0}, to_apply=[[INSTR_3:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[1000000,5]) -> f32[5] {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[1000000,5]{1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[1000,5]{1,0} fusion(f32[1000000,5]{1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = f32[5]{0} reduce(f32[1000,5]{1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[1000,1000,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[1000,5]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={0}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[5]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest, VariadicReductionLargeRow) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule Reduce_R1x2_to_R0x2_argmax
argmax {
@@ -459,44 +367,22 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.6: f32[], param_1.7: f32[], param_2.8: u32[], param_3.5: u32[]) -> (f32[], u32[]) {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[] parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[] parameter(1)
-// CHECK: [[INSTR_2:%[^ ]+]] = pred[] compare(f32[] [[INSTR_0]], f32[] [[INSTR_1]]), direction=GT
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] select(pred[] [[INSTR_2]], f32[] [[INSTR_0]], f32[] [[INSTR_1]])
-// CHECK: [[INSTR_4:%[^ ]+]] = u32[] parameter(2)
-// CHECK: [[INSTR_5:%[^ ]+]] = u32[] parameter(3)
-// CHECK: [[INSTR_6:%[^ ]+]].clone.1 = u32[] select(pred[] [[INSTR_2]], u32[] [[INSTR_4]], u32[] [[INSTR_5]])
-// CHECK: ROOT [[INSTR_7:%[^ ]+]] = (f32[], u32[]) tuple(f32[] [[INSTR_3]], u32[] [[INSTR_6]].clone.1)
-// CHECK: }
-// CHECK: (param_0.2: f32[2,100000]) -> (f32[2,317], u32[2,317]) {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[2,100000]{1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[INSTR_2:%[^ ]+]] = f32[2,100489]{1,0} pad(f32[2,100000]{1,0} [[INSTR_0]], f32[] [[INSTR_1]]), padding=0_0x0_489
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[2,317,317]{2,1,0} bitcast(f32[2,100489]{1,0} [[INSTR_2]])
-// CHECK: [[INSTR_4:%[^ ]+]] = u32[2,100000]{1,0} iota(), iota_dimension=0
-// CHECK: [[INSTR_5:%[^ ]+]] = u32[] constant(0)
-// CHECK: [[INSTR_6:%[^ ]+]] = u32[2,100489]{1,0} pad(u32[2,100000]{1,0} [[INSTR_4]], u32[] [[INSTR_5]]), padding=0_0x0_489
-// CHECK: [[INSTR_7:%[^ ]+]] = u32[2,317,317]{2,1,0} bitcast(u32[2,100489]{1,0} [[INSTR_6]])
-// CHECK: ROOT [[INSTR_8:%[^ ]+]] = (f32[2,317]{1,0}, u32[2,317]{1,0}) reduce(f32[2,317,317]{2,1,0} [[INSTR_3]], u32[2,317,317]{2,1,0} [[INSTR_7]], f32[] [[INSTR_1]], u32[] [[INSTR_5]]), dimensions={2}, to_apply=[[INSTR_9:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[2,100000]) -> (f32[2], u32[2]) {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[2,100000]{1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = (f32[2,317]{1,0}, u32[2,317]{1,0}) fusion(f32[2,100000]{1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[2,317]{1,0} get-tuple-element((f32[2,317]{1,0}, u32[2,317]{1,0}) [[INSTR_1]]), index=0
-// CHECK: [[INSTR_4:%[^ ]+]] = u32[2,317]{1,0} get-tuple-element((f32[2,317]{1,0}, u32[2,317]{1,0}) [[INSTR_1]]), index=1
-// CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[INSTR_6:%[^ ]+]] = u32[] constant(0)
-// CHECK: ROOT [[INSTR_7:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce(f32[2,317]{1,0} [[INSTR_3]], u32[2,317]{1,0} [[INSTR_4]], f32[] [[INSTR_5]], u32[] [[INSTR_6]]), dimensions={1}, to_apply=[[INSTR_8:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[pad_0:%[^ ]+]] = f32[2,100489]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_0x0_489
+// CHECK: [[bitcast_3:%[^ ]+]] = f32[2,317,317]{2,1,0} bitcast([[pad_0]])
+// CHECK: [[zero_idx_4:%[^ ]+]] = u32[] constant(0)
+// CHECK: [[pad_1_5:%[^ ]+]] = u32[2,100489]{1,0} pad([[idxs_6:%[^ ]+]], [[zero_idx_4]]), padding=0_0x0_489
+// CHECK: [[bitcast_1_7:%[^ ]+]] = u32[2,317,317]{2,1,0} bitcast([[pad_1_5]])
+// CHECK: [[reduce_8:%[^ ]+]] = (f32[2,317]{1,0}, u32[2,317]{1,0}) reduce([[bitcast_3]], [[bitcast_1_7]], [[zero_2]], [[zero_idx_4]]), dimensions={2}, to_apply=[[argmax_9:%[^ ]+]]
+// CHECK: [[get_tuple_element_10:%[^ ]+]] = f32[2,317]{1,0} get-tuple-element([[reduce_8]]), index=0
+// CHECK: [[get_tuple_element_1_11:%[^ ]+]] = u32[2,317]{1,0} get-tuple-element([[reduce_8]]), index=1
+// CHECK: ROOT [[out_1_12:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_10]], [[get_tuple_element_1_11]], [[zero_2]], [[zero_idx_4]]), dimensions={1}, to_apply=[[argmax_9]]
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(TreeReductionRewriterTest, VariadicReductionLargeBatchSize) {
- const char* hlo_text = R"(
+ const char* hlo = R"(
HloModule Reduce_R1x2_to_R0x2_argmax
argmax {
@@ -529,38 +415,14 @@
}
)";
- MatchOptimizedHloWithShapes(hlo_text,
- R"(
-// CHECK: (param_0.4: f32[], param_1.5: f32[], param_2.6: u32[], param_3.3: u32[]) -> (f32[], u32[]) {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[] parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = f32[] parameter(1)
-// CHECK: [[INSTR_2:%[^ ]+]] = pred[] compare(f32[] [[INSTR_0]], f32[] [[INSTR_1]]), direction=GT
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[] select(pred[] [[INSTR_2]], f32[] [[INSTR_0]], f32[] [[INSTR_1]])
-// CHECK: [[INSTR_4:%[^ ]+]] = u32[] parameter(2)
-// CHECK: [[INSTR_5:%[^ ]+]] = u32[] parameter(3)
-// CHECK: [[INSTR_6:%[^ ]+]].clone.1 = u32[] select(pred[] [[INSTR_2]], u32[] [[INSTR_4]], u32[] [[INSTR_5]])
-// CHECK: ROOT [[INSTR_7:%[^ ]+]] = (f32[], u32[]) tuple(f32[] [[INSTR_3]], u32[] [[INSTR_6]].clone.1)
-// CHECK: }
-// CHECK: (param_0: f32[20,2,100]) -> (f32[20,2], u32[20,2]) {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[20,2,100]{2,1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = u32[20,2,100]{2,1,0} iota(), iota_dimension=0
-// CHECK: [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[INSTR_3:%[^ ]+]] = u32[] constant(0)
-// CHECK: ROOT [[INSTR_4:%[^ ]+]] = (f32[20,2]{1,0}, u32[20,2]{1,0}) reduce(f32[20,2,100]{2,1,0} [[INSTR_0]], u32[20,2,100]{2,1,0} [[INSTR_1]], f32[] [[INSTR_2]], u32[] [[INSTR_3]]), dimensions={2}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[20,2,100]) -> (f32[2], u32[2]) {
-// CHECK: [[INSTR_0:%[^ ]+]] = f32[20,2,100]{2,1,0} parameter(0)
-// CHECK: [[INSTR_1:%[^ ]+]] = (f32[20,2]{1,0}, u32[20,2]{1,0}) fusion(f32[20,2,100]{2,1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK: [[INSTR_3:%[^ ]+]] = f32[20,2]{1,0} get-tuple-element((f32[20,2]{1,0}, u32[20,2]{1,0}) [[INSTR_1]]), index=0
-// CHECK: [[INSTR_4:%[^ ]+]] = u32[20,2]{1,0} get-tuple-element((f32[20,2]{1,0}, u32[20,2]{1,0}) [[INSTR_1]]), index=1
-// CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0)
-// CHECK: [[INSTR_6:%[^ ]+]] = u32[] constant(0)
-// CHECK: ROOT [[INSTR_7:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce(f32[20,2]{1,0} [[INSTR_3]], u32[20,2]{1,0} [[INSTR_4]], f32[] [[INSTR_5]], u32[] [[INSTR_6]]), dimensions={0}, to_apply=[[INSTR_8:%[^ ]+]]
-// CHECK: }
+ CheckTreeRewriter(hlo,
+ R"(
+// CHECK: [[reduce_0:%[^ ]+]] = (f32[20,2]{1,0}, u32[20,2]{1,0}) reduce([[input_1:%[^ ]+]], [[idxs_2:%[^ ]+]], [[zero_3:%[^ ]+]], [[zero_idx_4:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_5:%[^ ]+]]
+// CHECK: [[get_tuple_element_6:%[^ ]+]] = f32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=0
+// CHECK: [[get_tuple_element_1_7:%[^ ]+]] = u32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=1
+// CHECK: ROOT [[out_1_8:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_6]], [[get_tuple_element_1_7]], [[zero_3]], [[zero_idx_4]]), dimensions={0}, to_apply=[[argmax_5]]
)");
- EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
} // namespace
-} // namespace gpu
} // namespace xla