[MLIR:TF] Fold no-op broadcast to operations

PiperOrigin-RevId: 327184432
Change-Id: If324088ca9f7b09c776d5f524476b2c0a0085034
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 00e9fdd..8946faf 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -1353,6 +1353,7 @@
   let verifier = [{
     return Verify(*this);
   }];
+  let hasFolder = 1;
 }
 
 def TF_BucketizeOp : TF_Op<"Bucketize", [NoSideEffect, SameOperandsAndResultShape]> {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index f3dfc15..bc38e67 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -440,6 +440,19 @@
   return success();
 }
 
+OpFoldResult BroadcastToOp::fold(ArrayRef<Attribute> operands) {
+  Value input = this->input();
+
+  // Fold broadcast if operand and result types are the same and all dimensions
+  // are statically known (no-op broadcast).
+  auto result_ty = getType().dyn_cast<ShapedType>();
+  if (result_ty && result_ty.hasStaticShape() && result_ty == input.getType()) {
+    return input;
+  }
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // CaseOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index 887473e..737665d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -707,7 +707,6 @@
 
   // Fold reshape if operand and result types are the same and all dimensions
   // are statically known (no-op reshape).
-  // TODO(ezhulenev): Add the same folding for BroadcastToOp.
   auto result_ty = getType().dyn_cast<ShapedType>();
   if (result_ty && result_ty.hasStaticShape() &&
       result_ty == tensor.getType()) {
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index af57794..3bfc388 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -444,6 +444,14 @@
   return %0 : tensor<2x4xf32>
 }
 
+// CHECK-LABEL: func @testBroadcastToNoOp
+func @testBroadcastToNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x4xf32> {
+  %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2xi32>) -> tensor<2x4xf32>
+
+  // CHECK: return %arg0
+  return %0 : tensor<2x4xf32>
+}
+
 // CHECK-LABEL: func @testPackShapeComputation
 func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>,  tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
   // Test dimensions sizes.
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir
index 69eaeeb..cffb150 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir
@@ -17,9 +17,7 @@
 // CHECK:           [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]]
 // CHECK:           [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
 // CHECK:           [[RHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[RHSTAIL]]
-// CHECK:           [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]]
-// CHECK:           [[RHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
-// CHECK:           [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
+// CHECK:           [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHS]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
 // CHECK:           return [[RESULT]] : tensor<3x4x4xf32>
 // CHECK:         }
 
@@ -29,7 +27,6 @@
 
 func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> {
 // CHECK-LABEL:   func @batchmatmulv2_lhs_batch
-// CHECK:           "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
 // CHECK:           "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
 // CHECK:           "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
 // CHECK-SAME:        lhs_batching_dimensions = dense<0> : tensor<1xi64>,
@@ -43,7 +40,6 @@
 func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
 // CHECK-LABEL:   func @batchmatmulv2_rhs_batch
 // CHECK:           "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
-// CHECK:           "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
 // CHECK:           "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
 // CHECK-SAME:        lhs_batching_dimensions = dense<0> : tensor<1xi64>,
 // CHECK-SAME:        lhs_contracting_dimensions = dense<2> : tensor<1xi64>,