Rewrite masks while applying ellipsis mask to tf.StridedSlice op
begin_mask and end_mask needs to be adjusted while simplifying StridedSlice op by removing ellipsis_mask. new_axis_mask is already removed before this and we don't yet support shrink_axis_mask.
PiperOrigin-RevId: 317419960
Change-Id: Ie4a5f404f95f5b909065311a54cbbed64d4ccf4b
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index e95f3d0..7194309 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -528,6 +528,26 @@
return %1 : tensor<1x4x64x64xf32>
}
+// CHECK-LABEL: @StridedSliceRewriteMasks
+func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf32> {
+ %cst = "tf.Const"() {device = "", value = dense<[1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %cst_0 = "tf.Const"() {device = "", value = dense<[1, 0, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %cst_1 = "tf.Const"() {device = "", value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
+
+ // CHECK: %[[CST:.*]] = constant dense<[1, 0, 0, 1]> : tensor<4xi32>
+ // CHECK: %[[CST0:.*]] = constant dense<[1, 0, 0, 0]> : tensor<4xi32>
+ // CHECK: %[[CST1:.*]] = constant dense<1> : tensor<4xi32>
+ // CHECK: %[[RESULT:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST0]], %[[CST1]])
+ // CHECK-SAME: begin_mask = 7 : i64
+ // CHECK-SAME: ellipsis_mask = 0 : i64
+ // CHECK-SAME: end_mask = 14 : i64
+ // CHECK-SAME: new_axis_mask = 0 : i64
+ // CHECK-SAME: shrink_axis_mask = 0 : i64
+
+ %0 = "tf.StridedSlice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 1 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 4 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<8x4x16x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x4x16x1xf32>
+ return %0 : tensor<8x4x16x1xf32>
+}
+
// CHECK-LABEL: @MatrixSetDiagV2Conversion
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 3310c52..6ee9884 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -584,46 +584,50 @@
const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1;
- llvm::APInt new_begin_mask = strided_slice_op.begin_mask();
- llvm::APInt new_end_mask = strided_slice_op.end_mask();
+ int64_t begin_mask = strided_slice_op.begin_mask().getSExtValue();
+ int64_t end_mask = strided_slice_op.end_mask().getSExtValue();
+ int64_t new_begin_mask = 0;
+ int64_t new_end_mask = 0;
SmallVector<int32_t, 4> padded_begin;
SmallVector<int32_t, 4> padded_end;
SmallVector<int32_t, 4> padded_stride;
// Before the ellipsis.
- uint64_t index = 1;
- int count = 0;
-
- while (index < ellipsis_mask) {
- padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
- padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
- padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
- index <<= 1;
- count++;
+ int index = 0;
+ int new_index = 0;
+ while (((ellipsis_mask >> index) & 1) == 0) {
+ padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
+ padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
+ padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
+ if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
+ if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
+ ++index;
+ ++new_index;
}
// Ellipsis.
- for (int i = 0; i < ellipsis_filled_dim_size; ++i) {
- new_begin_mask |= ellipsis_mask;
- new_end_mask |= ellipsis_mask;
+ for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
+ new_begin_mask |= (1 << new_index);
+ new_end_mask |= (1 << new_index);
// Mimic the begin/end/strides mask behavior.
padded_begin.push_back(0);
padded_end.push_back(0);
padded_stride.push_back(1);
-
- ellipsis_mask <<= 1;
}
// Account for ellipsis mask.
- count++;
+ ++index;
// After the ellipsis.
- for (; count < begin_shape[0]; ++count) {
- padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(count));
- padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(count));
- padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(count));
+ for (; index < begin_shape[0]; ++index) {
+ padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
+ padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
+ padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
+
+ if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
+ if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
}
auto attribute_type = rewriter.getIntegerType(64);
@@ -645,7 +649,7 @@
end_op.getResult(), stride_op.getResult(),
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
rewriter.getIntegerAttr(attribute_type, new_end_mask),
- rewriter.getI64IntegerAttr(0),
+ /*ellipsis_maks=*/rewriter.getI64IntegerAttr(0),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.new_axis_mask()),
rewriter.getIntegerAttr(attribute_type,
@@ -655,10 +659,12 @@
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- // TODO(renjieliu): Consider expand the transformation for shrink
- // mask as well.
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
+ // TODO(renjieliu): Consider expand the transformation for shrink mask as
+ // well.
+ if (strided_slice_op.shrink_axis_mask().getZExtValue()) return failure();
+
// Handle new axis mask.
uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue();
if (new_axis_mask != 0) {