[Fix] fix a bug in ConvertTFStridedSlice when generating the result shape based on `new_axis_mask`.
PiperOrigin-RevId: 289561527
Change-Id: I9c9dec9249b2e126c4597bc4788f5026e080bc80
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 5793c84..a6f651b 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -426,8 +426,8 @@
// CHECK: %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
}
-// CHECK-LABEL: @PadStridedSliceNewAxisMask
-func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
+// CHECK-LABEL: @PadStridedSliceNewAxisMask1
+func @PadStridedSliceNewAxisMask1(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
%cst = constant dense<0> : tensor<4xi32>
%cst_0 = constant dense<1> : tensor<4xi32>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<2x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
@@ -439,3 +439,12 @@
// CHECK: %0 = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x3xf32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
// CHECK: %1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
}
+
+// CHECK-LABEL: @PadStridedSliceNewAxisMask2
+func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64x64xf32> {
+ %cst = constant dense<0> : tensor<3xi32>
+ %cst_0 = constant dense<1> : tensor<3xi32>
+ %0 = "tf.Squeeze"(%arg0) {T = f32, _output_shapes = ["tfshape$dim { size: 4 } dim { size: 64 } dim { size: 64 }"], device = "", squeeze_dims = []} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
+ %1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32>
+ return %1 : tensor<1x4x64x64xf32>
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index ab4d30e..3df25292 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -425,8 +425,7 @@
// TODO(renjieliu): Consider expand the transformation for ellipsis & shrink
// mask as well.
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
- const uint64_t new_axis_mask =
- strided_slice_op.new_axis_mask().getZExtValue();
+ uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
if (new_axis_mask == 0) return matchFailure();
// Insert a new reshape op.
@@ -435,22 +434,18 @@
original_input.getType().cast<RankedTensorType>();
const ArrayRef<int64_t> &original_input_shape =
original_input_type.getShape();
- RankedTensorType begin_type =
- strided_slice_op.begin().getType().cast<RankedTensorType>();
- const int dim_size = begin_type.getShape()[0];
SmallVector<int64_t, 4> new_shape;
- int mask = 1;
int index = 0;
- for (int i = 0; i < dim_size; ++i) {
- if (mask & new_axis_mask) {
+ while (index < original_input_shape.size() || new_axis_mask) {
+ if (new_axis_mask & 1) {
new_shape.emplace_back(1);
} else {
- new_shape.emplace_back(original_input_shape[index]);
- ++index;
+ new_shape.emplace_back(original_input_shape[index++]);
}
- mask = mask << 1;
+ new_axis_mask >>= 1;
}
+ const int dim_size = new_shape.size();
Location loc = strided_slice_op.getLoc();
auto shape_type =
RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));