[XLA] [NFC] Make reduction rewriter tests more resilient: only run selected pass, do not run generated code

We rely on integration tests for numerical correctness/overall flow.

PiperOrigin-RevId: 454848300
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index da1ce6b..59b2075 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -273,17 +273,13 @@
     srcs = [
         "reduction_degenerate_dim_remover_test.cc",
     ],
-    tags = tf_cuda_tests_tags(),
     deps = [
-        ":gpu_codegen_test",
         "//tensorflow/compiler/xla:debug_options_flags",
         "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla/service:gpu_plugin",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_module_config",
         "//tensorflow/compiler/xla/service:hlo_parser",
-        "//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
-        "//tensorflow/compiler/xla/service/gpu:gpu_executable",
+        "//tensorflow/compiler/xla/service/gpu:reduction_degenerate_dim_remover",
         "//tensorflow/compiler/xla/tests:filecheck",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
@@ -300,24 +296,19 @@
     srcs = [
         "reduction_layout_normalizer_test.cc",
     ],
-    tags = tf_cuda_tests_tags(),
     deps = [
-        ":gpu_codegen_test",
         "//tensorflow/compiler/xla:debug_options_flags",
         "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla/service:gpu_plugin",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_module_config",
         "//tensorflow/compiler/xla/service:hlo_parser",
-        "//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
-        "//tensorflow/compiler/xla/service/gpu:gpu_executable",
+        "//tensorflow/compiler/xla/service/gpu:reduction_layout_normalizer",
         "//tensorflow/compiler/xla/tests:filecheck",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
-        "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -327,17 +318,13 @@
     srcs = [
         "tree_reduction_rewriter_test.cc",
     ],
-    tags = tf_cuda_tests_tags(),
     deps = [
-        ":gpu_codegen_test",
         "//tensorflow/compiler/xla:debug_options_flags",
         "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla/service:gpu_plugin",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_module_config",
         "//tensorflow/compiler/xla/service:hlo_parser",
-        "//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
-        "//tensorflow/compiler/xla/service/gpu:gpu_executable",
+        "//tensorflow/compiler/xla/service/gpu:tree_reduction_rewriter",
         "//tensorflow/compiler/xla/tests:filecheck",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
@@ -346,6 +333,7 @@
         "//tensorflow/core:test_main",
         "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -408,17 +396,13 @@
     srcs = [
         "reduction_dimension_grouper_test.cc",
     ],
-    tags = tf_cuda_tests_tags(),
     deps = [
-        ":gpu_codegen_test",
         "//tensorflow/compiler/xla:debug_options_flags",
         "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla/service:gpu_plugin",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_module_config",
         "//tensorflow/compiler/xla/service:hlo_parser",
-        "//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
-        "//tensorflow/compiler/xla/service/gpu:gpu_executable",
+        "//tensorflow/compiler/xla/service/gpu:reduction_dimension_grouper",
         "//tensorflow/compiler/xla/tests:filecheck",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc
index ab9890f..e1ecf40 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc
@@ -13,10 +13,11 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
+
+#include <optional>
 #include <utility>
 
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
-#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
@@ -28,23 +29,20 @@
 #include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace xla {
-namespace gpu {
 
 namespace {
 
-class ReductionDegenerateDimRemoverTest : public GpuCodegenTest {
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
-    debug_options.add_xla_disable_hlo_passes("reduction-layout-normalizer");
-    debug_options.add_xla_disable_hlo_passes("reduction-dimension-grouper");
-    debug_options.add_xla_disable_hlo_passes("reduction-splitter");
-    debug_options.add_xla_disable_hlo_passes("gpu-tree-reduction-rewriter");
-    return debug_options;
+class ReductionDegenerateDimRemoverTest : public HloTestBase {
+ public:
+  void CheckDegenerateDimRemover(absl::string_view hlo,
+                                 std::optional<absl::string_view> expected) {
+    RunAndFilecheckHloRewrite(hlo, gpu::ReductionDegenerateDimRemover{},
+                              expected);
   }
 };
 
 TEST_F(ReductionDegenerateDimRemoverTest, ReductionWithDegenerateDimensions) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithDegenerateDimensions
 
 add {
@@ -62,16 +60,16 @@
 
 )";
 
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: f32[] reduce(f32[3,4,5]{2,1,0} {{.+}}, f32[] {{.+}}), dimensions={0,1,2}, to_apply=%add
-      )");
+  CheckDegenerateDimRemover(hlo, R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[reduce_2:%[^ ]+]] = f32[] reduce([[bitcast_0]], [[zero_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[bitcast_1_5:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[reduce_2]])
+  )");
 }
 
 TEST_F(ReductionDegenerateDimRemoverTest,
        ReductionWithDegenerateDimensionsVariadic) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithDegenerateDimensions
 
 argmax {
@@ -102,15 +100,20 @@
 
 )";
 
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (f32[], u32[]) reduce(f32[3,4,5]{2,1,0} %bitcast.5, u32[3,4,5]{2,1,0} %bitcast.4, f32[] %zero_1, u32[] %zero_idx_1), dimensions={0,1,2}
+  CheckDegenerateDimRemover(hlo, R"(
+// CHECK:  [[bitcast_0:%[^ ]+]] = f32[3,4,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK:  [[bitcast_1_2:%[^ ]+]] = u32[3,4,5]{2,1,0} bitcast([[idxs_3:%[^ ]+]])
+// CHECK:  [[reduce_4:%[^ ]+]] = (f32[], u32[]) reduce([[bitcast_0]], [[bitcast_1_2]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
+// CHECK-NEXT:  [[get_tuple_element_8:%[^ ]+]] = f32[] get-tuple-element([[reduce_4]]), index=0
+// CHECK-NEXT:  [[bitcast_2_9:%[^ ]+]] = f32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_8]])
+// CHECK-NEXT:  [[get_tuple_element_1_10:%[^ ]+]] = u32[] get-tuple-element([[reduce_4]]), index=1
+// CHECK-NEXT:  [[bitcast_3_11:%[^ ]+]] = u32[1,1,1,1]{3,2,1,0} bitcast([[get_tuple_element_1_10]])
+// CHECK-NEXT:  ROOT [[tuple_12:%[^ ]+]] = (f32[1,1,1,1]{3,2,1,0}, u32[1,1,1,1]{3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
 )");
 }
 
 TEST_F(ReductionDegenerateDimRemoverTest, DegenerateWithEmptyDimension) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithDegenerateDimensions
 
 add {
@@ -125,22 +128,13 @@
 
   ROOT out = f32[3,4,5,1] reduce(input, zero), dimensions={0,2,4}, to_apply=add
 }
-
 )";
 
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  // Copy instruction is added after bitcast because of copy-insertion pass,
-  // so we check the entire hlo module to verify there is no reduce instruction
-  // in this case.
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: ENTRY %main (input: f32[1,3,1,4,1,5,1]) -> f32[3,4,5,1] {
-// CHECK:   %input = f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} parameter(0)
-// CHECK:   %bitcast{{.+}} = f32[3,4,5,1]{3,2,1,0} bitcast(f32[1,3,1,4,1,5,1]{6,5,4,3,2,1,0} %input)
-// CHECK:   ROOT %copy{{.+}} = f32[3,4,5,1]{3,2,1,0} copy(f32[3,4,5,1]{3,2,1,0} %bitcast{{.+}})
+  CheckDegenerateDimRemover(hlo,
+                            R"(
+// CHECK: ROOT [[bitcast_0:%[^ ]+]] = f32[3,4,5,1]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
       )");
 }
 
 }  // namespace
-}  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc
index 4e45753..74818d1 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_dimension_grouper_test.cc
@@ -13,10 +13,11 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
+
+#include <optional>
 #include <utility>
 
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
-#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
 #include "tensorflow/compiler/xla/tests/filecheck.h"
@@ -25,21 +26,19 @@
 #include "tensorflow/core/platform/test.h"
 
 namespace xla {
-namespace gpu {
 
 namespace {
 
-class ReductionDimensionGrouperTest : public GpuCodegenTest {
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
-    debug_options.add_xla_disable_hlo_passes("reduction-layout-normalizer");
-    debug_options.add_xla_disable_hlo_passes("layout-assignment");
-    return debug_options;
+class ReductionDimensionGrouperTest : public HloTestBase {
+ public:
+  void CheckDimensionGrouper(absl::string_view hlo,
+                             std::optional<absl::string_view> expected) {
+    RunAndFilecheckHloRewrite(hlo, gpu::ReductionDimensionGrouper{}, expected);
   }
 };
 
 TEST_F(ReductionDimensionGrouperTest, ReductionWithGrouping) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReductionWithGrouping
 
 add {
@@ -54,19 +53,18 @@
 
   ROOT out = f32[100,10]{0,1} reduce(input, zero), dimensions={2,3}, to_apply=add
 }
-
-
 )";
 
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: f32[100,10]{0,1} reduce(f32[100,10,96]{2,1,0} {{.+}}, f32[] {{.+}}), dimensions={2}, to_apply=%add
+  CheckDimensionGrouper(hlo,
+                        R"(
+// CHECK:  [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
+// CHECK:  ROOT [[out_1_2:%[^ ]+]] = f32[100,10]{0,1} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={2}, to_apply=[[add_4:%[^ ]+]]
       )");
 }
 
 TEST_F(ReductionDimensionGrouperTest, ReductionWithGroupingVariadic) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReductionWithGrouping
 
 argmax {
@@ -94,16 +92,16 @@
 
   ROOT out = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce(input, idxs, zero, zero_idx), dimensions={2,3}, to_apply=argmax
 }
-
-
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce(f32[100,10,96]{2,1,0} %bitcast.3, u32[100,10,96]{2,1,0} %bitcast.2, f32[] %zero_1, u32[] %zero_idx_1), dimensions={2}
+  CheckDimensionGrouper(hlo, R"(
+// CHECK:  [[input_0:%[^ ]+]] = f32[100,10,32,3]{3,2,1,0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[100,10,96]{2,1,0} bitcast([[input_0]])
+// CHECK:  [[idxs_2:%[^ ]+]] = u32[100,10,32,3]{3,2,1,0} parameter(1)
+// CHECK:  [[bitcast_1_3:%[^ ]+]] = u32[100,10,96]{2,1,0} bitcast([[idxs_2]])
+// CHECK:  ROOT [[out_1_4:%[^ ]+]] = (f32[100,10]{1,0}, u32[100,10]{1,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[zero_5:%[^ ]+]], [[zero_idx_6:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_7:%[^ ]+]]
 )");
 }
 
 }  // namespace
-}  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc
index 28a700e..20a58f9 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc
@@ -13,10 +13,11 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
+
+#include <optional>
 #include <utility>
 
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
-#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
 #include "tensorflow/compiler/xla/tests/filecheck.h"
@@ -25,25 +26,19 @@
 #include "tensorflow/core/platform/test.h"
 
 namespace xla {
-namespace gpu {
 
 namespace {
 
-// TODO(b/210165681): The tests in this file are fragile to HLO op names.
-
-class ReductionLayoutNormalizerTest : public GpuCodegenTest {
-  DebugOptions GetDebugOptionsForTest() override {
-    DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
-    debug_options.add_xla_disable_hlo_passes("reduction-dimension-grouper");
-    debug_options.add_xla_disable_hlo_passes("reduction-splitter");
-    debug_options.add_xla_disable_hlo_passes("layout-assignment");
-    debug_options.add_xla_disable_hlo_passes("gpu-tree-reduction-rewriter");
-    return debug_options;
+class ReductionLayoutNormalizerTest : public HloTestBase {
+ public:
+  void CheckReductionLayoutNormalizer(
+      absl::string_view hlo, std::optional<absl::string_view> expected) {
+    RunAndFilecheckHloRewrite(hlo, gpu::ReductionLayoutNormalizer{}, expected);
   }
 };
 
 TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTest) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithLayoutChange
 
 add {
@@ -61,15 +56,16 @@
 
 )";
 
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: f32[4,12,12,16,5]{2,1,3,4,0} reduce(f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} {{.+}}, f32[] {{.+}}), dimensions={0,1,2}, to_apply=%add
+  CheckReductionLayoutNormalizer(hlo,
+                                 R"(
+// CHECK:  [[bitcast_0:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_1:%[^ ]+]])
+// CHECK:  [[reduce_2:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} reduce([[bitcast_0]], [[constant0_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[bitcast_1_5:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[reduce_2]])
       )");
 }
 
 TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTestVariadic) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithLayoutChangeVariadic
 
 
@@ -104,16 +100,24 @@
 
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: %reduce.1 = (f32[4,12,12,16,5]{2,1,3,4,0}, u32[4,12,12,16,5]{2,1,3,4,0}) reduce(f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} %bitcast.5, u32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} %bitcast.4, f32[] %constant0_1, u32[] %constant1_1), dimensions={0,1,2}, to_apply=%argmax
-//
+  CheckReductionLayoutNormalizer(hlo,
+                                 R"(
+// CHECK:  [[arg0_0:%[^ ]+]] = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_0]])
+// CHECK:  [[idxs_2:%[^ ]+]] = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
+// CHECK:  [[bitcast_1_3:%[^ ]+]] = u32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[idxs_2]])
+// CHECK:  [[reduce_4:%[^ ]+]] = (f32[4,12,12,16,5]{2,1,3,4,0}, u32[4,12,12,16,5]{2,1,3,4,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[constant0_5:%[^ ]+]], [[constant1_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
+// CHECK:  [[get_tuple_element_8:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=0
+// CHECK:  [[bitcast_2_9:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_8]])
+// CHECK:  [[get_tuple_element_1_10:%[^ ]+]] = u32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=1
+// CHECK:  [[bitcast_3_11:%[^ ]+]] = u32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_1_10]])
+// CHECK:  ROOT [[tuple_12:%[^ ]+]] = (f32[4,5,16,12,12]{4,3,2,1,0}, u32[4,5,16,12,12]{4,3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
       )");
 }
 
 TEST_F(ReductionLayoutNormalizerTest,
        LayoutCanonicalizerTestVariadicDifferentLayouts) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithLayoutChangeVariadicDifferent
 
 argmax {
@@ -147,27 +151,17 @@
 
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: %fused_computation (param_0.1: u32[2,3,4,7]) -> u32[7,2,3,4] {
-// CHECK:  %param_0.1 = u32[2,3,4,7]{3,2,1,0} parameter(0)
-// CHECK:  %copy.1 = u32[2,3,4,7]{2,1,0,3} copy(u32[2,3,4,7]{3,2,1,0} %param_0.1)
-// CHECK:  ROOT %bitcast.2 = u32[7,2,3,4]{3,2,1,0} bitcast(u32[2,3,4,7]{2,1,0,3} %copy.1)
-// CHECK: }
-//
-// CHECK: ENTRY %main (arg0: f32[2,3,4,7], idxs: u32[2,3,4,7]) -> (f32[2,3,4], u32[2,3,4]) {
-// CHECK:  %arg0 = f32[2,3,4,7]{2,1,0,3} parameter(0)
-// CHECK:  %bitcast = f32[7,2,3,4]{3,2,1,0} bitcast(f32[2,3,4,7]{2,1,0,3} %arg0)
-// CHECK:  %idxs = u32[2,3,4,7]{3,2,1,0} parameter(1)
-// CHECK:  %fusion = u32[7,2,3,4]{3,2,1,0} fusion(u32[2,3,4,7]{3,2,1,0} %idxs), kind=kLoop, calls=%fused_computation
-// CHECK:  %constant0 = f32[] constant(0)
-// CHECK:  %constant1 = u32[] constant(0)
-// CHECK:  ROOT %reduce0 = (f32[2,3,4]{2,1,0}, u32[2,3,4]{2,1,0}) reduce(f32[7,2,3,4]{3,2,1,0} %bitcast, u32[7,2,3,4]{3,2,1,0} %fusion, f32[] %constant0, u32[] %constant1), dimensions={0}, to_apply=%argmax
-// CHECK: }
+  CheckReductionLayoutNormalizer(hlo,
+                                 R"(
+// CHECK:  [[arg0_0:%[^ ]+]] = f32[2,3,4,7]{2,1,0,3} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[7,2,3,4]{3,2,1,0} bitcast([[arg0_0]])
+// CHECK:  [[idxs_2:%[^ ]+]] = u32[2,3,4,7]{3,2,1,0} parameter(1)
+// CHECK:  [[copy_3:%[^ ]+]] = u32[2,3,4,7]{2,1,0,3} copy([[idxs_2]])
+// CHECK:  [[bitcast_1_4:%[^ ]+]] = u32[7,2,3,4]{3,2,1,0} bitcast([[copy_3]])
+// CHECK:  ROOT [[reduce0_5:%[^ ]+]] = (f32[2,3,4]{2,1,0}, u32[2,3,4]{2,1,0}) reduce([[bitcast_1]], [[bitcast_1_4]], [[constant0_6:%[^ ]+]], [[constant1_7:%[^ ]+]]), dimensions={0}, to_apply=[[argmax_8:%[^ ]+]]
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
+  EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5}));
 }
 
 }  // namespace
