Remove EarlyBroadcastInDim pattern.

PiperOrigin-RevId: 422999899
Change-Id: Iabc6cdd2739895dcec6368c961b1d5e0ec861145
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td
index efc1d55..a6042b1 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td
@@ -164,18 +164,11 @@
   let constructor = "createBroadcastPropagationPass()";
 }
 
-// TODO(frgossen): Limit this pass to merging of assuming regions and factor out
-// broadcast propagation into its own pass.
 def MergeAssumingOpsPass : FunctionPass<"mhlo-merge-assuming-ops"> {
-  let summary = "Move dynamic broadcasts up over element-wise operations and "
-    "broadcast the operands rather than the result. This will eventually allow "
-    "for larger fusions.";
+  let summary = "Prepare moving dynamic broadcasts up over element-wise "
+    "operations and broadcast the operands rather than the result. This will "
+    "eventually allow for larger fusions.";
   let constructor = "createMergeAssumingOpsPass()";
-  let options = [
-      Option<"propagate_broadcasts_", "propagate-broadcasts", "bool",
-             /*default=*/"true",
-             "Also propagate broadcasts with a pattern-based rewrite.">,
-  ];
 }
 
 def GroupReductionDimensionsPass
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h
index 1fc3066..948c669 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h
@@ -80,13 +80,10 @@
 // fusions.
 std::unique_ptr<FunctionPass> createBroadcastPropagationPass();
 
-// Move dynamic broadcasts up over element-wise operations and broadcast the
-// operands rather than the result. This will eventually allow for larger
-// fusions.
-// TODO(frgossen): Limit this pass to merging of assuming regions and factor out
-// broadcast propagation into its own pass.
-std::unique_ptr<FunctionPass> createMergeAssumingOpsPass(
-    bool propagate_broadcasts = true);
+// Prepare moving dynamic broadcasts up over element-wise operations and
+// broadcast the operands rather than the result. This will eventually allow for
+// larger fusions.
+std::unique_ptr<FunctionPass> createMergeAssumingOpsPass();
 
 // Group reduction and parallel dimensions of reduction operations and realize
 // them through equivalent 1D or 2D reductions, if possible.
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
index ca37c87..9cfd991 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
@@ -118,12 +118,11 @@
 void PopulateTrigonometricToApproximationPatterns(
     MLIRContext *context, OwningRewritePatternList *patterns);
 
-// Populate patterns to move dynamic broadcasts up over element-wise operations
-// and broadcast the operands rather than the result. This will eventually allow
-// for larger fusions.
+// Populate patterns to prepare moving dynamic broadcasts up over element-wise
+// operations and broadcast the operands rather than the result. This will
+// eventually allow for larger fusions.
 void PopulateMergeAssumingOpsPatterns(MLIRContext *context,
-                                      OwningRewritePatternList *patterns,
-                                      bool propagate_broadcasts);
+                                      OwningRewritePatternList *patterns);
 
 // Populate patterns to group reduction and parallecol dimensions of reduction
 // operations and realize them through equivalent 1D or 2D reductions, if
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc
index 1e5b590..c282b46 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc
@@ -14,6 +14,7 @@
 
 ==============================================================================*/
 
+#include <memory>
 #include <utility>
 
 #include "llvm/ADT/STLExtras.h"
@@ -425,80 +426,8 @@
   }
 };
 
