[MLIR][HLO] Add cast when needed to simplify `shape.bcast` ops

PiperOrigin-RevId: 437832615
diff --git a/tensorflow/compiler/mlir/hlo/lib/Transforms/symbolic_shape_optimization.cc b/tensorflow/compiler/mlir/hlo/lib/Transforms/symbolic_shape_optimization.cc
index 03fd24a..008dc2a 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Transforms/symbolic_shape_optimization.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Transforms/symbolic_shape_optimization.cc
@@ -118,8 +118,19 @@
           return rewriter.create<tensor::ExtractOp>(loc, operand, operand_dim)
               .getResult();
         }));
-    rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(
-        op, op->getResultTypes().front(), elements);
+    Type index_ty = rewriter.getIndexType();
+    Type concrete_result_ty = RankedTensorType::get(
+        {static_cast<int64_t>(elements.size())}, index_ty);
+    Value result = rewriter.create<tensor::FromElementsOp>(
+        loc, concrete_result_ty, elements);
+
+    // Insert cast, if needed.
+    Type expected_ty = op.getResult().getType();
+    if (result.getType() != expected_ty) {
+      result = rewriter.create<tensor::CastOp>(loc, expected_ty, result);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir
index 7aa2212..e6ba7e5 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir
@@ -689,3 +689,31 @@
       -> tensor<1xindex>
   func.return %2 : tensor<1xindex>
 }
+
+// -----
+
+// CHECK-LABEL: @broadcast_w_dyn_ty
+// CHECK-SAME:  %[[ARG:.*]]: tensor<1xindex>
+func.func @broadcast_w_dyn_ty(%arg0: tensor<1xindex>) -> tensor<?xindex>{
+  // CHECK: %[[C0:.*]] = arith.constant 0
+  // CHECK: %[[D0:.*]] = tensor.extract %[[ARG]][%[[C0]]]
+  // CHECK: %[[UNCAST:.*]] = tensor.from_elements %[[D0]]
+  // CHECK: %[[CAST:.*]] = tensor.cast %[[UNCAST]] : tensor<1xindex> to tensor<?xindex>
+  // CHECK: return %[[CAST]]
+   %0 = shape.broadcast %arg0, %arg0
+       : tensor<1xindex>, tensor<1xindex> -> tensor<?xindex>
+   return %0 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_scalar_w_dyn_ty
+// CHECK-SAME:  %[[ARG:.*]]: tensor<0xindex>
+func.func @broadcast_scalar_w_dyn_ty(%arg0: tensor<0xindex>) -> tensor<?xindex>{
+  // CHECK: %[[UNCAST:.*]] = arith.constant dense<> : tensor<0xindex>
+  // CHECK: %[[CAST:.*]] = tensor.cast %[[UNCAST]] : tensor<0xindex> to tensor<?xindex>
+  // CHECK: return %[[CAST]]
+   %0 = shape.broadcast %arg0, %arg0
+       : tensor<0xindex>, tensor<0xindex> -> tensor<?xindex>
+   return %0 : tensor<?xindex>
+}