[MLIR:TF] Fold no-op broadcast to operations
PiperOrigin-RevId: 327184432
Change-Id: If324088ca9f7b09c776d5f524476b2c0a0085034
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 00e9fdd..8946faf 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -1353,6 +1353,7 @@
let verifier = [{
return Verify(*this);
}];
+ let hasFolder = 1;
}
def TF_BucketizeOp : TF_Op<"Bucketize", [NoSideEffect, SameOperandsAndResultShape]> {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index f3dfc15..bc38e67 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -440,6 +440,19 @@
return success();
}
+OpFoldResult BroadcastToOp::fold(ArrayRef<Attribute> operands) {
+ Value input = this->input();
+
+ // Fold broadcast if operand and result types are the same and all dimensions
+ // are statically known (no-op broadcast).
+ auto result_ty = getType().dyn_cast<ShapedType>();
+ if (result_ty && result_ty.hasStaticShape() && result_ty == input.getType()) {
+ return input;
+ }
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// CaseOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index 887473e..737665d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -707,7 +707,6 @@
// Fold reshape if operand and result types are the same and all dimensions
// are statically known (no-op reshape).
- // TODO(ezhulenev): Add the same folding for BroadcastToOp.
auto result_ty = getType().dyn_cast<ShapedType>();
if (result_ty && result_ty.hasStaticShape() &&
result_ty == tensor.getType()) {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index af57794..3bfc388 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -444,6 +444,14 @@
return %0 : tensor<2x4xf32>
}
+// CHECK-LABEL: func @testBroadcastToNoOp
+func @testBroadcastToNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x4xf32> {
+ %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2xi32>) -> tensor<2x4xf32>
+
+ // CHECK: return %arg0
+ return %0 : tensor<2x4xf32>
+}
+
// CHECK-LABEL: func @testPackShapeComputation
func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
// Test dimensions sizes.
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir
index 69eaeeb..cffb150 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir
@@ -17,9 +17,7 @@
// CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]]
// CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
// CHECK: [[RHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[RHSTAIL]]
-// CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]]
-// CHECK: [[RHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
-// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
+// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHS]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: return [[RESULT]] : tensor<3x4x4xf32>
// CHECK: }
@@ -29,7 +27,6 @@
func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> {
// CHECK-LABEL: func @batchmatmulv2_lhs_batch
-// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
@@ -43,7 +40,6 @@
func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
// CHECK-LABEL: func @batchmatmulv2_rhs_batch
// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
-// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,