[tfrt:tf] Update shape optimization pass to handle const shapes

- fix a bug in shape constraints optimization
- add an option to optimize only constraints, because it allows more efficient mhlo broadcasts movement

PiperOrigin-RevId: 392166188
Change-Id: I01a653e316c3a05fed094fb47f5c33b31b160299
diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_pipeline.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_pipeline.cc
index 491bb6c..16b2a29 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_pipeline.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_pipeline.cc
@@ -74,7 +74,8 @@
   // Resolve all shape constraints (e.g. broadcast constraints that can be
   // proved statically and changed to const witness) early to allow more
   // efficient broadcast operations moving.
-  pm.addNestedPass<mlir::FuncOp>(CreateSymbolicShapeOptimizationPass());
+  pm.addNestedPass<mlir::FuncOp>(
+      CreateSymbolicShapeOptimizationPass(/*constraints_only=*/true));
 
   // Move up broadcasting operations to allow for more fusion opportunities.
   pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_passes.h b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_passes.h
index 805493f..8cdd15e 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_passes.h
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_passes.h
@@ -45,7 +45,8 @@
 std::unique_ptr<mlir::FunctionPass> CreateFissionPass();
 
 // Pass to optimize broadcasts based on the symbolic shape constraints.
-std::unique_ptr<mlir::FunctionPass> CreateSymbolicShapeOptimizationPass();
+std::unique_ptr<mlir::FunctionPass> CreateSymbolicShapeOptimizationPass(
+    bool constraints_only = false);
 
 // Creates `tf_device.cluster` operations according to the TF CPURT clustering
 // policy.
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_passes.td b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_passes.td
index 0f2f4cf..652f532 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_passes.td
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_passes.td
@@ -86,18 +86,27 @@
   ];
 }
 
