Fix metadata emission for dynamically shaped tensor during TPU rewrite
Dynamically shaped tensors are also to be emitted in the metadata for each
arguments, so are unranked ones.
PiperOrigin-RevId: 271511200
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index ad7eef5..9eedbdc 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -196,6 +196,7 @@
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/core/platform:logging",
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
index aed835f..7a9f9e1 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
@@ -177,15 +177,14 @@
// -----
-// Tests argument with unranked shape. No shape should be populated in the
+// Tests argument with unranked shape. Empty shape should be populated in the
// metadata for associated argument.
// CHECK-LABEL: func @unranked_shape_arg
func @unranked_shape_arg(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "tpu0", func = @_func, num_replicas = 1, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<*xi32>) -> tensor<*xi32>
// CHECK: metadata
- // CHECK-SAME: args
- // CHECK-NOT: shape
+ // CHECK-SAME: shape {\0A unknown_rank: true
return %0: tensor<*xi32>
}
@@ -195,16 +194,14 @@
// -----
-// Tests argument with partial shape. No shape should be populated in the
-// metadata for associated argument.
+// Tests argument with partial shape.
// CHECK-LABEL: func @partial_shape_arg
func @partial_shape_arg(%arg0: tensor<?x?x3xi32>) -> tensor<?x?x3xi32> {
%0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "tpu0", func = @_func, num_replicas = 1, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<?x?x3xi32>) -> tensor<?x?x3xi32>
// CHECK: metadata
// CHECK-SAME: args
- // CHECK-NOT: shape
-
+ // CHECK-SAME: shape {\0A dim {\0A size: -1\0A }\0A dim {\0A size: -1\0A }\0A dim {\0A size: 3\0A }\0A }
return %0: tensor<?x?x3xi32>
}
func @_func(%arg0: tensor<?x?x3xi32>) -> tensor<?x?x3xi32> {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
index 0095a69..3e13e84 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
@@ -39,6 +39,7 @@
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -245,13 +246,14 @@
else
arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
- // Unranked and partial shapes are not populated.
+ // Populate argument shapes.
+ *arg->mutable_shape() = tensorflow::TensorShapeProto();
if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
- if (ranked_tensor_type.hasStaticShape()) {
- tensorflow::TensorShapeProto shape_proto;
- ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
- *arg->mutable_shape() = std::move(shape_proto);
- }
+ tensorflow::TensorShapeProto shape_proto;
+ ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
+ *arg->mutable_shape() = std::move(shape_proto);
+ } else {
+ arg->mutable_shape()->set_unknown_rank(true);
}
// TODO(lyandy): Determine proper sharding of args once topology and devices
@@ -307,7 +309,9 @@
for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) {
// Skip adding shape op for operands that have static shapes.
- if (metadata.args(operand_and_idx.index()).has_shape()) continue;
+ tensorflow::PartialTensorShape shape(
+ metadata.args(operand_and_idx.index()).shape());
+ if (shape.IsFullyDefined()) continue;
auto shape_op = builder->create<TF::ShapeOp>(
launch_func.getLoc(),