Fix metadata emission for dynamically shaped tensor during TPU rewrite

Dynamically shaped tensors are also to be emitted in the metadata for each
arguments, so are unranked ones.

PiperOrigin-RevId: 271511200
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index ad7eef5..9eedbdc 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -196,6 +196,7 @@
         "//tensorflow/compiler/mlir/lite:validators",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla:xla_proto",
+        "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_proto_cc",
         "//tensorflow/core/platform:logging",
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
index aed835f..7a9f9e1 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
@@ -177,15 +177,14 @@
 
 // -----
 
-// Tests argument with unranked shape. No shape should be populated in the
+// Tests argument with unranked shape. Empty shape should be populated in the
 // metadata for associated argument.
 
 // CHECK-LABEL: func @unranked_shape_arg
 func @unranked_shape_arg(%arg0: tensor<*xi32>) -> tensor<*xi32> {
   %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "tpu0", func = @_func, num_replicas = 1, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<*xi32>) -> tensor<*xi32>
   // CHECK:      metadata
-  // CHECK-SAME: args
-  // CHECK-NOT:  shape
+  // CHECK-SAME: shape {\0A unknown_rank: true
 
   return %0: tensor<*xi32>
 }
@@ -195,16 +194,14 @@
 
 // -----
 
-// Tests argument with partial shape. No shape should be populated in the
-// metadata for associated argument.
+// Tests argument with partial shape.
 
 // CHECK-LABEL: func @partial_shape_arg
 func @partial_shape_arg(%arg0: tensor<?x?x3xi32>) -> tensor<?x?x3xi32> {
   %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "tpu0", func = @_func, num_replicas = 1, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<?x?x3xi32>) -> tensor<?x?x3xi32>
   // CHECK:      metadata
   // CHECK-SAME: args
-  // CHECK-NOT:  shape
-
+  // CHECK-SAME: shape {\0A dim {\0A size: -1\0A }\0A dim {\0A size: -1\0A }\0A dim {\0A size: 3\0A }\0A }
   return %0: tensor<?x?x3xi32>
 }
 func @_func(%arg0: tensor<?x?x3xi32>) -> tensor<?x?x3xi32> {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
index 0095a69..3e13e84 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
@@ -39,6 +39,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
 #include "tensorflow/compiler/xla/xla.pb.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -245,13 +246,14 @@
     else
       arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
 
-    // Unranked and partial shapes are not populated.
+    // Populate argument shapes.
+    *arg->mutable_shape() = tensorflow::TensorShapeProto();
     if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
-      if (ranked_tensor_type.hasStaticShape()) {
-        tensorflow::TensorShapeProto shape_proto;
-        ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
-        *arg->mutable_shape() = std::move(shape_proto);
-      }
+      tensorflow::TensorShapeProto shape_proto;
+      ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
+      *arg->mutable_shape() = std::move(shape_proto);
+    } else {
+      arg->mutable_shape()->set_unknown_rank(true);
     }
 
     // TODO(lyandy): Determine proper sharding of args once topology and devices
@@ -307,7 +309,9 @@
 
   for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) {
     // Skip adding shape op for operands that have static shapes.
-    if (metadata.args(operand_and_idx.index()).has_shape()) continue;
+    tensorflow::PartialTensorShape shape(
+        metadata.args(operand_and_idx.index()).shape());
+    if (shape.IsFullyDefined()) continue;
 
     auto shape_op = builder->create<TF::ShapeOp>(
         launch_func.getLoc(),