Remove EarlyBroadcastInDim pattern.
PiperOrigin-RevId: 422999899
Change-Id: Iabc6cdd2739895dcec6368c961b1d5e0ec861145
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td
index efc1d55..a6042b1 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td
@@ -164,18 +164,11 @@
let constructor = "createBroadcastPropagationPass()";
}
-// TODO(frgossen): Limit this pass to merging of assuming regions and factor out
-// broadcast propagation into its own pass.
def MergeAssumingOpsPass : FunctionPass<"mhlo-merge-assuming-ops"> {
- let summary = "Move dynamic broadcasts up over element-wise operations and "
- "broadcast the operands rather than the result. This will eventually allow "
- "for larger fusions.";
+ let summary = "Prepare moving dynamic broadcasts up over element-wise "
+ "operations and broadcast the operands rather than the result. This will "
+ "eventually allow for larger fusions.";
let constructor = "createMergeAssumingOpsPass()";
- let options = [
- Option<"propagate_broadcasts_", "propagate-broadcasts", "bool",
- /*default=*/"true",
- "Also propagate broadcasts with a pattern-based rewrite.">,
- ];
}
def GroupReductionDimensionsPass
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h
index 1fc3066..948c669 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h
@@ -80,13 +80,10 @@
// fusions.
std::unique_ptr<FunctionPass> createBroadcastPropagationPass();
-// Move dynamic broadcasts up over element-wise operations and broadcast the
-// operands rather than the result. This will eventually allow for larger
-// fusions.
-// TODO(frgossen): Limit this pass to merging of assuming regions and factor out
-// broadcast propagation into its own pass.
-std::unique_ptr<FunctionPass> createMergeAssumingOpsPass(
- bool propagate_broadcasts = true);
+// Prepare moving dynamic broadcasts up over element-wise operations and
+// broadcast the operands rather than the result. This will eventually allow for
+// larger fusions.
+std::unique_ptr<FunctionPass> createMergeAssumingOpsPass();
// Group reduction and parallel dimensions of reduction operations and realize
// them through equivalent 1D or 2D reductions, if possible.
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
index ca37c87..9cfd991 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
@@ -118,12 +118,11 @@
void PopulateTrigonometricToApproximationPatterns(
MLIRContext *context, OwningRewritePatternList *patterns);
-// Populate patterns to move dynamic broadcasts up over element-wise operations
-// and broadcast the operands rather than the result. This will eventually allow
-// for larger fusions.
+// Populate patterns to prepare moving dynamic broadcasts up over element-wise
+// operations and broadcast the operands rather than the result. This will
+// eventually allow for larger fusions.
void PopulateMergeAssumingOpsPatterns(MLIRContext *context,
- OwningRewritePatternList *patterns,
- bool propagate_broadcasts);
+ OwningRewritePatternList *patterns);
// Populate patterns to group reduction and parallecol dimensions of reduction
// operations and realize them through equivalent 1D or 2D reductions, if
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc
index 1e5b590..c282b46 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc
@@ -14,6 +14,7 @@
==============================================================================*/
+#include <memory>
#include <utility>
#include "llvm/ADT/STLExtras.h"
@@ -425,80 +426,8 @@
}
};
-struct EarlyBroadcastInDimOpPattern
- : public OpRewritePattern<DynamicBroadcastInDimOp> {
- using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op,
- PatternRewriter &rewriter) const override {
- // Dynamic broadcasting only works on ranked tensors.
- auto result_ty = bcast_op.getType().dyn_cast<RankedTensorType>();
- if (!result_ty) return failure();
-
- // Find producer op.
- Operation *producer_op = bcast_op.operand().getDefiningOp();
- if (!producer_op) return failure();
- auto producer_result_ty =
- producer_op->getResultTypes().front().dyn_cast<RankedTensorType>();
- if (!producer_result_ty) return failure();
-
- // Only apply to broadcasting elementwise producers.
- bool is_broadcasting_elementwise =
- (producer_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() &&
- producer_op->hasTrait<mlir::OpTrait::Elementwise>()) ||
- producer_op->hasTrait<mlir::mhlo::OpTrait::BroadcastingElementwise>();
- if (!is_broadcasting_elementwise) return failure();
-
- // Materialize broadcast on operands.
- SmallVector<Value, 2> bcasted_operands;
- Location loc = bcast_op.getLoc();
- ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
- for (Value producer_operand : producer_op->getOperands()) {
- // The broadcast only works on ranked operations.
- auto producer_operand_ty =
- producer_operand.getType().dyn_cast<RankedTensorType>();
- if (!producer_operand_ty) {
- return bcast_op.emitError()
- << "Can only move up broadcasts over ranked tensor operands.";
- }
-
- // Materialize dynamic broadcast. The operand shape is either the same as
- // the result shape and we can reuse the broadcast dimensions, or it is a
- // scalar and we can create empty broadcast dimensions.
- assert((producer_operand_ty.getRank() == 0 ||
- producer_operand_ty.getRank() == producer_result_ty.getRank()) &&
- "expect scalar or same shape");
- auto bcast_dims = producer_operand_ty.getRank() == 0
- ? rewriter.getI64TensorAttr({})
- : bcast_op.broadcast_dimensions();
- auto bcasted_operand_ty =
- RankedTensorType::get(ty_shape, producer_operand_ty.getElementType());
- bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
- loc, bcasted_operand_ty, producer_operand,
- bcast_op.output_dimensions(), bcast_dims));
- }
-
- // Create a copy of the producer op with the new broadcasted operands.
- OperationState new_producer_op_state(
- loc, producer_op->getName().getStringRef(), bcasted_operands, result_ty,
- producer_op->getAttrs());
- Operation *new_producer_op =
- rewriter.createOperation(new_producer_op_state);
-
- // The original result of the broadcast now falls directly out of the new
- // producer op. Use it instead.
- rewriter.replaceOp(bcast_op, new_producer_op->getResults());
-
- return success();
- }
-};
-
struct MergeAssumingOpsPass
: public MergeAssumingOpsPassBase<MergeAssumingOpsPass> {
- explicit MergeAssumingOpsPass(bool propagate_broadcasts) {
- propagate_broadcasts_ = propagate_broadcasts;
- }
-
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
}
@@ -506,8 +435,7 @@
void runOnFunction() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
- mhlo::PopulateMergeAssumingOpsPatterns(ctx, &patterns,
- propagate_broadcasts_);
+ mhlo::PopulateMergeAssumingOpsPatterns(ctx, &patterns);
GreedyRewriteConfig config;
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns),
@@ -520,8 +448,7 @@
} // namespace
void PopulateMergeAssumingOpsPatterns(MLIRContext *context,
- OwningRewritePatternList *patterns,
- bool propagate_broadcasts) {
+ OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<
EliminateDuplicateCstrBroadcastableOps,
@@ -537,8 +464,6 @@
MoveUpOutOfAssumingOpPattern<shape::ShapeOfOp>,
ShapeReificationPattern>(context);
// clang-format on
- if (propagate_broadcasts)
- patterns->insert<EarlyBroadcastInDimOpPattern>(context);
mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
context);
mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context);
@@ -549,9 +474,8 @@
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
}
-std::unique_ptr<FunctionPass> createMergeAssumingOpsPass(
- bool propagate_broadcasts) {
- return std::make_unique<MergeAssumingOpsPass>(propagate_broadcasts);
+std::unique_ptr<FunctionPass> createMergeAssumingOpsPass() {
+ return std::make_unique<MergeAssumingOpsPass>();
}
} // namespace mhlo
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir
index 56caaf7..4f990a7 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir
@@ -2,11 +2,6 @@
// RUN: --mhlo-merge-assuming-ops --canonicalize --cse %s | \
// RUN: FileCheck %s
-// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect \
-// RUN: --mhlo-merge-assuming-ops="propagate-broadcasts=false" \
-// RUN: --canonicalize --cse %s | \
-// RUN: FileCheck --check-prefix=CHECK-NOPROP %s
-
// Shape computations shall be reified.
// CHECK-LABEL: @shape_of_unary
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>)
@@ -36,82 +31,6 @@
// -----
-// Broadcasts can be moved up over unary shape-preserving operations.
-// CHECK-LABEL: @bcast_unary
-// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
-func @bcast_unary(%arg : tensor<?x32xi16>, %out_dims : tensor<3xindex>)
- -> tensor<?x?x32xf16> {
- // CHECK: %[[BCASTED_OPERAND:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[OUT_DIMS]])
- // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xi16>, tensor<3xindex>) -> tensor<?x?x32xi16>
- // CHECK: "mhlo.convert"(%[[BCASTED_OPERAND]]) : (tensor<?x?x32xi16>) -> tensor<?x?x32xf16>
- %0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16>
- %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) {
- broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } :
- (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
- return %1 : tensor<?x?x32xf16>
-}
-
-// CHECK-NOPROP-LABEL: @bcast_unary
-// CHECK-NOPROP-SAME: (%[[ARG:.*]]: tensor<?x32xi16>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
-
-// CHECK-NOPROP: %[[TMP:.*]] = "mhlo.convert"(%[[ARG]])
-// CHECK-NOPROP: %[[RES:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TMP]], %[[OUT_DIMS]])
-// return %[[RES]]
-
-// -----
-
-// Broadcasts can be moved up over n-ary shape-preserving operations.
-// CHECK-LABEL: @bcast_nary
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf32>, %[[ARG1:.*]]: tensor<?x32xf32>, %[[OUT_DIMS:.*]]: tensor<3xindex>)
-func @bcast_nary(%arg0 : tensor<?x32xf32>, %arg1 : tensor<?x32xf32>,
- %out_dims : tensor<3xindex>) -> tensor<?x?x32xf32> {
- // CHECK-NOT: subtract
- // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[OUT_DIMS]])
- // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[OUT_DIMS]])
- // CHECK: %{{.*}} = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] : tensor<?x?x32xf32>
- %0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf32>
- %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) {
- broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } :
- (tensor<?x32xf32>, tensor<3xindex>) -> tensor<?x?x32xf32>
- return %1 : tensor<?x?x32xf32>
-}
-
-// -----
-
-// Exemplary IR as it appears in the lowering with `tf.Sub` and `tf.Cast`.
-// CHECK-LABEL: @cast_sub
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xi16>, %[[ARG1:.*]]: tensor<?x?x32xf16>) -> tensor<?x?x32xf16>
-func @cast_sub(%arg0: tensor<?x32xi16>, %arg1: tensor<?x?x32xf16>)
- -> tensor<?x?x32xf16> {
- // CHECK-NOT: convert
- // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %{{.*}})
- // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %{{.*}})
- // CHECK: %[[CONVERTED_BCASTED_ARG0:.*]] = "mhlo.convert"(%[[BCASTED_ARG0]]) : (tensor<?x?x32xi16>) -> tensor<?x?x32xf16>
- // CHECK: %{{.*}} = mhlo.subtract %[[BCASTED_ARG1]], %[[CONVERTED_BCASTED_ARG0]] : tensor<?x?x32xf16>
- %0 = "mhlo.convert"(%arg0) : (tensor<?x32xi16>) -> tensor<?x32xf16>
- %1 = shape.shape_of %arg1 : tensor<?x?x32xf16> -> tensor<?xindex>
- %2 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
- %3 = shape.cstr_broadcastable %1, %2 : tensor<?xindex>, tensor<?xindex>
- %4 = shape.assuming %3 -> (tensor<?x?x32xf16>) {
- %5 = shape.shape_of %arg1 : tensor<?x?x32xf16> -> tensor<?xindex>
- %6 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex>
- %7 = shape.broadcast %5, %6 : tensor<?xindex>, tensor<?xindex>
- -> tensor<?xindex>
- %8 = tensor.cast %7 : tensor<?xindex> to tensor<3xindex>
- %9 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %8) {
- broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} :
- (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
- %10 = "mhlo.dynamic_broadcast_in_dim"(%0, %8) {
- broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} :
- (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
- %11 = mhlo.subtract %9, %10 : tensor<?x?x32xf16>
- shape.assuming_yield %11 : tensor<?x?x32xf16>
- }
- return %4 : tensor<?x?x32xf16>
-}
-
-// -----
-
// CHECK-LABEL: @inline_bcasted_shape_operands
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>)
func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
@@ -364,58 +283,6 @@
// -----
-// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
-// CHECK-LABEL: @sub_sub
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
-func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
- %arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
- // CHECK-DAG: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
- // CHECK-DAG: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
- // CHECK-DAG: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
- // CHECK-DAG: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
- // CHECK-DAG: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]]
- // CHECK-DAG: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
- // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]]
- // CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
- // CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
- // CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
- // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
- // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
- // CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
- // CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
- // CHECK: shape.assuming_yield %[[RESULT]]
- // CHECK: return %[[ASSUMING_RESULT]]
- %0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
- %1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
- %2 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex>
- %3 = shape.assuming %2 -> (tensor<?x32xf16>) {
- %8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
- %9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
- %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<?xindex>
- %11 = tensor.cast %10 : tensor<?xindex> to tensor<2xindex>
- %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
- %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
- %14 = mhlo.subtract %12, %13 : tensor<?x32xf16>
- shape.assuming_yield %14 : tensor<?x32xf16>
- }
- %4 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
- %5 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
- %6 = shape.cstr_broadcastable %4, %5 : tensor<3xindex>, tensor<2xindex>
- %7 = shape.assuming %6 -> (tensor<?x?x32xf16>) {
- %8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
- %9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
- %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
- %11 = tensor.cast %10 : tensor<?xindex> to tensor<3xindex>
- %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %11) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
- %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %11) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
- %14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16>
- shape.assuming_yield %14 : tensor<?x?x32xf16>
- }
- return %7 : tensor<?x?x32xf16>
-}
-
-// -----
-
// CHECK-LABEL: @redundant_cstr_broadcastable
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>)
func @redundant_cstr_broadcastable(%arg0: tensor<?xindex>,
@@ -468,28 +335,6 @@
// -----
-// CHECK-LABEL: @bcast_select_scalar_pred
-// CHECK-SAME: %[[PRED:.*]]: tensor<i1>, %[[LHS:.*]]: tensor<?x?xf32>, %[[RHS:.*]]: tensor<?x?xf32>, %[[SHAPE:.*]]: tensor<2xindex>
-func @bcast_select_scalar_pred(%pred : tensor<i1>, %arg0 : tensor<?x?xf32>,
- %arg1 : tensor<?x?xf32>, %shape : tensor<2xindex>) -> tensor<?x?xf32> {
- // CHECK: %[[BCASTED_PRED:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[PRED]], %[[SHAPE]])
- // CHECK-SAME: broadcast_dimensions = dense<>
- // CHECK: %[[BCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[LHS]], %[[SHAPE]])
- // CHECK-SAME: broadcast_dimensions = dense<[0, 1]>
- // CHECK: %[[BCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RHS]], %[[SHAPE]])
- // CHECK-SAME: broadcast_dimensions = dense<[0, 1]>
- // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[BCASTED_PRED]], %[[BCASTED_LHS]], %[[BCASTED_RHS]])
- // CHECK: return %[[RESULT]]
- %0 = "mhlo.select"(%pred, %arg0, %arg1)
- : (tensor<i1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape)
- { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> }
- : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
-}
-
-// -----
-
// CHECK-LABEL: @move_down_into_assuming
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>, %[[W:.*]]: !shape.witness)
func @move_down_into_assuming(%arg0: tensor<?x32xi16>, %w: !shape.witness) -> tensor<?x32xf16> {
diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
index ad76f81..c386987 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
@@ -127,8 +127,7 @@
// Add the broadcast propagation pass first, because it can help to avoid
// exponential complexity from the EarlyBroadcastInDimOp pattern which is used
// in the merge assuming ops pass further down.
- pm.addNestedPass<FuncOp>(
- mlir::mhlo::createMergeAssumingOpsPass(/*propagate_broadcasts=*/false));
+ pm.addNestedPass<FuncOp>(mlir::mhlo::createMergeAssumingOpsPass());
pm.addNestedPass<FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createCanonicalizerPass());
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
index 79228c0..b741949 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
@@ -246,8 +246,7 @@
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateShapeSimplification());
- pm.addNestedPass<mlir::FuncOp>(
- mlir::mhlo::createMergeAssumingOpsPass(/*propagate_broadcasts=*/false));
+ pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createMergeAssumingOpsPass());
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());