-struct EarlyBroadcastInDimOpPattern
-    : public OpRewritePattern<DynamicBroadcastInDimOp> {
-  using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op,
-                                PatternRewriter &rewriter) const override {
-    // Dynamic broadcasting only works on ranked tensors.
-    auto result_ty = bcast_op.getType().dyn_cast<RankedTensorType>();
-    if (!result_ty) return failure();
-
-    // Find producer op.
-    Operation *producer_op = bcast_op.operand().getDefiningOp();
-    if (!producer_op) return failure();
-    auto producer_result_ty =
-        producer_op->getResultTypes().front().dyn_cast<RankedTensorType>();
-    if (!producer_result_ty) return failure();
-
-    // Only apply to broadcasting elementwise producers.
-    bool is_broadcasting_elementwise =
-        (producer_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() &&
-         producer_op->hasTrait<mlir::OpTrait::Elementwise>()) ||
-        producer_op->hasTrait<mlir::mhlo::OpTrait::BroadcastingElementwise>();
-    if (!is_broadcasting_elementwise) return failure();
-
-    // Materialize broadcast on operands.
-    SmallVector<Value, 2> bcasted_operands;
-    Location loc = bcast_op.getLoc();
-    ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
-    for (Value producer_operand : producer_op->getOperands()) {
-      // The broadcast only works on ranked operations.
-      auto producer_operand_ty =
-          producer_operand.getType().dyn_cast<RankedTensorType>();
-      if (!producer_operand_ty) {
-        return bcast_op.emitError()
-               << "Can only move up broadcasts over ranked tensor operands.";
-      }
-
-      // Materialize dynamic broadcast. The operand shape is either the same as
-      // the result shape and we can reuse the broadcast dimensions, or it is a
-      // scalar and we can create empty broadcast dimensions.
-      assert((producer_operand_ty.getRank() == 0 ||
-              producer_operand_ty.getRank() == producer_result_ty.getRank()) &&
-             "expect scalar or same shape");
-      auto bcast_dims = producer_operand_ty.getRank() == 0
-                            ? rewriter.getI64TensorAttr({})
-                            : bcast_op.broadcast_dimensions();
-      auto bcasted_operand_ty =
-          RankedTensorType::get(ty_shape, producer_operand_ty.getElementType());
-      bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
-          loc, bcasted_operand_ty, producer_operand,
-          bcast_op.output_dimensions(), bcast_dims));
-    }
-
-    // Create a copy of the producer op with the new broadcasted operands.
-    OperationState new_producer_op_state(
-        loc, producer_op->getName().getStringRef(), bcasted_operands, result_ty,
-        producer_op->getAttrs());
-    Operation *new_producer_op =
-        rewriter.createOperation(new_producer_op_state);
-
-    // The original result of the broadcast now falls directly out of the new
-    // producer op. Use it instead.
-    rewriter.replaceOp(bcast_op, new_producer_op->getResults());
-
-    return success();
-  }
-};
-
 struct MergeAssumingOpsPass
     : public MergeAssumingOpsPassBase<MergeAssumingOpsPass> {
-  explicit MergeAssumingOpsPass(bool propagate_broadcasts) {
-    propagate_broadcasts_ = propagate_broadcasts;
-  }
-
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
   }
@@ -506,8 +435,7 @@
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    mhlo::PopulateMergeAssumingOpsPatterns(ctx, &patterns,
-                                           propagate_broadcasts_);
+    mhlo::PopulateMergeAssumingOpsPatterns(ctx, &patterns);
     GreedyRewriteConfig config;
     config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
     if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns),
@@ -520,8 +448,7 @@
 }  // namespace
 
 void PopulateMergeAssumingOpsPatterns(MLIRContext *context,
-                                      OwningRewritePatternList *patterns,
-                                      bool propagate_broadcasts) {
+                                      OwningRewritePatternList *patterns) {
   // clang-format off
   patterns->insert<
       EliminateDuplicateCstrBroadcastableOps,
@@ -537,8 +464,6 @@
       MoveUpOutOfAssumingOpPattern<shape::ShapeOfOp>,
       ShapeReificationPattern>(context);
   // clang-format on
-  if (propagate_broadcasts)
-    patterns->insert<EarlyBroadcastInDimOpPattern>(context);
   mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
                                                              context);
   mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context);
@@ -549,9 +474,8 @@
   tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
 }
 
-std::unique_ptr<FunctionPass> createMergeAssumingOpsPass(
-    bool propagate_broadcasts) {
-  return std::make_unique<MergeAssumingOpsPass>(propagate_broadcasts);
+std::unique_ptr<FunctionPass> createMergeAssumingOpsPass() {
+  return std::make_unique<MergeAssumingOpsPass>();
 }
 
 }  // namespace mhlo
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir
index 56caaf7..4f990a7 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir
@@ -2,11 +2,6 @@
 // RUN:   --mhlo-merge-assuming-ops --canonicalize --cse %s | \
 // RUN: FileCheck %s
 
-// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect \
-// RUN:   --mhlo-merge-assuming-ops="propagate-broadcasts=false" \
-// RUN:   --canonicalize --cse %s | \
-// RUN: FileCheck --check-prefix=CHECK-NOPROP %s
-
 // Shape computations shall be reified.
 // CHECK-LABEL: @shape_of_unary
 // CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
@@ -36,82 +31,6 @@
 
 // -----
 
