[mlir][hlo] Refactor rank specialization to allow an arbitrary number of inputs

This actually simplifies the code a bit.

PiperOrigin-RevId: 358201038
Change-Id: I6c9cc8ea391c988c8d5315d8e768debcfa869bf7
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
index 7c47b6f..3f12d51 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
@@ -202,6 +202,149 @@
   }
 };
 
+template <typename ChloOpTy, typename HloOpTy>
+struct ConvertUnrankedDynamicBroadcastOpHelper {
+  // Returns the dynamic result of checking the given value is effectively a
+  // scalar shape (i.e. the number of elements is 1).
+  static Value GreaterRankIsN(OpBuilder &builder, Location loc,
+                              Value actual_rank, int targeted_rank) {
+    return builder.create<CmpIOp>(
+        loc, CmpIPredicate::eq, actual_rank,
+        builder.create<ConstantIndexOp>(loc, targeted_rank));
+  }
+
+  static scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
+      OpBuilder &builder, ChloOpTy op, Value actual_rank, int targeted_rank) {
+    // Create the if block to place the current specialized logic in.
+    Value greater_rank_is_n =
+        GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
+    return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
+                                     greater_rank_is_n, true);
+  }
+
+  static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op,
+                                          Value value, int targeted_rank) {
+    auto loc = op.getLoc();
+    Value shape = builder.create<shape::ShapeOfOp>(loc, value);
+    SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
+    auto unknown_rank_extent_tensor_type = RankedTensorType::get(
+        {RankedTensorType::kDynamicSize}, builder.getIndexType());
+    auto known_rank_extent_tensor_type =
+        RankedTensorType::get({targeted_rank}, builder.getIndexType());
+    Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
+        loc, known_rank_extent_tensor_type,
+        mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
+                                        ranked_shape));
+    Value extended_value = builder.create<shape::BroadcastOp>(
+        loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
+    return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
+                                          extended_value);
+  }
+
+  // Create the if statement and code for a broadcasting op with a result of a
+  // given rank.
+  static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder,
+                                                  ChloOpTy op,
+                                                  ValueRange operands,
+                                                  int targeted_rank) {
+    auto loc = op.getLoc();
+    SmallVector<Value, 2> reshaped_operands;
+
+    auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
+        targeted_rank, RankedTensorType::kDynamicSize);
+
+    for (Value operand : operands) {
+      // Handle shape broadcasting and inference.
+      Value extended_operand_casted =
+          createBroadcastToKnownRank(if_builder, op, operand, targeted_rank);
+
+      // 1. Reshape operands to the given rank (with the same number of
+      // elements)
+      // 2. Compute the ranked-broadcasted ChloOp (which will assert that the
+      // ops
+      //    can be broadcasted and do the actual broadcasting)
+      // 3. Type erase the output back to unranked
+      auto reshaped_type = RankedTensorType::get(
+          dynamic_dimensions,
+          operand.getType().template dyn_cast<TensorType>().getElementType());
+      Value reshaped_operand = if_builder.create<mhlo::DynamicReshapeOp>(
+          loc, reshaped_type, operand, extended_operand_casted);
+      reshaped_operands.push_back(reshaped_operand);
+    }
+    auto result_element_type = op.getResult()
+                                   .getType()
+                                   .template dyn_cast<TensorType>()
+                                   .getElementType();
+    auto result_type =
+        RankedTensorType::get(dynamic_dimensions, result_element_type);
+    Value result = if_builder.create<ChloOpTy>(
+        loc, ArrayRef<Type>{result_type}, reshaped_operands, op.getAttrs());
+    Value reshaped_result = if_builder.create<tensor::CastOp>(
+        loc, UnrankedTensorType::get(result_element_type), result);
+    if_builder.create<scf::YieldOp>(loc, reshaped_result);
+  }
+
+  // Iterates over the desired ranks to be specialized and generates the code
+  // snippet for each case.
+  static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op,
+                                    ValueRange operands) {
+    auto loc = op.getLoc();
+
+    // Find the larger rank of the operands.
+    auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
+                                                    rewriter.getIndexType());
+    Value greater_rank;
+    for (Value operand : operands) {
+      Value shape =
+          rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
+      Value rank =
+          rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
+      if (!greater_rank) {
+        greater_rank = rank;
+      } else {
+        Value greater_rank_compare = rewriter.create<CmpIOp>(
+            loc, CmpIPredicate::sgt, greater_rank, rank);
+        greater_rank = rewriter.create<SelectOp>(loc, greater_rank_compare,
+                                                 greater_rank, rank);
+      }
+    }
+
+    // Generate a list of nested if/else statements to handle rank
+    // specializations from 1 to `kMaxRankSpecialization`.
+    scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
+        rewriter, op, greater_rank, 1);
+    OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
+    createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1);
+
+    // Put each subsequent rank specialization inside the else statement of the
+    // previous one.
+    OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
+    constexpr int kMaxRankSpecialization = 6;
+    for (int i = 2; i < kMaxRankSpecialization; i++) {
+      auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
+          else_builder, op, greater_rank, i);
+      if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
+      createRankSpecializedBroadcastAndOp(if_builder, op, operands, i);
+      else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
+      else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
+    }
+    // Fire an assertion if none of the rank specializations applied (one of
+    // the ranks was greater than `kMaxRankSpecialization`).
+    else_builder.create<AssertOp>(
+        loc,
+        GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
+                       kMaxRankSpecialization),
+        "Input for dynamic binary op lowering was of a rank greater than " +
+            std::to_string(kMaxRankSpecialization));
+    // Add the rank 6 specialization to the innermost else block.
+    createRankSpecializedBroadcastAndOp(else_builder, op, operands,
+                                        kMaxRankSpecialization);
+
+    // Return the result of the outermost if statement.
+    return if_op.getResult(0);
+  }
+};
+
 // Handles lowering of the following pattern to patterns that will be further
 // matched by other patterns until they result in LHLO:
 //   %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