-}  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc
index 99c942c..c36d842 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc
@@ -13,10 +13,12 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.h"
+
+#include <optional>
 #include <utility>
 
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
-#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
@@ -28,16 +30,21 @@
 #include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace xla {
-namespace gpu {
 
 namespace {
 
-// TODO(b/210165681): The tests in this file are fragile to HLO op names.
-
-class TreeReductionRewriterTest : public GpuCodegenTest {};
+class TreeReductionRewriterTest : public HloTestBase {
+ public:
+  void CheckTreeRewriter(absl::string_view hlo,
+                         std::optional<absl::string_view> expected) {
+    RunAndFilecheckHloRewrite(
+        hlo, gpu::GpuTreeReductionRewriter{se::CudaComputeCapability{8, 1}},
+        expected);
+  }
+};
 
 TEST_F(TreeReductionRewriterTest, RowReductionSingleDimensionNoBatched) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -53,27 +60,17 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.2: f32[50000]) -> f32[224] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[50000]{0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[] constant(0)
-// CHECK:   [[INSTR_2:%[^ ]+]] = f32[50176]{0} pad(f32[50000]{0} [[INSTR_0]], f32[] [[INSTR_1]]), padding=0_176
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[224,224]{1,0} bitcast(f32[50176]{0} [[INSTR_2]])
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = f32[224]{0} reduce(f32[224,224]{1,0} [[INSTR_3]], f32[] [[INSTR_1]]), dimensions={1}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[50000]{0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[224]{0} fusion(f32[50000]{0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = f32[] reduce(f32[224]{0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK: [[pad_0:%[^ ]+]] = f32[50176]{0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_176
+// CHECK: [[bitcast_3:%[^ ]+]] = f32[224,224]{1,0} bitcast([[pad_0]])
+// CHECK: [[reduce_4:%[^ ]+]] = f32[224]{0} reduce([[bitcast_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_5:%[^ ]+]]
+// CHECK: ROOT [[out_1_6:%[^ ]+]] = f32[] reduce([[reduce_4]], [[zero_2]]), dimensions={0}, to_apply=[[add_5]]
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest, RowReductionWeirdOutputLayout) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -90,16 +87,15 @@
 )";
 
   // Check that we preserve the layout.
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
+  CheckTreeRewriter(hlo,
+                    R"(
 // CHECK: f32[2,4]{0,1} reduce(
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest,
        RowReductionSingleDimensionNoBatchedDivisible) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -115,26 +111,18 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.2: f32[49952]) -> f32[223] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[49952]{0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[223,224]{1,0} bitcast(f32[49952]{0} [[INSTR_0]])
-// CHECK:   [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_3:%[^ ]+]] = f32[223]{0} reduce(f32[223,224]{1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={1}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[49952]) -> f32[] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[49952]{0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[223]{0} fusion(f32[49952]{0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = f32[] reduce(f32[223]{0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK: [[input_0:%[^ ]+]] = f32[49952]{0} parameter(0)
+// CHECK: [[bitcast_1:%[^ ]+]] = f32[223,224]{1,0} bitcast([[input_0]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[223]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest, RowReductionNoBatched) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -150,28 +138,18 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.2: f32[100,10,90000]) -> f32[100,10,300] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[100,10,90000]{2,1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[100,10,300,300]{3,2,1,0} bitcast(f32[100,10,90000]{2,1,0} [[INSTR_0]])
-// CHECK:   [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_3:%[^ ]+]] = f32[100,10,300]{2,1,0} reduce(f32[100,10,300,300]{3,2,1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={3}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[100,10,90000]) -> f32[100,10] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[100,10,90000]{2,1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[100,10,300]{2,1,0} fusion(f32[100,10,90000]{2,1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = f32[100,10]{1,0} reduce(f32[100,10,300]{2,1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={2}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK: [[bitcast_0:%[^ ]+]] = f32[100,10,300,300]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK: [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK: [[reduce_3:%[^ ]+]] = f32[100,10,300]{2,1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={3}, to_apply=[[add_4:%[^ ]+]]
+// CHECK: ROOT [[out_1_5:%[^ ]+]] = f32[100,10]{1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={2}, to_apply=[[add_4]]
       )");
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest,
        RowReductionSingleDimensionNoBatchedLargeInput) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -187,26 +165,18 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.2: f32[1000000]) -> f32[1000] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[1000000]{0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[1000,1000]{1,0} bitcast(f32[1000000]{0} [[INSTR_0]])
-// CHECK:   [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_3:%[^ ]+]] = f32[1000]{0} reduce(f32[1000,1000]{1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={1}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[1000000]{0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[1000]{0} fusion(f32[1000000]{0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = f32[] reduce(f32[1000]{0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[input_0:%[^ ]+]] = f32[1000000]{0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[1000,1000]{1,0} bitcast([[input_0]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[1000]{0} reduce([[bitcast_1]], [[zero_2]]), dimensions={1}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[] reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionFits) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -222,27 +192,17 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.2: f32[8,100,90000]) -> f32[100,300] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[8,100,90000]{2,1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[8,100,300,300]{3,2,1,0} bitcast(f32[8,100,90000]{2,1,0} [[INSTR_0]])
-// CHECK:   [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_3:%[^ ]+]] = f32[100,300]{1,0} reduce(f32[8,100,300,300]{3,2,1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={3,0}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[8,100,90000]) -> f32[100] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[8,100,90000]{2,1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[100,300]{1,0} fusion(f32[8,100,90000]{2,1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = f32[100]{0} reduce(f32[100,300]{1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={1}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[bitcast_0:%[^ ]+]] = f32[8,100,300,300]{3,2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[100,300]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={3,0}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={1}, to_apply=[[add_4]]
       )");
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -258,30 +218,15 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.2: f32[32,100,90000]) -> f32[32,100,300] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[32,100,90000]{2,1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[32,100,300,300]{3,2,1,0} bitcast(f32[32,100,90000]{2,1,0} [[INSTR_0]])
-// CHECK:   [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_3:%[^ ]+]] = f32[32,100,300]{2,1,0} reduce(f32[32,100,300,300]{3,2,1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={3}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[32,100,90000]) -> f32[100] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[32,100,90000]{2,1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[32,100,300]{2,1,0} fusion(f32[32,100,90000]{2,1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK:   [[INSTR_4:%[^ ]+]] = f32[32,100]{1,0} reduce(f32[32,100,300]{2,1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={2}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK:   ROOT [[INSTR_6:%[^ ]+]] = f32[100]{0} reduce(f32[32,100]{1,0} [[INSTR_4]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK: [[reduce_0:%[^ ]+]] = f32[32,100]{1,0} reduce([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), dimensions={2}, to_apply=[[add_3:%[^ ]+]]
+// CHECK:  ROOT [[out_1_4:%[^ ]+]] = f32[100]{0} reduce([[reduce_0]], [[zero_2]]), dimensions={0}, to_apply=[[add_3]]
       )");
-
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest, ColumnReductionSimple) {
-  // TODO(cheshire): reduce duplication for HLO text, factor out the common
-  // part.
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -297,27 +242,18 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.2: f32[10000,100]) -> f32[100,100] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[10000,100]{1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[100,100,100]{2,1,0} bitcast(f32[10000,100]{1,0}
-// %param_0.2)
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_1:%[^ ]+]] = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} [[INSTR_2:%[^ ]+]], f32[] [[INSTR_0]]), dimensions={0}, to_apply=[[INSTR_3:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[10000,100]) -> f32[100] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[10000,100]{1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[100,100]{1,0} fusion(f32[10000,100]{1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = f32[100]{0} reduce(f32[100,100]{1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+
+// CHECK:  [[input_0:%[^ ]+]] = f32[10000,100]{1,0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[100,100,100]{2,1,0} bitcast([[input_0]])
+// CHECK:  [[reduce_2:%[^ ]+]] = f32[100,100]{1,0} reduce([[bitcast_1]], [[zero_3:%[^ ]+]]), dimensions={0}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_2]], [[zero_3]]), dimensions={0}, to_apply=[[add_4]]
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest, ColumnReductionSimpleNoSquareDivisible) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -333,27 +269,18 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.2: f32[10302,100]) -> f32[102,100] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[10302,100]{1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[101,102,100]{2,1,0} bitcast(f32[10302,100]{1,0}
-// %param_0.2)
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_1:%[^ ]+]] = f32[102,100]{1,0} reduce(f32[101,102,100]{2,1,0} [[INSTR_2:%[^ ]+]], f32[] [[INSTR_0]]), dimensions={0}, to_apply=[[INSTR_3:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[10302,100]) -> f32[100] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[10302,100]{1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[102,100]{1,0} fusion(f32[10302,100]{1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = f32[100]{0} reduce(f32[102,100]{1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[input_0:%[^ ]+]] = f32[10302,100]{1,0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[101,102,100]{2,1,0} bitcast([[input_0]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[102,100]{1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={0}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[100]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest, ColumnReductionOtherIndex) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -369,28 +296,18 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.2: f32[10000,2,2,2]) -> f32[100,2,2,2] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[10000,2,2,2]{3,2,1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[100,100,2,2,2]{4,3,2,1,0} bitcast(f32[10000,2,2,2]{3,2,1,0} [[INSTR_0]])
-// CHECK:   [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_3:%[^ ]+]] = f32[100,2,2,2]{3,2,1,0} reduce(f32[100,100,2,2,2]{4,3,2,1,0} [[INSTR_1]], f32[] [[INSTR_2]]), dimensions={0}, to_apply=[[INSTR_4:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[10000,2,2,2]) -> f32[2,2,2] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[10000,2,2,2]{3,2,1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[100,2,2,2]{3,2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = f32[2,2,2]{2,1,0} reduce(f32[100,2,2,2]{3,2,1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[input_0:%[^ ]+]] = f32[10000,2,2,2]{3,2,1,0} parameter(0)
+// CHECK:  [[bitcast_1:%[^ ]+]] = f32[100,100,2,2,2]{4,3,2,1,0} bitcast([[input_0]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[100,2,2,2]{3,2,1,0} reduce([[bitcast_1]], [[zero_2]]), dimensions={0}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[2,2,2]{2,1,0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest, ColumnReductionVeryLargeInput) {
-  // TODO(cheshire): reduce duplication for HLO text, factor out the common
-  // part.
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule ReduceWithPadding
 
 add {
@@ -406,27 +323,18 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.2: f32[1000000,5]) -> f32[1000,5] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[1000000,5]{1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[1000,1000,5]{2,1,0} bitcast(f32[1000000,5]{1,0}
-// %param_0.2)
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_1:%[^ ]+]] = f32[1000,5]{1,0} reduce(f32[1000,1000,5]{2,1,0} [[INSTR_2:%[^ ]+]], f32[] [[INSTR_0]]), dimensions={0}, to_apply=[[INSTR_3:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[1000000,5]) -> f32[5] {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[1000000,5]{1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[1000,5]{1,0} fusion(f32[1000000,5]{1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] constant(0)
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = f32[5]{0} reduce(f32[1000,5]{1,0} [[INSTR_1]], f32[] [[INSTR_3]]), dimensions={0}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+
+// CHECK:  [[bitcast_0:%[^ ]+]] = f32[1000,1000,5]{2,1,0} bitcast([[input_1:%[^ ]+]])
+// CHECK:  [[zero_2:%[^ ]+]] = f32[] constant(0)
+// CHECK:  [[reduce_3:%[^ ]+]] = f32[1000,5]{1,0} reduce([[bitcast_0]], [[zero_2]]), dimensions={0}, to_apply=[[add_4:%[^ ]+]]
+// CHECK:  ROOT [[out_1_5:%[^ ]+]] = f32[5]{0} reduce([[reduce_3]], [[zero_2]]), dimensions={0}, to_apply=[[add_4]]
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest, VariadicReductionLargeRow) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule Reduce_R1x2_to_R0x2_argmax
 
 argmax {
@@ -459,44 +367,22 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.6: f32[], param_1.7: f32[], param_2.8: u32[], param_3.5: u32[]) -> (f32[], u32[]) {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[] parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[] parameter(1)
-// CHECK:   [[INSTR_2:%[^ ]+]] = pred[] compare(f32[] [[INSTR_0]], f32[] [[INSTR_1]]), direction=GT
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] select(pred[] [[INSTR_2]], f32[] [[INSTR_0]], f32[] [[INSTR_1]])
-// CHECK:   [[INSTR_4:%[^ ]+]] = u32[] parameter(2)
-// CHECK:   [[INSTR_5:%[^ ]+]] = u32[] parameter(3)
-// CHECK:   [[INSTR_6:%[^ ]+]].clone.1 = u32[] select(pred[] [[INSTR_2]], u32[] [[INSTR_4]], u32[] [[INSTR_5]])
-// CHECK:   ROOT [[INSTR_7:%[^ ]+]] = (f32[], u32[]) tuple(f32[] [[INSTR_3]], u32[] [[INSTR_6]].clone.1)
-// CHECK: }
-// CHECK: (param_0.2: f32[2,100000]) -> (f32[2,317], u32[2,317]) {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[2,100000]{1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[] constant(0)
-// CHECK:   [[INSTR_2:%[^ ]+]] = f32[2,100489]{1,0} pad(f32[2,100000]{1,0} [[INSTR_0]], f32[] [[INSTR_1]]), padding=0_0x0_489
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[2,317,317]{2,1,0} bitcast(f32[2,100489]{1,0} [[INSTR_2]])
-// CHECK:   [[INSTR_4:%[^ ]+]] = u32[2,100000]{1,0} iota(), iota_dimension=0
-// CHECK:   [[INSTR_5:%[^ ]+]] = u32[] constant(0)
-// CHECK:   [[INSTR_6:%[^ ]+]] = u32[2,100489]{1,0} pad(u32[2,100000]{1,0} [[INSTR_4]], u32[] [[INSTR_5]]), padding=0_0x0_489
-// CHECK:   [[INSTR_7:%[^ ]+]] = u32[2,317,317]{2,1,0} bitcast(u32[2,100489]{1,0} [[INSTR_6]])
-// CHECK:   ROOT [[INSTR_8:%[^ ]+]] = (f32[2,317]{1,0}, u32[2,317]{1,0}) reduce(f32[2,317,317]{2,1,0} [[INSTR_3]], u32[2,317,317]{2,1,0} [[INSTR_7]], f32[] [[INSTR_1]], u32[] [[INSTR_5]]), dimensions={2}, to_apply=[[INSTR_9:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[2,100000]) -> (f32[2], u32[2]) {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[2,100000]{1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = (f32[2,317]{1,0}, u32[2,317]{1,0}) fusion(f32[2,100000]{1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[2,317]{1,0} get-tuple-element((f32[2,317]{1,0}, u32[2,317]{1,0}) [[INSTR_1]]), index=0
-// CHECK:   [[INSTR_4:%[^ ]+]] = u32[2,317]{1,0} get-tuple-element((f32[2,317]{1,0}, u32[2,317]{1,0}) [[INSTR_1]]), index=1
-// CHECK:   [[INSTR_5:%[^ ]+]] = f32[] constant(0)
-// CHECK:   [[INSTR_6:%[^ ]+]] = u32[] constant(0)
-// CHECK:   ROOT [[INSTR_7:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce(f32[2,317]{1,0} [[INSTR_3]], u32[2,317]{1,0} [[INSTR_4]], f32[] [[INSTR_5]], u32[] [[INSTR_6]]), dimensions={1}, to_apply=[[INSTR_8:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[pad_0:%[^ ]+]] = f32[2,100489]{1,0} pad([[input_1:%[^ ]+]], [[zero_2:%[^ ]+]]), padding=0_0x0_489
+// CHECK:  [[bitcast_3:%[^ ]+]] = f32[2,317,317]{2,1,0} bitcast([[pad_0]])
+// CHECK:  [[zero_idx_4:%[^ ]+]] = u32[] constant(0)
+// CHECK:  [[pad_1_5:%[^ ]+]] = u32[2,100489]{1,0} pad([[idxs_6:%[^ ]+]], [[zero_idx_4]]), padding=0_0x0_489
+// CHECK:  [[bitcast_1_7:%[^ ]+]] = u32[2,317,317]{2,1,0} bitcast([[pad_1_5]])
+// CHECK:  [[reduce_8:%[^ ]+]] = (f32[2,317]{1,0}, u32[2,317]{1,0}) reduce([[bitcast_3]], [[bitcast_1_7]], [[zero_2]], [[zero_idx_4]]), dimensions={2}, to_apply=[[argmax_9:%[^ ]+]]
+// CHECK:  [[get_tuple_element_10:%[^ ]+]] = f32[2,317]{1,0} get-tuple-element([[reduce_8]]), index=0
+// CHECK:  [[get_tuple_element_1_11:%[^ ]+]] = u32[2,317]{1,0} get-tuple-element([[reduce_8]]), index=1
+// CHECK:  ROOT [[out_1_12:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_10]], [[get_tuple_element_1_11]], [[zero_2]], [[zero_idx_4]]), dimensions={1}, to_apply=[[argmax_9]]
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 TEST_F(TreeReductionRewriterTest, VariadicReductionLargeBatchSize) {
-  const char* hlo_text = R"(
+  const char* hlo = R"(
 HloModule Reduce_R1x2_to_R0x2_argmax
 
 argmax {
@@ -529,38 +415,14 @@
 }
 )";
 
-  MatchOptimizedHloWithShapes(hlo_text,
-                              R"(
-// CHECK: (param_0.4: f32[], param_1.5: f32[], param_2.6: u32[], param_3.3: u32[]) -> (f32[], u32[]) {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[] parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = f32[] parameter(1)
-// CHECK:   [[INSTR_2:%[^ ]+]] = pred[] compare(f32[] [[INSTR_0]], f32[] [[INSTR_1]]), direction=GT
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[] select(pred[] [[INSTR_2]], f32[] [[INSTR_0]], f32[] [[INSTR_1]])
-// CHECK:   [[INSTR_4:%[^ ]+]] = u32[] parameter(2)
-// CHECK:   [[INSTR_5:%[^ ]+]] = u32[] parameter(3)
-// CHECK:   [[INSTR_6:%[^ ]+]].clone.1 = u32[] select(pred[] [[INSTR_2]], u32[] [[INSTR_4]], u32[] [[INSTR_5]])
-// CHECK:   ROOT [[INSTR_7:%[^ ]+]] = (f32[], u32[]) tuple(f32[] [[INSTR_3]], u32[] [[INSTR_6]].clone.1)
-// CHECK: }
-// CHECK: (param_0: f32[20,2,100]) -> (f32[20,2], u32[20,2]) {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[20,2,100]{2,1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = u32[20,2,100]{2,1,0} iota(), iota_dimension=0
-// CHECK:   [[INSTR_2:%[^ ]+]] = f32[] constant(0)
-// CHECK:   [[INSTR_3:%[^ ]+]] = u32[] constant(0)
-// CHECK:   ROOT [[INSTR_4:%[^ ]+]] = (f32[20,2]{1,0}, u32[20,2]{1,0}) reduce(f32[20,2,100]{2,1,0} [[INSTR_0]], u32[20,2,100]{2,1,0} [[INSTR_1]], f32[] [[INSTR_2]], u32[] [[INSTR_3]]), dimensions={2}, to_apply=[[INSTR_5:%[^ ]+]]
-// CHECK: }
-// CHECK: ENTRY %main (input: f32[20,2,100]) -> (f32[2], u32[2]) {
-// CHECK:   [[INSTR_0:%[^ ]+]] = f32[20,2,100]{2,1,0} parameter(0)
-// CHECK:   [[INSTR_1:%[^ ]+]] = (f32[20,2]{1,0}, u32[20,2]{1,0}) fusion(f32[20,2,100]{2,1,0} [[INSTR_0]]), kind=kInput, calls=[[INSTR_2:%[^ ]+]]
-// CHECK:   [[INSTR_3:%[^ ]+]] = f32[20,2]{1,0} get-tuple-element((f32[20,2]{1,0}, u32[20,2]{1,0}) [[INSTR_1]]), index=0
-// CHECK:   [[INSTR_4:%[^ ]+]] = u32[20,2]{1,0} get-tuple-element((f32[20,2]{1,0}, u32[20,2]{1,0}) [[INSTR_1]]), index=1
-// CHECK:   [[INSTR_5:%[^ ]+]] = f32[] constant(0)
-// CHECK:   [[INSTR_6:%[^ ]+]] = u32[] constant(0)
-// CHECK:   ROOT [[INSTR_7:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce(f32[20,2]{1,0} [[INSTR_3]], u32[20,2]{1,0} [[INSTR_4]], f32[] [[INSTR_5]], u32[] [[INSTR_6]]), dimensions={0}, to_apply=[[INSTR_8:%[^ ]+]]
-// CHECK: }
+  CheckTreeRewriter(hlo,
+                    R"(
+// CHECK:  [[reduce_0:%[^ ]+]] = (f32[20,2]{1,0}, u32[20,2]{1,0}) reduce([[input_1:%[^ ]+]], [[idxs_2:%[^ ]+]], [[zero_3:%[^ ]+]], [[zero_idx_4:%[^ ]+]]), dimensions={2}, to_apply=[[argmax_5:%[^ ]+]]
+// CHECK:  [[get_tuple_element_6:%[^ ]+]] = f32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=0
+// CHECK:  [[get_tuple_element_1_7:%[^ ]+]] = u32[20,2]{1,0} get-tuple-element([[reduce_0]]), index=1
+// CHECK:  ROOT [[out_1_8:%[^ ]+]] = (f32[2]{0}, u32[2]{0}) reduce([[get_tuple_element_6]], [[get_tuple_element_1_7]], [[zero_3]], [[zero_idx_4]]), dimensions={0}, to_apply=[[argmax_5]]
       )");
-  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
 }
 
 }  // namespace
-}  // namespace gpu
 }  // namespace xla