[Fix] fix a bug in ConvertTFStridedSlice when generating the result shape based on `new_axis_mask`.

PiperOrigin-RevId: 289561527
Change-Id: I9c9dec9249b2e126c4597bc4788f5026e080bc80
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 5793c84..a6f651b 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -426,8 +426,8 @@
   // CHECK: %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
 }
 
-// CHECK-LABEL: @PadStridedSliceNewAxisMask
-func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
+// CHECK-LABEL: @PadStridedSliceNewAxisMask1
+func @PadStridedSliceNewAxisMask1(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
   %cst = constant dense<0> : tensor<4xi32>
   %cst_0 = constant dense<1> : tensor<4xi32>
   %0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 9 : i64, shrink_axis_mask = 0 : i64} : (tensor<2x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
@@ -439,3 +439,12 @@
   // CHECK: %0 = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x3xf32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
   // CHECK: %1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i64, ellipsis_mask = 0 : i64, end_mask = 15 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
 }
+
+// CHECK-LABEL: @PadStridedSliceNewAxisMask2
+func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64x64xf32> {
+  %cst = constant dense<0> : tensor<3xi32>
+  %cst_0 = constant dense<1> : tensor<3xi32>
+  %0 = "tf.Squeeze"(%arg0) {T = f32, _output_shapes = ["tfshape$dim { size: 4 } dim { size: 64 } dim { size: 64 }"], device = "", squeeze_dims = []} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
+  %1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32>
+  return %1 : tensor<1x4x64x64xf32>
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index ab4d30e..3df25292 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -425,8 +425,7 @@
     // TODO(renjieliu): Consider expand the transformation for ellipsis & shrink
     // mask as well.
     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
-    const uint64_t new_axis_mask =
-        strided_slice_op.new_axis_mask().getZExtValue();
+    uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
     if (new_axis_mask == 0) return matchFailure();
 
     // Insert a new reshape op.
@@ -435,22 +434,18 @@
         original_input.getType().cast<RankedTensorType>();
     const ArrayRef<int64_t> &original_input_shape =
         original_input_type.getShape();
-    RankedTensorType begin_type =
-        strided_slice_op.begin().getType().cast<RankedTensorType>();
-    const int dim_size = begin_type.getShape()[0];
     SmallVector<int64_t, 4> new_shape;
-    int mask = 1;
     int index = 0;
-    for (int i = 0; i < dim_size; ++i) {
-      if (mask & new_axis_mask) {
+    while (index < original_input_shape.size() || new_axis_mask) {
+      if (new_axis_mask & 1) {
         new_shape.emplace_back(1);
       } else {
-        new_shape.emplace_back(original_input_shape[index]);
-        ++index;
+        new_shape.emplace_back(original_input_shape[index++]);
       }
-      mask = mask << 1;
+      new_axis_mask >>= 1;
     }
 
+    const int dim_size = new_shape.size();
     Location loc = strided_slice_op.getLoc();
     auto shape_type =
         RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));