@@ -298,7 +441,9 @@
     OpBuilder if_neq_shapes_builder =
         if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
     if_neq_shapes_builder.create<scf::YieldOp>(
-        loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
+        loc, ConvertUnrankedDynamicBroadcastOpHelper<
+                 ChloOpTy, HloOpTy>::HandleBroadcastAndOp(if_neq_shapes_builder,
+                                                          op, {lhs, rhs}));
 
     rewriter.replaceOp(op, {if_op.getResult(0)});
     return success();
@@ -318,23 +463,6 @@
                                    rewriter.create<ConstantIndexOp>(loc, 1));
   }
 
-  Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
-                       int targeted_rank) const {
-    return builder.create<CmpIOp>(
-        loc, CmpIPredicate::eq, actual_rank,
-        builder.create<ConstantIndexOp>(loc, targeted_rank));
-  }
-
-  scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
-      OpBuilder &builder, ChloOpTy op, Value actual_rank,
-      int targeted_rank) const {
-    // Create the if block to place the current specialized logic in.
-    Value greater_rank_is_n =
-        GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
-    return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
-                                     greater_rank_is_n, true);
-  }
-
   Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value,
                                Value shape_of_lhs, Value shape_of_rhs) const {
     auto unknown_rank_extent_tensor_type = RankedTensorType::get(
@@ -345,122 +473,6 @@
     return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value,
                                                   broadcast_shape);
   }
