Fold StridedSliceOp when input is defined by ShapeOp.
The pattern is common is TF python library like height = tf.shape(x)[1].
When x has some dynamic dimensions (typically batch dim), tf.shape can not be constant folded so height cannot be inferred as a constant.
This PR folds this kind of patterns to improve sub-shape constant folding.
Rename some testcases
Correctly handle negative strides
Add testcases for out of bound begin and end
clang-format
Address comments
Fix Windows build. Templated Lambda is not supported in MSVC.
Fix shrink_axis_mask with negative begin
Use canonicalization pattern instead of folder to better support unranked and dynamic output
Switch back to folder
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 7cc9d51..c472b24 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -14412,6 +14412,8 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
+ let hasFolder = 1;
+
let verifier = [{ return VerifyStridedSliceBase(*this); }];
let extraClassDeclaration = [{
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index 0f8a423..eba0646 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -1886,6 +1886,124 @@
return true;
}
+OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
+ // Fold StridedSlice operation if it extracts statically known dimensions.
+ //
+ // For example,
+ //
+ // %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
+ // %height = tf.StridedSlice(%shape, 1, 2, 1)
+ //
+ // In this case %height can be replaced with a constant 2.
+ //
+ // Or,
+ //
+ // %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
+ // %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
+ //
+ // In this case %spatial_shape can be replaced with a constant [2, 3].
+
+ // Input to strided slice op is defined by shape operation.
+ auto shape_op = input().getDefiningOp<ShapeOp>();
+ if (!shape_op) {
+ return {};
+ }
+
+ // `begin`, `end` and `strides` should be constant in order to infer static
+ // dimension.
+ DenseIntElementsAttr begin_attr, end_attr, strides_attr;
+ if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
+ !matchPattern(end(), m_Constant(&end_attr)) ||
+ !matchPattern(strides(), m_Constant(&strides_attr)) ||
+ begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
+ strides_attr.getNumElements() != 1) {
+ return {};
+ }
+
+ // Do not fold when `new_axis_mask` is set. It's likely to break the shape
+ // of output. Typically, `new_axis_mask` is not set in this canonicalization
+ // pattern.
+ if (new_axis_mask() != 0) return {};
+
+ auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
+ // Only ranked tensor can be folded.
+ if (!tensor_ty) return {};
+
+ int64_t rank = tensor_ty.getRank();
+ int64_t begin_int = begin_attr.getValue<APInt>(0).getSExtValue();
+ int64_t end_int = end_attr.getValue<APInt>(0).getSExtValue();
+ int64_t strides_int = strides_attr.getValue<APInt>(0).getSExtValue();
+
+ // Canonicalize `begin` and `end` in case of negative index.
+ if (begin_int < 0) begin_int += rank;
+ if (end_int < 0) end_int += rank;
+
+ // Create `begin` and `end` from `*_mask`. Note that we don't care about
+ // `new_axis_mask` as it can be inferred from `output_ty`.
+ if (shrink_axis_mask() == 1) {
+ // When `shrink_axis_mask` is set, output is always a scalar so only
+ // one element is sliced.
+ end_int = begin_int + 1;
+ }
+ if (begin_mask() == 1) {
+ begin_int = (strides_int > 0) ? 0 : rank - 1;
+ }
+ if (end_mask() == 1) {
+ end_int = (strides_int > 0) ? rank : -1;
+ }
+ if (ellipsis_mask() == 1) {
+ begin_int = 0;
+ end_int = rank;
+ }
+
+ // It's possible that `begin` and `end` are out of bound. See
+ // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
+ if (strides_int > 0) {
+ begin_int = std::min(begin_int, rank);
+ end_int = std::min(end_int, rank);
+ } else {
+ begin_int = std::min(begin_int, rank - 1);
+ end_int = std::min(end_int, rank - 1);
+ }
+
+ SmallVector<int64_t, 2> sub_shape;
+ // Only handle cases that have something to slice to avoid infinite for-loop.
+ if ((end_int > begin_int && strides_int > 0) ||
+ (end_int < begin_int && strides_int < 0)) {
+ // Extract sub-shape only if all of those dimensions are static.
+ for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
+ i += strides_int) {
+ if (tensor_ty.isDynamicDim(i)) {
+ return {};
+ }
+ sub_shape.push_back(tensor_ty.getDimSize(i));
+ }
+ }
+
+ // For unranked or dynamic output, we infer the output type to either a
+ // scalar or a vector based on `shrink_axis_mask` because we have rejected
+ // the case of `new_axis_mask` != 0.
+ auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
+ auto output_ty = output().getType().dyn_cast<RankedTensorType>();
+ if (!output_ty || !output_ty.hasStaticShape()) {
+ if (shrink_axis_mask() == 1) {
+ output_ty = RankedTensorType::get({}, output_elt_ty);
+ } else {
+ output_ty = RankedTensorType::get(
+ {static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
+ }
+ }
+
+ // Down-cast to 32 bit int if needed.
+ if (output_elt_ty.isInteger(32)) {
+ SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
+ std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
+ [](int64_t d) { return static_cast<int32_t>(d); });
+ return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
+ }
+ return DenseIntElementsAttr::get(output_ty, sub_shape);
+}
+
//===----------------------------------------------------------------------===//
// StridedSliceGradOp
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index e2a0552..841e6dd 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -486,7 +486,7 @@
}
// CHECK-LABEL: func @testPackShapeComputation
-func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
+func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
// Test dimensions sizes.
%d1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%d2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
@@ -526,26 +526,20 @@
%15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[PACK0:.*]] = "tf.Pack"
- // StridedSlice takes second dimension from the shape:
- // begin = [1], end = [2], stride = [1]
- %17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
- %18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
- // CHECK: %[[PACK1:.*]] = "tf.Pack"
-
// Packed dimensions have higher rank than the reshape operand:
// [?, 1] vs [?, 1, 1]
- %20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
- %21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
- // CHECK: %[[PACK2:.*]] = "tf.Pack"
+ %16 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+ %17 = "tf.Pack"(%16, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
+ // CHECK: %[[PACK1:.*]] = "tf.Pack"
// Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass
- %23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
- %24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
- %25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
- // CHECK: %[[PACK3:.*]] = "tf.Pack"
+ %18 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
+ %19 = "tf.StridedSlice"(%18, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
+ %20 = "tf.Pack"(%19, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
+ // CHECK: %[[PACK2:.*]] = "tf.Pack"
- // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]]
- return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
+ // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]]
+ return %5, %9, %15, %17, %20 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
}
// CHECK-LABEL: testSelectScalarPred
@@ -1373,3 +1367,211 @@
// CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1
return %0, %1 : tensor<?xf32>, tensor<?xf32>
}
+
+// CHECK-LABEL: testFoldStridedSliceShapeI32
+func @testFoldStridedSliceShapeI32(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi32>) {
+ %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
+ %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+ return %3 : tensor<2xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeI64
+func @testFoldStridedSliceShapeI64(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi64>) {
+ %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi64>
+ %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64>
+ return %3 : tensor<2xi64>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeDynamicOutput
+func @testFoldStridedSliceShapeDynamicOutput(%arg0: tensor<?x1x2x?xf32>) -> (tensor<?xi32>) {
+ %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
+ %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
+ return %3 : tensor<?xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<?xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI32
+func @testFoldStridedSliceShapeWithShrinkAxisMaskI32(%arg0: tensor<?x1x2x?xf32>) -> (tensor<i32>) {
+ %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
+ %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+ return %3 : tensor<i32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI64
+func @testFoldStridedSliceShapeWithShrinkAxisMaskI64(%arg0: tensor<?x1x2x?xf32>) -> (tensor<i64>) {
+ %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi64>
+ %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
+ return %3 : tensor<i64>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskUnrankedOutput
+func @testFoldStridedSliceShapeWithShrinkAxisMaskUnrankedOutput(%arg0: tensor<?x1x2x?xf32>) -> (tensor<*xi32>) {
+ %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
+ %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
+ return %3 : tensor<*xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1
+func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1(%arg0: tensor<?x1x2x3xf32>) -> (tensor<i32>) {
+ %0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+ %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+ return %4 : tensor<i32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2
+func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2(%arg0: tensor<?x1x2x3xf32>) -> (tensor<i32>) {
+ %0 = "tf.Const"() {value = dense<-2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+ %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+ return %4 : tensor<i32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testUnfoldedStridedSliceShape
+func @testUnfoldedStridedSliceShape(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi32>) {
+ %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
+ %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+ return %4 : tensor<2xi32>
+ // CHECK: %[[SLICE:.*]] = "tf.StridedSlice"
+ // CHECK: return %[[SLICE]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithBeginMask
+func @testFoldStridedSliceShapeWithBeginMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<2xi32>) {
+ %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
+ %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+ return %4 : tensor<2xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithEndMask
+func @testFoldStridedSliceShapeWithEndMask(%arg0: tensor<?x1x2x3xf32>) -> (tensor<3xi32>) {
+ %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+ %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+ return %3 : tensor<3xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStrides
+func @testFoldStridedSliceShapeWithPositiveStrides(%arg0: tensor<1x2x3x4x?xf32>) -> (tensor<2xi32>) {
+ %0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x4x?xf32>) -> tensor<5xi32>
+ %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<5xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+ return %4 : tensor<2xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd
+func @testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd(%arg0: tensor<?x1x2x3xf32>) -> (tensor<3xi32>) {
+ %0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+ %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+ return %3 : tensor<3xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStrides
+func @testFoldStridedSliceShapeWithNegativeStrides(%arg0: tensor<1x2x3x?xf32>) -> (tensor<1xi32>) {
+ %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
+ %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ return %4 : tensor<1xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin
+func @testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin(%arg0: tensor<?x1x2x3xf32>) -> (tensor<2xi32>) {
+ %0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+ %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+ return %4 : tensor<2xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesBeginMask
+func @testFoldStridedSliceShapeWithNegativeStridesBeginMask(%arg0: tensor<?x1x2x3xf32>) -> (tensor<2xi32>) {
+ %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+ %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+ return %4 : tensor<2xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesEndMask
+func @testFoldStridedSliceShapeWithNegativeStridesEndMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<3xi32>) {
+ %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
+ %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+ return %4 : tensor<3xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithEmptySlice
+func @testFoldStridedSliceShapeWithEmptySlice(%arg0: tensor<?x1x2x3xf32>) -> (tensor<0xi32>) {
+ %0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+ %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+ %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+ %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32>
+ return %4 : tensor<0xi32>
+ // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+ // CHECK: return %[[CST]]
+}