-// Broadcasts can be moved up over unary shape-preserving operations.
-// CHECK-LABEL: @bcast_unary
-// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
-func @bcast_unary(%arg : tensor<?x32xi16>, %out_dims : tensor<3xindex>)
-    -> tensor<?x?x32xf16> {
-  // CHECK:      %[[BCASTED_OPERAND:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[OUT_DIMS]])
-  // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xi16>, tensor<3xindex>) -> tensor<?x?x32xi16>
-  // CHECK:      "mhlo.convert"(%[[BCASTED_OPERAND]]) : (tensor<?x?x32xi16>) -> tensor<?x?x32xf16>
-  %0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
-  %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) {
-      broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } :
-      (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
-  return %1 : tensor<?x?x32xf16>
-}
-
-// CHECK-NOPROP-LABEL: @bcast_unary
-// CHECK-NOPROP-SAME: (%[[ARG:.*]]: tensor<?x32xi16>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
-
-// CHECK-NOPROP: %[[TMP:.*]] = "mhlo.convert"(%[[ARG]])
-// CHECK-NOPROP: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP]], %[[OUT_DIMS]])
-// return %[[RES]]
-
-// -----
-
-// Broadcasts can be moved up over n-ary shape-preserving operations.
-// CHECK-LABEL: @bcast_nary
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf32>, %[[ARG1:.*]]: tensor<?x32xf32>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
-func @bcast_nary(%arg0 : tensor<?x32xf32>, %arg1 : tensor<?x32xf32>,
-    %out_dims : tensor<3xindex>) -> tensor<?x?x32xf32> {
-  // CHECK-NOT: subtract
-  // CHECK:     %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[OUT_DIMS]])
-  // CHECK:     %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[OUT_DIMS]])
-  // CHECK:     %{{.*}} = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] : tensor<?x?x32xf32>
-  %0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf32>
-  %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) {
-      broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } :
-      (tensor<?x32xf32>, tensor<3xindex>) -> tensor<?x?x32xf32>
-  return %1 : tensor<?x?x32xf32>
-}
-
-// -----
-
-// Exemplary IR as it appears in the lowering with `tf.Sub` and `tf.Cast`.
-// CHECK-LABEL: @cast_sub
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xi16>, %[[ARG1:.*]]: tensor<?x?x32xf16>) -> tensor<?x?x32xf16>
-func @cast_sub(%arg0: tensor<?x32xi16>, %arg1: tensor<?x?x32xf16>)
-    -> tensor<?x?x32xf16> {
-  // CHECK-NOT: convert
-  // CHECK:     %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %{{.*}})
-  // CHECK:     %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %{{.*}})
-  // CHECK:     %[[CONVERTED_BCASTED_ARG0:.*]] = "mhlo.convert"(%[[BCASTED_ARG0]]) : (tensor<?x?x32xi16>) -> tensor<?x?x32xf16>
-  // CHECK:     %{{.*}} = mhlo.subtract %[[BCASTED_ARG1]], %[[CONVERTED_BCASTED_ARG0]] : tensor<?x?x32xf16>
-  %0 = "mhlo.convert"(%arg0) : (tensor<?x32xi16>) -> tensor<?x32xf16>
-  %1 = shape.shape_of %arg1 : tensor<?x?x32xf16> -> tensor<?xindex>
-  %2 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
-  %3 = shape.cstr_broadcastable %1, %2 : tensor<?xindex>, tensor<?xindex>
-  %4 = shape.assuming %3 -> (tensor<?x?x32xf16>) {
-    %5 = shape.shape_of %arg1 : tensor<?x?x32xf16> -> tensor<?xindex>
-    %6 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
-    %7 = shape.broadcast %5, %6 : tensor<?xindex>, tensor<?xindex>
-        -> tensor<?xindex>
-    %8 = tensor.cast %7 : tensor<?xindex> to tensor<3xindex>
-    %9 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %8) {
-        broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} :
-        (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
-    %10 = "mhlo.dynamic_broadcast_in_dim"(%0, %8) {
-        broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} :
-        (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
-    %11 = mhlo.subtract %9, %10 : tensor<?x?x32xf16>
-    shape.assuming_yield %11 : tensor<?x?x32xf16>
-  }
-  return %4 : tensor<?x?x32xf16>
-}
-
-// -----
-
 // CHECK-LABEL: @inline_bcasted_shape_operands
 // CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>)
 func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
@@ -364,58 +283,6 @@
 
 // -----
 
-// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
-// CHECK-LABEL: @sub_sub
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
-func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
-    %arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
-  // CHECK-DAG:  %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
-  // CHECK-DAG:  %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
-  // CHECK-DAG:  %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
-  // CHECK-DAG:  %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
-  // CHECK-DAG:  %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]]
-  // CHECK-DAG:  %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
-  // CHECK:      %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]]
-  // CHECK:        %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
-  // CHECK:        %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
-  // CHECK:        %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
-  // CHECK:        %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
-  // CHECK:        %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
-  // CHECK:        %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
-  // CHECK:        %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
-  // CHECK:        shape.assuming_yield %[[RESULT]]
-  // CHECK:      return %[[ASSUMING_RESULT]]
-  %0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
-  %1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
-  %2 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex>
-  %3 = shape.assuming %2 -> (tensor<?x32xf16>) {
-    %8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
-    %9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
-    %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<?xindex>
-    %11 = tensor.cast %10 : tensor<?xindex> to tensor<2xindex>
-    %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
-    %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
-    %14 = mhlo.subtract %12, %13 : tensor<?x32xf16>
-    shape.assuming_yield %14 : tensor<?x32xf16>
-  }
-  %4 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
-  %5 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
-  %6 = shape.cstr_broadcastable %4, %5 : tensor<3xindex>, tensor<2xindex>
-  %7 = shape.assuming %6 -> (tensor<?x?x32xf16>) {
-    %8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
-    %9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
-    %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
-    %11 = tensor.cast %10 : tensor<?xindex> to tensor<3xindex>
-    %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %11) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
-    %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %11) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
-    %14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16>
-    shape.assuming_yield %14 : tensor<?x?x32xf16>
-  }
-  return %7 : tensor<?x?x32xf16>
-}
-
-// -----
-
 // CHECK-LABEL: @redundant_cstr_broadcastable
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>)
 func @redundant_cstr_broadcastable(%arg0: tensor<?xindex>,
@@ -468,28 +335,6 @@
 
 // -----
 
-// CHECK-LABEL: @bcast_select_scalar_pred
-// CHECK-SAME:  %[[PRED:.*]]: tensor<i1>, %[[LHS:.*]]: tensor<?x?xf32>, %[[RHS:.*]]: tensor<?x?xf32>, %[[SHAPE:.*]]: tensor<2xindex>
-func @bcast_select_scalar_pred(%pred : tensor<i1>, %arg0 : tensor<?x?xf32>,
-    %arg1 : tensor<?x?xf32>, %shape : tensor<2xindex>) -> tensor<?x?xf32> {
-  // CHECK:      %[[BCASTED_PRED:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PRED]], %[[SHAPE]])
-  // CHECK-SAME:   broadcast_dimensions = dense<>
-  // CHECK:      %[[BCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[LHS]], %[[SHAPE]])
-  // CHECK-SAME:   broadcast_dimensions = dense<[0, 1]>
-  // CHECK:      %[[BCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RHS]], %[[SHAPE]])
-  // CHECK-SAME:   broadcast_dimensions = dense<[0, 1]>
-  // CHECK:      %[[RESULT:.*]] = "mhlo.select"(%[[BCASTED_PRED]], %[[BCASTED_LHS]], %[[BCASTED_RHS]])
-  // CHECK:      return %[[RESULT]]
-  %0 = "mhlo.select"(%pred, %arg0, %arg1)
-      : (tensor<i1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
-  %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape)
-      { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> }
-      : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-
-// -----
-
 // CHECK-LABEL: @move_down_into_assuming
 // CHECK-SAME:  (%[[ARG:.*]]: tensor<?x32xi16>, %[[W:.*]]: !shape.witness)
 func @move_down_into_assuming(%arg0: tensor<?x32xi16>, %w: !shape.witness) -> tensor<?x32xf16> {
diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
index ad76f81..c386987 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
@@ -127,8 +127,7 @@
   // Add the broadcast propagation pass first, because it can help to avoid
   // exponential complexity from the EarlyBroadcastInDimOp pattern which is used
   // in the merge assuming ops pass further down.
-  pm.addNestedPass<FuncOp>(
-      mlir::mhlo::createMergeAssumingOpsPass(/*propagate_broadcasts=*/false));
+  pm.addNestedPass<FuncOp>(mlir::mhlo::createMergeAssumingOpsPass());
   pm.addNestedPass<FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
   pm.addPass(mlir::createCSEPass());
   pm.addPass(mlir::createCanonicalizerPass());
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
index 79228c0..b741949 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
@@ -246,8 +246,7 @@
   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
   pm.addNestedPass<mlir::FuncOp>(
       mlir::kernel_gen::transforms::CreateShapeSimplification());
-  pm.addNestedPass<mlir::FuncOp>(
-      mlir::mhlo::createMergeAssumingOpsPass(/*propagate_broadcasts=*/false));
+  pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createMergeAssumingOpsPass());
   pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
   pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());