-
-  Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value value,
-                                   int targeted_rank) const {
-    auto loc = op.getLoc();
-    Value shape = builder.create<shape::ShapeOfOp>(loc, value);
-    SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
-    auto unknown_rank_extent_tensor_type = RankedTensorType::get(
-        {RankedTensorType::kDynamicSize}, builder.getIndexType());
-    auto known_rank_extent_tensor_type =
-        RankedTensorType::get({targeted_rank}, builder.getIndexType());
-    Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
-        loc, known_rank_extent_tensor_type,
-        mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
-                                        ranked_shape));
-    Value extended_value = builder.create<shape::BroadcastOp>(
-        loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
-    return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
-                                          extended_value);
-  }
-
-  // Create the if statement and code for a broadcasting op with a result of a
-  // given rank.
-  void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
-                                           Value lhs, Value rhs,
-                                           int targeted_rank) const {
-    auto loc = op.getLoc();
-
-    // Handle shape broadcasting and inference.
-    Value extended_lhs_casted =
-        createBroadcastToKnownRank(if_builder, op, lhs, targeted_rank);
-    Value extended_rhs_casted =
-        createBroadcastToKnownRank(if_builder, op, rhs, targeted_rank);
-    auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
-        targeted_rank, RankedTensorType::kDynamicSize);
-    auto reshaped_type = RankedTensorType::get(
-        dynamic_dimensions,
-        lhs.getType().template dyn_cast<TensorType>().getElementType());
-
-    // 1. Reshape operands to the given rank (with the same number of elements)
-    // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
-    //    can be broadcasted and do the actual broadcasting)
-    // 3. Type erase the output back to unranked
-    Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
-        loc, reshaped_type, lhs, extended_lhs_casted);
-    Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
-        loc, reshaped_type, rhs, extended_rhs_casted);
-    auto result_element_type = op.getResult()
-                                   .getType()
-                                   .template dyn_cast<TensorType>()
-                                   .getElementType();
-    auto result_type =
-        RankedTensorType::get(dynamic_dimensions, result_element_type);
-    Value result = if_builder.create<ChloOpTy>(
-        loc, ArrayRef<Type>{result_type},
-        ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
-    Value reshaped_result = if_builder.create<tensor::CastOp>(
-        loc, UnrankedTensorType::get(result_element_type), result);
-    if_builder.create<scf::YieldOp>(loc, reshaped_result);
-  }
-
-  // Iterates over the desired ranks to be specialized and generates the code
-  // snippet for each case.
-  Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
-                             Value rhs) const {
-    auto loc = op.getLoc();
-
-    // Find the larger rank of the 2 operands.
-    auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
-                                                    rewriter.getIndexType());
-    Value lhs_shape =
-        rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
-    Value rhs_shape =
-        rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
-    Value lhs_rank =
-        rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), lhs_shape);
-    Value rhs_rank =
-        rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), rhs_shape);
-    Value greater_rank_lhs =
-        rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
-    Value greater_rank =
-        rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
-
-    // Generate a list of nested if/else statements to handle rank
-    // specializations from 1 to `kMaxRankSpecialization`.
-    scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
-        rewriter, op, greater_rank, 1);
-    OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
-    createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, 1);
-
-    // Put each subsequent rank specialization inside the else statement of the
-    // previous one.
-    OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
-    constexpr int kMaxRankSpecialization = 6;
-    for (int i = 2; i < kMaxRankSpecialization; i++) {
-      auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
-          else_builder, op, greater_rank, i);
-      if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
-      createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, i);
-      else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
-      else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
-    }
-    // Fire an assertion if none of the rank specializations applied (one of
-    // the ranks was greater than `kMaxRankSpecialization`).
-    else_builder.create<AssertOp>(
-        loc,
-        GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
-                       kMaxRankSpecialization),
-        "Input for dynamic binary op lowering was of a rank greater than " +
-            std::to_string(kMaxRankSpecialization));
-    // Add the rank 6 specialization to the innermost else block.
-    createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs,
-                                        kMaxRankSpecialization);
-
-    // Return the result of the outermost if statement.
-    return if_op.getResult(0);
-  }
 };
 
 struct TransformUnrankedHloPass
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir
index a074763..43270f3 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-transform-unranked.mlir
@@ -209,9 +209,9 @@
 // CHECK-NEXT:                   %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
 // CHECK-NEXT:                   %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
 // CHECK-NEXT:                   %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
+// CHECK-NEXT:                   %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 // CHECK-NEXT:                   %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
 // CHECK-NEXT:                   %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
-// CHECK-NEXT:                   %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 // CHECK-NEXT:                   %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
 // CHECK-NEXT:                   %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
 // CHECK-NEXT:                   %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
@@ -224,9 +224,9 @@
 // CHECK-NEXT:                     %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
 // CHECK-NEXT:                     %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
 // CHECK-NEXT:                     %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
+// CHECK-NEXT:                     %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 // CHECK-NEXT:                     %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
 // CHECK-NEXT:                     %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
-// CHECK-NEXT:                     %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 // CHECK-NEXT:                     %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 // CHECK-NEXT:                     %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK-NEXT:                     %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
@@ -239,9 +239,9 @@
 // CHECK-NEXT:                       %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
 // CHECK-NEXT:                       %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
 // CHECK-NEXT:                       %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
+// CHECK-NEXT:                       %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 // CHECK-NEXT:                       %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
 // CHECK-NEXT:                       %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
-// CHECK-NEXT:                       %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 // CHECK-NEXT:                       %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 // CHECK-NEXT:                       %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 // CHECK-NEXT:                       %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
@@ -254,9 +254,9 @@
 // CHECK-NEXT:                         %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
 // CHECK-NEXT:                         %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
 // CHECK-NEXT:                         %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
+// CHECK-NEXT:                         %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 // CHECK-NEXT:                         %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
 // CHECK-NEXT:                         %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
-// CHECK-NEXT:                         %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 // CHECK-NEXT:                         %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 // CHECK-NEXT:                         %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
 // CHECK-NEXT:                         %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
@@ -269,9 +269,9 @@
 // CHECK-NEXT:                           %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
 // CHECK-NEXT:                           %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
 // CHECK-NEXT:                           %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
+// CHECK-NEXT:                           %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 // CHECK-NEXT:                           %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
 // CHECK-NEXT:                           %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
-// CHECK-NEXT:                           %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 // CHECK-NEXT:                           %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 // CHECK-NEXT:                           %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
 // CHECK-NEXT:                           %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
@@ -284,9 +284,9 @@
 // CHECK-NEXT:                           %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
 // CHECK-NEXT:                           %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
 // CHECK-NEXT:                           %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
+// CHECK-NEXT:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 // CHECK-NEXT:                           %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
 // CHECK-NEXT:                           %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
-// CHECK-NEXT:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 // CHECK-NEXT:                           %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 // CHECK-NEXT:                           %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
 // CHECK-NEXT:                           %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>