-def SymbolicShapeOptimization : FunctionPass<"tf-cpurt-symbolic-shape-optimization"> {
+def SymbolicShapeOptimization
+    : FunctionPass<"tf-cpurt-symbolic-shape-optimization"> {
   let summary = "Optimizes broadcasts based on the symbolic shapes";
   let constructor = "tensorflow::CreateSymbolicShapeOptimizationPass()";
   let description = [{
-    A simple pass that rewrites mhlo.broadcast_in_dim operations with
-    linalg.generic broadcasts using the symbolic shape attributes defined
-    on the entrypoint function arguments.
+    A simple pass that replaces shape constraints with const witnesses and
+    rewrites mhlo.broadcast_in_dim operations with linalg.generic broadcasts
+    using the symbolic shape attributes defined on the entrypoint function
+    arguments.
   }];
   let dependentDialects = [
     "mlir::mhlo::MhloDialect",
     "mlir::linalg::LinalgDialect"
   ];
+
+  let options = [
+   Option<"optimize_only_constraints", "optimize-only-constraints",
+          "bool", /*default=*/"false",
+          "Optimize only shape constraints and do not touch broadcasts.">,
+
+  ];
 }
 
 def Clustering : FunctionPass<"tf-cpurt-clustering"> {
diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_symbolic_shape_optimization.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_symbolic_shape_optimization.cc
index ca46399..0bf2334 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_symbolic_shape_optimization.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_cpurt_symbolic_shape_optimization.cc
@@ -47,6 +47,7 @@
 using mlir::AffineExpr;
 using mlir::AffineMap;
 using mlir::ConstantIndexOp;
+using mlir::ConstantOp;
 using mlir::DenseIntElementsAttr;
 using mlir::dyn_cast;
 using mlir::dyn_cast_or_null;
@@ -105,6 +106,12 @@
     Operation* defined_by_op = operand.getDefiningOp();
     if (!defined_by_op) return failure();
 
+    // Check if the shape is a constant.
+    if (auto const_shape = dyn_cast<shape::ConstShapeOp>(defined_by_op)) {
+      bcasted_shapes.emplace_back(const_shape.shape().getValues<int64_t>());
+      continue;
+    }
+
     // Check if the shape is a result of shape.shape_of operation.
     if (auto shape_of = dyn_cast<shape::ShapeOfOp>(defined_by_op)) {
       if (auto shape = GetSymbolicShape(shape_of.arg(), symbolic_shapes)) {
@@ -217,6 +224,17 @@
   if (failed(GetSymbolicShapes(op, symbolic_shapes_, bcasted_shapes)))
     return failure();
 
+  // Find the maximum rank of the operands.
+  size_t rank = 0;
+  for (const SymbolicShape& bcasted_shape : bcasted_shapes)
+    rank = std::max(rank, bcasted_shape.size());
+
+  // Prepend `1` to all shapes to match the maximum rank.
+  for (size_t i = 0; i < bcasted_shapes.size(); ++i) {
+    bcasted_shapes[i].insert(bcasted_shapes[i].begin(),
+                             rank - bcasted_shapes[i].size(), 1);
+  }
+
   // Pick the first shape as the initialization value for the output shape, and
   // check if the broadcast can be statically proven to be successful.
   SymbolicShape output_shape = bcasted_shapes[0];
@@ -327,6 +345,12 @@
         if (bcasted_shapes[i][d] == dim) {
           Operation* operand_src = bcast.getOperand(i).getDefiningOp();
 
+          // Shape defined by the shape.const_shape operation.
+          if (auto shape = dyn_cast_or_null<shape::ConstShapeOp>(operand_src)) {
+            return rewriter.create<ConstantOp>(
+                loc, shape.shape().getValue({static_cast<unsigned>(dim)}));
+          }
+
           // Shape defined by the shape.shape_of operation.
           if (auto shape_of = dyn_cast_or_null<shape::ShapeOfOp>(operand_src)) {
             return rewriter.create<tensor::DimOp>(loc, shape_of.arg(),
@@ -434,6 +458,12 @@
 
 struct SymbolicShapeOptimizationPass
     : public SymbolicShapeOptimizationBase<SymbolicShapeOptimizationPass> {
+  SymbolicShapeOptimizationPass() = default;
+
+  explicit SymbolicShapeOptimizationPass(bool constraints_only) {
+    this->optimize_only_constraints = constraints_only;
+  }
+
   void runOnFunction() override {
     FuncOp func = getFunction();
 
@@ -444,10 +474,12 @@
     MLIRContext* ctx = &getContext();
     mlir::RewritePatternSet patterns(ctx);
 
-    // Rewrite broadcasts and constraints based on the symbolic shapes.
-    patterns
-        .insert<CstrBroadcastableOpLowering, DynamicBroadcastInDimOpLowering>(
-            ctx, symbolic_shapes);
+    // Rewrite constraints based on the symbolic shapes.
+    patterns.insert<CstrBroadcastableOpLowering>(ctx, symbolic_shapes);
+
+    // Rewrite broadcasts based on the symbolic shapes if enabled.
+    if (!optimize_only_constraints)
+      patterns.insert<DynamicBroadcastInDimOpLowering>(ctx, symbolic_shapes);
 
     // Add shape dialect canonicalization patterns to fold shape operations
     // after constraints are replaced with constant witness.
@@ -463,8 +495,9 @@
 
 }  // namespace
 
-std::unique_ptr<FunctionPass> CreateSymbolicShapeOptimizationPass() {
-  return std::make_unique<SymbolicShapeOptimizationPass>();
+std::unique_ptr<FunctionPass> CreateSymbolicShapeOptimizationPass(
+    bool constraints_only) {
+  return std::make_unique<SymbolicShapeOptimizationPass>(constraints_only);
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/symbolic_shape_optimization.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/symbolic_shape_optimization.mlir
index 778a23f..94709ce 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/jit/symbolic_shape_optimization.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/jit/symbolic_shape_optimization.mlir
@@ -15,7 +15,6 @@
   return %2: !shape.witness
 }
 
-
 // -----
 
 // CHECK-LABEL: @optimize_1dx1d_constraint_with_static_shape
@@ -33,6 +32,21 @@
 
 // -----
 
+// CHECK-LABEL: @optimize_1dx1d_constraint_with_const_shape
+func @optimize_1dx1d_constraint_with_const_shape(
+  %arg0: tensor<512xf32>,
+  %arg1: tensor<?x512xf32>
+    {cpurt.symbolic_shape = dense<[-2,512]> : tensor<2xi64>}
+) -> !shape.witness {
+  %0 = shape.const_shape [512] : tensor<1xindex>
+  %1 = shape.shape_of %arg1 : tensor<?x512xf32> -> tensor<2xindex>
+  // CHECK: shape.const_witness true
+  %2 = shape.cstr_broadcastable %0, %1 : tensor<1xindex>, tensor<2xindex>
+  return %2: !shape.witness
+}
+
+// -----
+
 // CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
 
 // CHECK:       @optimize_1dx1d_bcast(
@@ -66,6 +80,39 @@
 
 // -----
 
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK:       @optimize_1dx2d_bcast_const_shape(
+// CHECK-SAME:    %[[ARG0:[a-z0-9]+]]: tensor<512xf32>
+// CHECK-SAME:    %[[ARG1:[a-z0-9]+]]: tensor<?x512xf32>
+func @optimize_1dx2d_bcast_const_shape(
+  %arg0: tensor<512xf32>,
+  %arg1: tensor<?x512xf32>
+    {cpurt.symbolic_shape = dense<[-2, 512]> : tensor<2xi64>}
+) -> tensor<?x512xf32> {
+  %0 = shape.const_shape [512] : tensor<1xindex>
+  %1 = shape.shape_of %arg1 : tensor<?x512xf32> -> tensor<2xindex>
+  %2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<2xindex>
+                             -> tensor<2xindex>
+
+  // CHECK:      %[[C0:.*]] = constant 0 : index
+  // CHECK:      %[[D0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+  // CHECK:      %[[OUT:.*]] = linalg.init_tensor [%[[D0]], 512]
+  // CHECK:      %[[RET:.*]] = linalg.generic
+  // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel"]
+  // CHECK-SAME: ins(%[[ARG0]] : tensor<512xf32>)
+  // CHECK-SAME: outs(%[[OUT]] : tensor<?x512xf32>)
+  %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2)
+         {broadcast_dimensions = dense<[1]> : tensor<1xi64>}
+       : (tensor<512xf32>, tensor<2xindex>) -> tensor<?x512xf32>
+
+  return %3: tensor<?x512xf32>
+}
+
+// -----
+
 // CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
 
 // CHECK:       @optimize_1dx1dx1d_bcast(
diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_cpurt_pipeline.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_cpurt_pipeline.mlir
index a283fdb..6f305d7 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_cpurt_pipeline.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_cpurt_pipeline.mlir
@@ -109,6 +109,28 @@
 
 // -----
 
+// CHECK: add_vec_tensor_tensor
+func @add_vec_tensor_tensor(
+  %arg0: tensor<512xf32>,
+  %arg1: tensor<1x?x512xf32>
+    {cpurt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>},
+  %arg2: tensor<1x?x512xf32>
+    {cpurt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>}
+) -> tensor<1x?x512xf32> {
+  // CHECK-NOT: memref.reinterpret_cast
+  // CHECK: linalg.generic
+  // CHECK:   addf
+  // CHECK:   addf
+  // CHECK-NOT: linalg.generic
+  %0 = "tf.AddV2"(%arg0, %arg1)
+        : (tensor<512xf32>, tensor<1x?x512xf32>) -> tensor<1x?x512xf32>
+  %1 = "tf.AddV2"(%arg2, %0)
+        : (tensor<1x?x512xf32>, tensor<1x?x512xf32>) -> tensor<1x?x512xf32>
+  return %1 : tensor<1x?x512xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @tf_binary_with_bcast
 func @tf_binary_with_bcast(%arg0: tensor<?x1xf32>,
                            %arg1: tensor<?x4xf32>) -> tensor<?x4xf32> {