Rewrite masks while applying ellipsis mask to tf.StridedSlice op

begin_mask and end_mask needs to be adjusted while simplifying StridedSlice op by removing ellipsis_mask. new_axis_mask is already removed before this and we don't yet support shrink_axis_mask.

PiperOrigin-RevId: 317419960
Change-Id: Ie4a5f404f95f5b909065311a54cbbed64d4ccf4b
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index e95f3d0..7194309 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -528,6 +528,26 @@
   return %1 : tensor<1x4x64x64xf32>
 }
 
+// CHECK-LABEL: @StridedSliceRewriteMasks
+func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf32> {
+  %cst = "tf.Const"() {device = "", value = dense<[1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %cst_0 = "tf.Const"() {device = "", value = dense<[1, 0, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %cst_1 = "tf.Const"() {device = "", value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
+
+  // CHECK: %[[CST:.*]] = constant dense<[1, 0, 0, 1]> : tensor<4xi32>
+  // CHECK: %[[CST0:.*]] = constant dense<[1, 0, 0, 0]> : tensor<4xi32>
+  // CHECK: %[[CST1:.*]] = constant dense<1> : tensor<4xi32>
+  // CHECK: %[[RESULT:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST0]], %[[CST1]])
+  // CHECK-SAME: begin_mask = 7 : i64
+  // CHECK-SAME: ellipsis_mask = 0 : i64
+  // CHECK-SAME: end_mask = 14 : i64
+  // CHECK-SAME: new_axis_mask = 0 : i64
+  // CHECK-SAME: shrink_axis_mask = 0 : i64
+
+  %0 = "tf.StridedSlice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 1 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 4 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<8x4x16x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x4x16x1xf32>
+  return %0 : tensor<8x4x16x1xf32>
+}
+
 // CHECK-LABEL: @MatrixSetDiagV2Conversion
 func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
   %cst = constant dense<0> : tensor<i32>
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 3310c52..6ee9884 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -584,46 +584,50 @@
 
     const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
 
-    llvm::APInt new_begin_mask = strided_slice_op.begin_mask();
-    llvm::APInt new_end_mask = strided_slice_op.end_mask();
+    int64_t begin_mask = strided_slice_op.begin_mask().getSExtValue();
+    int64_t end_mask = strided_slice_op.end_mask().getSExtValue();
+    int64_t new_begin_mask = 0;
+    int64_t new_end_mask = 0;
 
     SmallVector<int32_t, 4> padded_begin;
     SmallVector<int32_t, 4> padded_end;
     SmallVector<int32_t, 4> padded_stride;
 
     // Before the ellipsis.
-    uint64_t index = 1;
-    int count = 0;
-
-    while (index < ellipsis_mask) {
-      padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
-      padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
-      padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
-      index <<= 1;
-      count++;
+    int index = 0;
+    int new_index = 0;
+    while (((ellipsis_mask >> index) & 1) == 0) {
+      padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
+      padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
+      padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
+      if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
+      if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
+      ++index;
+      ++new_index;
     }
 
     // Ellipsis.
-    for (int i = 0; i < ellipsis_filled_dim_size; ++i) {
-      new_begin_mask |= ellipsis_mask;
-      new_end_mask |= ellipsis_mask;
+    for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
+      new_begin_mask |= (1 << new_index);
+      new_end_mask |= (1 << new_index);
 
       // Mimic the begin/end/strides mask behavior.
       padded_begin.push_back(0);
       padded_end.push_back(0);
       padded_stride.push_back(1);
-
-      ellipsis_mask <<= 1;
     }
 
     // Account for ellipsis mask.
-    count++;
+    ++index;
 
     // After the ellipsis.
-    for (; count < begin_shape[0]; ++count) {
-      padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
-      padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
-      padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
+    for (; index < begin_shape[0]; ++index) {
+      padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
+      padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
+      padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
+
+      if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
+      if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
     }
 
     auto attribute_type = rewriter.getIntegerType(64);
@@ -645,7 +649,7 @@
         end_op.getResult(), stride_op.getResult(),
         rewriter.getIntegerAttr(attribute_type, new_begin_mask),
         rewriter.getIntegerAttr(attribute_type, new_end_mask),
-        rewriter.getI64IntegerAttr(0),
+        /*ellipsis_maks=*/rewriter.getI64IntegerAttr(0),
         rewriter.getIntegerAttr(attribute_type,
                                 strided_slice_op.new_axis_mask()),
         rewriter.getIntegerAttr(attribute_type,
@@ -655,10 +659,12 @@
 
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    // TODO(renjieliu): Consider expand the transformation for shrink
-    // mask as well.
     TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
 
+    // TODO(renjieliu): Consider expand the transformation for shrink mask as
+    // well.
+    if (strided_slice_op.shrink_axis_mask().getZExtValue()) return failure();
+
     // Handle new axis mask.
     uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
     if (new_axis_mask != 0) {