Separate ShapeNToShape canonicalization pattern
Remove folder
Add tests
Use getElementTypeOrSelf utility
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 56baed1..ce69118 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -8713,8 +8713,6 @@
}];
let hasCanonicalizer = 1;
-
- let hasFolder = 1;
}
def TF_ShardedFilenameOp : TF_Op<"ShardedFilename", [NoSideEffect]> {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index d726894..c5da5e0 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -932,21 +932,6 @@
return success();
}
-LogicalResult ShapeNOp::fold(ArrayRef<Attribute> operands,
- SmallVectorImpl<OpFoldResult> &results) {
- if (getNumOperands() == 0) return success();
- int width =
- getType(0).cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
-
- for (Type input_ty : getOperandTypes()) {
- OpFoldResult result = ConvertShapeToAttr(input_ty, width);
- if (!result) return failure();
-
- results.push_back(result);
- }
- return success();
-}
-
namespace {
// Canonicalization pattern for ShapeNOp that don't have all
// static input shapes. Replacing output values corresponding to static input
@@ -955,12 +940,12 @@
using OpRewritePattern<ShapeNOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeNOp op,
PatternRewriter &rewriter) const override {
- // ShapeNOp::fold handles this case.
- if (op.getNumOperands() == 0) return success();
- int width = op.getType(0)
- .cast<ShapedType>()
- .getElementType()
- .getIntOrFloatBitWidth();
+ if (op.getNumOperands() == 0) {
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth();
SmallVector<Value, 4> results(op.getNumOperands());
SmallVector<int64_t, 4> dynamic_indices;
@@ -982,12 +967,8 @@
return failure();
}
- // Create a ShapeOp when there is only one dynamic input.
- // Or create a ShapeNOp when there are two or more dynamic inputs.
- if (dynamic_inputs.size() == 1) {
- results[dynamic_indices[0]] = rewriter.create<TF::ShapeOp>(
- op.getLoc(), result_types[0], dynamic_inputs[0]);
- } else if (dynamic_inputs.size() >= 2) {
+ // Create a ShapeNOp for all dynamic inputs.
+ if (!dynamic_inputs.empty()) {
auto dynamic_shape_n = rewriter.create<TF::ShapeNOp>(
op.getLoc(), result_types, dynamic_inputs);
for (auto index_result :
@@ -1000,11 +981,26 @@
return success();
}
};
+
+// Canonicalize ShapeNOp to ShapeOp if there is only one operand.
+class ShapeNToShape : public OpRewritePattern<ShapeNOp> {
+ using OpRewritePattern<ShapeNOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ShapeNOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getNumOperands() != 1) {
+ return failure();
+ }
+ auto shape = rewriter.create<TF::ShapeOp>(
+ op.getLoc(), op.getType(0), op.getOperand(0));
+ rewriter.replaceOp(op, {shape});
+ return success();
+ }
+};
} // namespace
void ShapeNOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<ShapeNPartialStaticInputShape>(context);
+ results.insert<ShapeNToShape, ShapeNPartialStaticInputShape>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index 0e0ba40..f114d17 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -121,6 +121,15 @@
return %0#0, %0#1, %0#2 : tensor<0xi64>, tensor<4xi64>, tensor<?xi64>
}
+// CHECK-LABEL: func @testShapeNToShape
+func @testShapeNToShape(%arg0: tensor<*xf32>) -> tensor<?xi64> {
+ // CHECK: %[[SHAPE0:.*]] = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi64>
+ %0:1 = "tf.ShapeN"(%arg0) : (tensor<*xf32>) -> tensor<?xi64>
+
+ // CHECK: return %[[SHAPE0]]
+ return %0#0 : tensor<?xi64>
+}
+
// CHECK-LABEL: func @testLeakyRelu
func @testLeakyRelu(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor<f32>, tensor<f32>, tensor<16xf32>) {
%pos = constant dense<5.0> : tensor<f32>