Separate ShapeNToShape canonicalization pattern

Remove folder

Add tests

Use getElementTypeOrSelf utility
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 56baed1..ce69118 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -8713,8 +8713,6 @@
   }];
 
   let hasCanonicalizer = 1;
-
-  let hasFolder = 1;
 }
 
 def TF_ShardedFilenameOp : TF_Op<"ShardedFilename", [NoSideEffect]> {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index d726894..c5da5e0 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -932,21 +932,6 @@
   return success();
 }
 
-LogicalResult ShapeNOp::fold(ArrayRef<Attribute> operands,
-                             SmallVectorImpl<OpFoldResult> &results) {
-  if (getNumOperands() == 0) return success();
-  int width =
-      getType(0).cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
-
-  for (Type input_ty : getOperandTypes()) {
-    OpFoldResult result = ConvertShapeToAttr(input_ty, width);
-    if (!result) return failure();
-
-    results.push_back(result);
-  }
-  return success();
-}
-
 namespace {
 // Canonicalization pattern for ShapeNOp that don't have all
 // static input shapes. Replacing output values corresponding to static input
@@ -955,12 +940,12 @@
   using OpRewritePattern<ShapeNOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(ShapeNOp op,
                                 PatternRewriter &rewriter) const override {
-    // ShapeNOp::fold handles this case.
-    if (op.getNumOperands() == 0) return success();
-    int width = op.getType(0)
-                    .cast<ShapedType>()
-                    .getElementType()
-                    .getIntOrFloatBitWidth();
+    if (op.getNumOperands() == 0) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    int width = getElementTypeOrSelf(op.getType(0)).getIntOrFloatBitWidth();
 
     SmallVector<Value, 4> results(op.getNumOperands());
     SmallVector<int64_t, 4> dynamic_indices;
@@ -982,12 +967,8 @@
       return failure();
     }
 
-    // Create a ShapeOp when there is only one dynamic input.
-    // Or create a ShapeNOp when there are two or more dynamic inputs.
-    if (dynamic_inputs.size() == 1) {
-      results[dynamic_indices[0]] = rewriter.create<TF::ShapeOp>(
-          op.getLoc(), result_types[0], dynamic_inputs[0]);
-    } else if (dynamic_inputs.size() >= 2) {
+    // Create a ShapeNOp for all dynamic inputs.
+    if (!dynamic_inputs.empty()) {
       auto dynamic_shape_n = rewriter.create<TF::ShapeNOp>(
           op.getLoc(), result_types, dynamic_inputs);
       for (auto index_result :
@@ -1000,11 +981,26 @@
     return success();
   }
 };
+
+// Canonicalize ShapeNOp to ShapeOp if there is only one operand.
+class ShapeNToShape : public OpRewritePattern<ShapeNOp> {
+  using OpRewritePattern<ShapeNOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ShapeNOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getNumOperands() != 1) {
+      return failure();
+    }
+    auto shape = rewriter.create<TF::ShapeOp>(
+        op.getLoc(), op.getType(0), op.getOperand(0));
+    rewriter.replaceOp(op, {shape});
+    return success();
+  }
+};
 }  // namespace
 
 void ShapeNOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  results.insert<ShapeNPartialStaticInputShape>(context);
+  results.insert<ShapeNToShape, ShapeNPartialStaticInputShape>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index 0e0ba40..f114d17 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -121,6 +121,15 @@
   return %0#0, %0#1, %0#2 : tensor<0xi64>, tensor<4xi64>, tensor<?xi64>
 }
 
+// CHECK-LABEL: func @testShapeNToShape
+func @testShapeNToShape(%arg0: tensor<*xf32>) -> tensor<?xi64> {
+  // CHECK: %[[SHAPE0:.*]] = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi64>
+  %0:1 = "tf.ShapeN"(%arg0) : (tensor<*xf32>) -> tensor<?xi64>
+
+  // CHECK: return %[[SHAPE0]]
+  return %0#0 : tensor<?xi64>
+}
+
 // CHECK-LABEL: func @testLeakyRelu
 func @testLeakyRelu(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor<f32>, tensor<f32>, tensor<16xf32>) {
   %pos = constant dense<5.0> : tensor<f32>