Add VariableShapeOp to TensorFlow MLIR ODS.

VariableShapeOp is generated from the TensorFlow op registry. Add verifier based on Shape/ShapeN for VariableShape.

PiperOrigin-RevId: 277397486
Change-Id: Ib3cb0c7fef6d7e7ec91ae46810d4123de99a0487
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index bad35e0..8a4b67d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -5048,6 +5048,35 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_VariableShapeOp : TF_Op<"VariableShape", []> {
+  let summary = "Returns the shape of the variable pointed to by `resource`.";
+
+  let description = [{
+This operation returns a 1-D integer tensor representing the shape of `input`.
+
+For example:
+
+```
+# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+shape(t) ==> [2, 2, 3]
+```
+  }];
+
+  let arguments = (ins
+    TF_ResourceTensor:$input
+  );
+
+  let results = (outs
+    TF_I32OrI64Tensor:$output
+  );
+
+  TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>;
+
+  let verifier = [{
+    return Verify(*this);
+  }];
+}
+
 def TF_WhereOp : TF_Op<"Where", [NoSideEffect]> {
   let summary = "Returns locations of nonzero / true values in a tensor.";
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index e2ca339..b7612a2 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -1244,7 +1244,7 @@
 //===----------------------------------------------------------------------===//
 
 namespace {
-// Validates Shape/ShapeN operand and associated result types.
+// Validates Shape/ShapeN/VariableShape operand and associated result types.
 LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
                                           Type result_type,
                                           int variadic_idx = -1) {
@@ -1255,7 +1255,7 @@
   if (!result_ranked_type || result_ranked_type.getShape().size() != 1)
     return op->emitOpError("requires 1D type for result") << variadic_idx_str;
 
-  auto operand_ranked_type = operand_type.dyn_cast<RankedTensorType>();
+  auto operand_ranked_type = operand_type.dyn_cast_or_null<RankedTensorType>();
   if (operand_ranked_type) {
     // The operand is a ranked tensor.
     if (result_ranked_type.hasStaticShape() &&
@@ -1649,6 +1649,29 @@
 }
 
 //===----------------------------------------------------------------------===//
+// VariableShapeOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(VariableShapeOp op) {
+  auto resource_operand_type = op.input()
+                                   ->getType()
+                                   .cast<TensorType>()
+                                   .getElementType()
+                                   .cast<TF::ResourceType>();
+  auto subtypes = resource_operand_type.getSubtypes();
+  switch (subtypes.size()) {
+    case 1:
+      return VerifyShapeOperandAndResult(
+          op, resource_operand_type.getSubtypes().front(), op.getType());
+    case 0:
+      return VerifyShapeOperandAndResult(op, Type(), op.getType());
+    default:
+      return op.emitOpError(
+          "requires resource input type to have at most 1 subtype");
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // WhileOp
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index c97ea7e..600360c 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -1129,6 +1129,55 @@
 
 // -----
 
+// CHECK-LABEL: func @testValidVariableShape
+func @testValidVariableShape(%arg0: tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>, %arg1: tensor<*x!tf.resource>) -> (tensor<4xi32>, tensor<?xi32>) {
+  %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<4xi32>
+  %1 = "tf.VariableShape"(%arg1) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>) -> tensor<?xi32>
+  return %0, %1 : tensor<4xi32>, tensor<?xi32>
+}
+
+// -----
+
+func @testVariableShapeMultipleSubtypes(%arg0: tensor<*x!tf.resource<tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>>>) {
+  // expected-error @+1 {{requires resource input type to have at most 1 subtype}}
+  %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>>>) -> tensor<4xi32>
+  return
+}
+
+// -----
+
+func @testVariableShapeWrongResultElemType(%arg0: tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<?xf32> {
+  // expected-error @+1 {{result #0 must be tensor of 32/64-bit integer values}}
+  %0 = "tf.VariableShape"(%arg0) : (tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+func @testVariableShapeWrongResultDim(%arg0: tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<*xi32> {
+  // expected-error @+1 {{requires 1D type for result}}
+  %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<*xi32>
+  return %0 : tensor<*xi32>
+}
+
+// -----
+
+func @testVariableShapeMismatchDim(%arg0: tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<2xi32> {
+  // expected-error @+1 {{requires dimension size of result to match rank of operand}}
+  %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<1x32x32x16xf32>>>) -> tensor<2xi32>
+  return %0 : tensor<2xi32>
+}
+
+// -----
+
+func @testVariableShapeWrongResultDimDynamic(%arg0: tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> {
+  // expected-error @+1 {{requires dynamic shape result for unranked operand}}
+  %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32>
+  return %0 : tensor<2xi32>
+}
+
+// -----
+
 // Test invalid tf.Const
 func @testConst() -> tensor<f32> {
   // expected-error @+1 {{attribute 'value' failed to satisfy constraint: constant vector/tensor}}