Verify resource subtype is supplied for tf.VarHandleOp.

If no resource subtype is supplied, we can't derive the dtype attribute.
Verify that a resource subtype is supplied so we get a verification failure
instead of an assert failure if you try to access the dtype attribute for a
tf.VarHandleOp of type tensor<*x!tf.resource>.

PiperOrigin-RevId: 350805993
Change-Id: I049fb7e292993f6e760bb850a01337a3a8c09452
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
index a0f4baf..39a45f1 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
@@ -568,26 +568,6 @@
   "return (*getOperation()->result_type_begin()).cast<ShapedType>();",
   [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
 
-// A derived attribute that returns the element type of the tensor held by a
-// named resource-type operand or result.
-class TF_DerivedOperandOrResultHandleTypeAttr<string name> : DerivedTypeAttr<
-  "auto resource_type =\n"
-  "  mlir::getElementTypeOrSelf(this->" # name # "())\n"
-  "  .cast<TF::ResourceType>();\n"
-  "assert(!resource_type.getSubtypes().empty() && \"unknown type\");\n"
-  "return mlir::getElementTypeOrSelf(*resource_type.getSubtypes().begin());">;
-
-// A derived attribute that returns the shape of the tensor held by a named
-// resource-type operand or result.
-class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr<
-  "ShapedType",
-  "auto resource_type =\n"
-  "  mlir::getElementTypeOrSelf(this->" # name # "())\n"
-  "  .cast<TF::ResourceType>();\n"
-  "assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n"
-  "return resource_type.getSubtypes().begin()->cast<ShapedType>();",
-  [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
-
 def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
   let returnType = "Type";
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 4615064..93792ab 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -865,16 +865,35 @@
     Res<TF_ResourceTensor, "", [TF_VariableAlloc]>:$resource
   );
 
-  TF_DerivedOperandOrResultHandleTypeAttr dtype =
-    TF_DerivedOperandOrResultHandleTypeAttr<"resource">;
-  TF_DerivedOperandOrResultHandleShapeAttr shape =
-    TF_DerivedOperandOrResultHandleShapeAttr<"resource">;
+  let verifier = [{
+    // VarHandleOp requires the resource handle supply a single subtype from
+    // which to derive the dtype and shape attributes.
+    if (resource_type().getSubtypes().size() != 1) {
+      return emitOpError(
+          "must have exactly one subtype in the result resource type");
+    }
+
+    return success();
+  }];
+
+  DerivedTypeAttr dtype = DerivedTypeAttr<
+      "return getElementTypeOrSelf(resource_subtype());">;
+  DerivedAttr shape = DerivedAttr<
+      "ShapedType",
+      "return resource_subtype().cast<ShapedType>();",
+      [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
 
   let extraClassDeclaration = [{
     // TF_ResourceHandleAllocatorInterface:
     ResourceHandleValueAndId GetResourceHandleValueAndId(
       llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
       int64_t &next_id);
+
+    TensorType resource_subtype() { return resource_type().getSubtypes()[0]; }
+
+    ResourceType resource_type() {
+      return getElementTypeOrSelf(resource()).cast<TF::ResourceType>();
+    }
   }];
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
index 920f535..5d8cdc7 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
@@ -28,7 +28,7 @@
 // CHECK-LABEL: func @decompose_assign_add_variable_op
 func @decompose_assign_add_variable_op() -> () {
 
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<i32>>>
 
   // CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
   // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
@@ -36,7 +36,7 @@
   // CHECK: "tf.AssignVariableOp"
 
   %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-  "tf.AssignAddVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
+  "tf.AssignAddVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
 
   return
 }
@@ -49,7 +49,7 @@
 // CHECK-LABEL: func @decompose_assign_sub_variable_op
 func @decompose_assign_sub_variable_op() -> () {
 
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<i32>>>
 
   // CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
   // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
@@ -57,7 +57,7 @@
   // CHECK: "tf.AssignVariableOp"
 
   %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-  "tf.AssignSubVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
+  "tf.AssignSubVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
 
   return
 }
@@ -70,7 +70,7 @@
 // CHECK-SAME: (%[[DELTA:.*]]: tensor<f32>)
 func @decompose_resource_apply_gradient_descent(%arg0: tensor<f32>) -> () {
 
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
 
   // CHECK: %[[ALPHA:[0-9]*]] = "tf.Const"
   // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
@@ -80,7 +80,7 @@
   // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[SUB]])
 
   %1 = "tf.Const"() {T = f32, value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
-  "tf.ResourceApplyGradientDescent"(%0, %1, %arg0) {use_locking = false} : (tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
+  "tf.ResourceApplyGradientDescent"(%0, %1, %arg0) {use_locking = false} : (tensor<!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>) -> ()
 
   return
 }
@@ -96,8 +96,8 @@
 
   // CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
   // CHECK: [[ACCUM_HANDLE:%.*]] = "tf.VarHandleOp"
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
+  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
 
   // CHECK: [[ACCUM:%.*]] = "tf.ReadVariableOp"([[ACCUM_HANDLE]])
   // CHECK: [[ACCUM_MOMENTUM:%.*]] = "tf.Mul"([[ACCUM]], [[MOMENTUM]])
@@ -107,7 +107,7 @@
   // CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]])
   // CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[ACCUM_NEW_LR]])
   // CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
-  "tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+  "tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
   return
 }
 
@@ -122,8 +122,8 @@
 
   // CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
   // CHECK: [[ACCUM_HANDLE:%.*]] = "tf.VarHandleOp"
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
+  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
 
   // CHECK: [[ACCUM:%.*]] = "tf.ReadVariableOp"([[ACCUM_HANDLE]])
   // CHECK: [[ACCUM_MOMENTUM:%.*]] = "tf.Mul"([[ACCUM]], [[MOMENTUM]])
@@ -136,7 +136,7 @@
   // CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]])
   // CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[DELTA]])
   // CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
-  "tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+  "tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
   return
 }
 
@@ -151,10 +151,10 @@
 
   // CHECK: %[[VAR_HANDLE:[0-9]*]] = "tf.VarHandleOp"
   // CHECK: %[[ACCUM_HANDLE:[0-9]*]] = "tf.VarHandleOp"
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
 
-  // CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+  // CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
   // CHECK: %[[ACCUM_MOMENTUM:[0-9]*]] = "tf.Mul"(%[[ACCUM]], %[[MOMENTUM]])
   // CHECK: %[[GRAD_LR:[0-9]*]] = "tf.Mul"(%[[GRAD]], %[[LR]])
   // CHECK: %[[NEW_ACCUM:[0-9]*]] = "tf.Sub"(%[[ACCUM_MOMENTUM]], %[[GRAD_LR]])
@@ -164,7 +164,7 @@
   // CHECK: %[[NEW_VAR:[0-9]*]] = "tf.AddV2"(%[[VAR]], %[[NEW_ACCUM]])
   // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[NEW_VAR]])
 
-  "tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+  "tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
 
   return
 }
@@ -180,10 +180,10 @@
 
   // CHECK: %[[VAR_HANDLE:[0-9]*]] = "tf.VarHandleOp"
   // CHECK: %[[ACCUM_HANDLE:[0-9]*]] = "tf.VarHandleOp"
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
 
-  // CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+  // CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
   // CHECK: %[[ACCUM_MOMENTUM:[0-9]*]] = "tf.Mul"(%[[ACCUM]], %[[MOMENTUM]])
   // CHECK: %[[GRAD_LR:[0-9]*]] = "tf.Mul"(%[[GRAD]], %[[LR]])
   // CHECK: %[[NEW_ACCUM:[0-9]*]] = "tf.Sub"(%[[ACCUM_MOMENTUM]], %[[GRAD_LR]])
@@ -195,7 +195,7 @@
   // CHECK: %[[NEW_VAR:[0-9]*]] = "tf.AddV2"(%[[VAR]], %[[NEW_DELTA]])
   // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[NEW_VAR]])
 
-  "tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+  "tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
 
   return
 }
@@ -212,21 +212,21 @@
 // CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"()
 // CHECK: [[ACC_HANDLE:%.*]] = "tf.VarHandleOp"()
 // CHECK: [[GRAD_SQUARE:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
 // CHECK: [[NEW_ACC:%.*]] = "tf.AddV2"([[OLD_ACC]], [[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
 // CHECK: [[LR_MULTIPLY:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
 // CHECK: [[SQRT:%.*]] = "tf.Sqrt"([[NEW_ACC]]) : (tensor<*xf32>) -> tensor<*xf32>
 // CHECK: [[DIVISOR:%.*]] = "tf.AddV2"([[SQRT]], [[EPSILON]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
 // CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[LR_MULTIPLY]], [[DIVISOR]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
-// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
 // CHECK: [[NEW_VAR:%.*]] = "tf.Sub"(%9, %8) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
-// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
 
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
 
-  "tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+  "tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
 
   return
 }
@@ -236,22 +236,22 @@
 // CHECK-SAME:  (%[[LR:.*]]: tensor<f32>, %[[GRAD:.*]]: tensor<f32>)
 func @decompose_resource_apply_adagrad(%arg0: tensor<f32>, %arg1: tensor<f32>) -> () {
 
-  // CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  // CHECK: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  // CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+  // CHECK: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
   // CHECK: %[[GRAD_SQUARE:.*]] = "tf.Mul"(%[[GRAD]], %[[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+  // CHECK: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
   // CHECK: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
   // CHECK: %[[LR_MULTIPLY:.*]] = "tf.Mul"(%[[LR]], %[[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: %[[SQRT:.*]] = "tf.Sqrt"(%[[ACCUM_NEW]]) : (tensor<*xf32>) -> tensor<*xf32>
   // CHECK: %[[DIV:.*]] = "tf.Div"(%[[LR_MULTIPLY]], %[[SQRT]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
-  // CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+  // CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
   // CHECK: %[[VAR_NEW:.*]] = "tf.Sub"(%[[VAR]], %[[DIV]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-  // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
-  // CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+  // CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
 
-  "tf.ResourceApplyAdagrad"(%0, %1, %arg0, %arg1) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
+  "tf.ResourceApplyAdagrad"(%0, %1, %arg0, %arg1) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>) -> ()
 
   return
 }
@@ -274,12 +274,12 @@
 // CHECK: [[ONE_MINUS_BETA1_POWER:%.*]] = "tf.Sub"([[ONE]], [[BETA1_POWER]])
 // CHECK: [[ALPHA_NO_LR:%.*]] = "tf.Div"([[SQRT_ONE_MINUS_BETA2_POWER]], [[ONE_MINUS_BETA1_POWER]])
 // CHECK: [[ALPHA:%.*]] = "tf.Mul"([[LR]], [[ALPHA_NO_LR]])
-// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
 // CHECK: [[BETA1_OLD_M:%.*]] = "tf.Mul"([[BETA1]], [[OLD_M]])
 // CHECK: [[ONE_MINUS_BETA1:%.*]] = "tf.Sub"([[ONE]], [[BETA1]])
 // CHECK: [[ONE_MINUS_BETA1_GRAD:%.*]] = "tf.Mul"([[ONE_MINUS_BETA1]], [[GRAD]])
 // CHECK: [[NEW_M:%.*]] = "tf.AddV2"([[BETA1_OLD_M]], [[ONE_MINUS_BETA1_GRAD]])
-// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
 // CHECK: [[BETA2_OLD_V:%.*]] = "tf.Mul"([[BETA2]], [[OLD_V]])
 // CHECK: [[ONE_MINUS_BETA2:%.*]] = "tf.Sub"([[ONE]], [[BETA2]])
 // CHECK: [[GRAD_SQUARE:%.*]] = "tf.Square"([[GRAD]])
@@ -289,17 +289,17 @@
 // CHECK: [[SQRT_NEW_V:%.*]] = "tf.Sqrt"([[NEW_V]])
 // CHECK: [[SQRT_NEW_V_EPSILON:%.*]] = "tf.AddV2"([[SQRT_NEW_V]], [[EPSILON]])
 // CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[ALPHA_NEW_M]], [[SQRT_NEW_V_EPSILON]])
-// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
 // CHECK: [[NEW_VAR:%.*]] = "tf.Sub"([[OLD_VAR]], [[VAR_DELTA]])
 // CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]])
 // CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]])
 // CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]])
 
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+  %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
 
-  "tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+  "tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
 
   return
 }
@@ -322,12 +322,12 @@
 // CHECK: [[VAL_84:%.*]] = "tf.Sub"([[ONE]], [[BETA1_POWER]])
 // CHECK: [[VAL_85:%.*]] = "tf.Div"([[VAL_83]], [[VAL_84]])
 // CHECK: [[VAL_86:%.*]] = "tf.Mul"([[LR]], [[VAL_85]])
-// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
 // CHECK: [[VAL_88:%.*]] = "tf.Mul"([[BETA1]], [[OLD_M]])
 // CHECK: [[VAL_89:%.*]] = "tf.Sub"([[ONE]], [[BETA1]])
 // CHECK: [[VAL_90:%.*]] = "tf.Mul"([[VAL_89]], [[GRAD]])
 // CHECK: [[NEW_M:%.*]] = "tf.AddV2"([[VAL_88]], [[VAL_90]])
-// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
 // CHECK: [[VAL_93:%.*]] = "tf.Mul"([[BETA2]], [[OLD_V]])
 // CHECK: [[VAL_94:%.*]] = "tf.Sub"([[ONE]], [[BETA2]])
 // CHECK: [[VAL_95:%.*]] = "tf.Square"([[GRAD]])
@@ -341,17 +341,17 @@
 // CHECK: [[VAL_103:%.*]] = "tf.Sqrt"([[NEW_V]])
 // CHECK: [[VAL_104:%.*]] = "tf.AddV2"([[VAL_103]], [[EPSILON]])
 // CHECK: [[VAL_105:%.*]] = "tf.Div"([[VAL_102]], [[VAL_104]])
-// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
 // CHECK: [[NEW_VAR:%.*]] = "tf.Sub"([[OLD_VAR]], [[VAL_105]])
-// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
-// CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
-// CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
 
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+  %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
 
-  "tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+  "tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
 
   return
 }
@@ -366,12 +366,12 @@
   // CHECK: [[ZERO:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
 
   // CHECK: [[VAR:%.+]] = "tf.VarHandleOp"
-  %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
 
   // CHECK: [[READVAR:%.+]] = "tf.ReadVariableOp"([[VAR]])
   // CHECK: [[GATHER:%.+]] = "tf.GatherV2"([[READVAR]], [[INDEX]], [[ZERO]]) {batch_dims = 0 : i64} : (tensor<*xi32>, tensor<?xi32>, tensor<i64>) -> tensor<*xi32>
   // CHECK: return [[GATHER]]
-  %0 = "tf.ResourceGather"(%resource, %indices) : (tensor<*x!tf.resource>, tensor<?xi32>) -> (tensor<*xi32>)
+  %0 = "tf.ResourceGather"(%resource, %indices) : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<?xi32>) -> (tensor<*xi32>)
 
   return %0: tensor<*xi32>
 }
@@ -403,10 +403,10 @@
   // CHECK: [[MG_HANDLE:%.*]] = "tf.VarHandleOp"
   // CHECK: [[MS_HANDLE:%.*]] = "tf.VarHandleOp"
   // CHECK: [[MOM_HANDLE:%.*]] = "tf.VarHandleOp"
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
+  %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
+  %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
+  %3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
 
   // CHECK: [[GRADSQ:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]])
   // CHECK: [[SB:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
@@ -438,7 +438,7 @@
   // CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[MOM_NEW]])
   // CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
 
-  "tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+  "tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
   return
 }
 // -----
@@ -477,12 +477,12 @@
 // CHECK-SAME: ([[INDEX:%.+]]: tensor<2x?xi32>, [[UPDATE:%.+]]: tensor<?x?x?xi32>)
 func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates: tensor<?x?x?xi32>) {
   // CHECK: [[VAR:%.+]] = "tf.VarHandleOp"
-  %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
 
   // CHECK: [[READ:%.+]] = "tf.ReadVariableOp"([[VAR]])
   // CHECK: [[TENSOR:%.+]] = "tf.TensorScatterUpdate"([[READ]], [[INDEX]], [[UPDATE]]) : (tensor<*xi32>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
   // CHECK: "tf.AssignVariableOp"([[VAR]], [[TENSOR]])
-  "tf.ResourceScatterUpdate"(%resource, %indices, %updates) : (tensor<*x!tf.resource>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> ()
+  "tf.ResourceScatterUpdate"(%resource, %indices, %updates) : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> ()
 
   return
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir
index b18b03b..de9ef7b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir
@@ -340,7 +340,7 @@
 // Tests main function with invalid VarHandleOp resource subtype.
 
 func @main() {
-  // expected-error@+1 {{expects resource type to have one subtype, got '!tf.resource'}}
+  // expected-error @+1 {{must have exactly one subtype in the result resource type}}
   %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
   return
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir
index f2c045a..903d57e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir
@@ -12,18 +12,18 @@
 // -----
 
 // CHECK-LABEL: func @no_args
-// CHECK-SAME: (%arg0: tensor<!tf.resource> {tf.resource_name = "x"})
+// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
 // CHECK-NOT: "tf.VarHandleOp"
 func @no_args() {
-  %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
   return
 }
 
 // CHECK-LABEL: func @some_args
-// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource> {tf.resource_name = "x"})
+// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
 // CHECK-NOT: "tf.VarHandleOp"
 func @some_args(%arg0: tensor<i1>) {
-  %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
   return
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
index b4df968..5ad8336 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
@@ -6,7 +6,7 @@
 func @only_resource_load() -> tensor<*xi32> {
 
   // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
 
   // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
   // CHECK: "tf_device.cluster"
@@ -16,7 +16,7 @@
   // CHECK-SAME: () -> tensor<*xi32>
 
   %1 = "tf_device.cluster"() ( {
-    %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
+    %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
     %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
     tf_device.return %3 : tensor<*xi32>
   }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
@@ -32,7 +32,7 @@
 func @only_resource_store() -> tensor<*xi32> {
 
   // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
 
   // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
   // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"()
@@ -43,7 +43,7 @@
 
   %1 = "tf_device.cluster"() ( {
     %2 = "tf.SomeComputation"() : () -> (tensor<*xi32>)
-    "tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
+    "tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
     tf_device.return %2 : tensor<*xi32>
   }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
 
@@ -59,7 +59,7 @@
 func @same_resource_load_and_store() -> tensor<*xi32> {
 
   // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
 
   // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
   // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
@@ -70,9 +70,9 @@
   // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1)
 
   %1 = "tf_device.cluster"() ( {
-    %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
+    %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
     %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
-    "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
+    "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
     tf_device.return %3 : tensor<*xi32>
   }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
 
@@ -89,7 +89,7 @@
 func @same_resource_load_and_store_cast() -> tensor<1xi32> {
 
   // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
 
   // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
   // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
@@ -101,10 +101,10 @@
   // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1)
 
   %1 = "tf_device.cluster"() ( {
-    %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<1xi32>
+    %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<1xi32>
     %3 = "tf.SomeComputation"(%2) : (tensor<1xi32>) -> (tensor<*xi32>)
-    "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
-    %4 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<1xi32>
+    "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
+    %4 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<1xi32>
     tf_device.return %4 : tensor<1xi32>
   }) {cluster_attr = "cluster_attr"} : () -> tensor<1xi32>
 
@@ -123,16 +123,16 @@
   %0 = "tf_device.cluster"() ( {
 
     // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
-    %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+    %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
 
     // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
-    %2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
+    %2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
 
     // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
     %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
 
     // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[COMPUTE_RES]])
-    "tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
+    "tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
 
     // CHECK: tf_device.return %[[COMPUTE_RES]]
     tf_device.return %3 : tensor<*xi32>
@@ -1006,10 +1006,10 @@
   // CHECK: tf_device.return
   // CHECK: {cluster_attr = "cluster_attr"}
   // CHECK: return
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
   %1 = "tf_device.cluster"() ( {
-    %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
-    "tf.SomeResourceOperation"(%0) : (tensor<*x!tf.resource>) -> ()
+    %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
+    "tf.SomeResourceOperation"(%0) : (tensor<*x!tf.resource<tensor<*xi32>>>) -> ()
     %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
     tf_device.return %3 : tensor<*xi32>
   }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
@@ -1032,12 +1032,12 @@
   // CHECK-SAME: else_branch = @else_fn, is_stateless = true, then_branch = @then_fn
   // CHECK: tf_device.return
   // CHECK: return
-  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
-  %1 = "tf.VarHandleOp"() {container = "d", shared_name = "w"} : () -> tensor<*x!tf.resource>
+  %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
+  %1 = "tf.VarHandleOp"() {container = "d", shared_name = "w"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
   %2 = "tf_device.cluster"() ( {
     %3 = "tf.If"(%arg0, %0, %1)
           { else_branch = @else_fn, then_branch = @then_fn, is_stateless = true}
-          : (tensor<i1>, tensor<*x!tf.resource>, tensor<*x!tf.resource>) -> tensor<*xi32>
+          : (tensor<i1>, tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
     tf_device.return %3 : tensor<*xi32>
   }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
   return %2 : tensor<*xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index fddb45f..1b1ae77 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -4135,3 +4135,27 @@
   "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1) {device_var_reads_indices = [0], device_var_updates_indices = [-2]} : (tensor<!tf.resource<tensor<i32>>>, tensor<3x!tf.string>) -> ()
   return
 }
+
+// -----
+
+// Valid VarHandleOp operation.
+// CHECK-LABEL: func @testVarHandleOp
+func @testVarHandleOp() -> tensor<!tf.resource<tensor<*xf32>>> {
+  %0 = "tf.VarHandleOp"() {
+    container = "",
+    shared_name = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
+  } : () -> tensor<!tf.resource<tensor<*xf32>>>
+  return %0 : tensor<!tf.resource<tensor<*xf32>>>
+}
+
+// -----
+
+// VarHandleOp operation missing the required resource subtype.
+func @testVarHandleOp() -> tensor<*x!tf.resource> {
+  // expected-error @+1 {{must have exactly one subtype in the result resource type}}
+  %0 = "tf.VarHandleOp"() {
+    container = "",
+    shared_name = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
+  } : () -> tensor<*x!tf.resource>
+  return %0 : tensor<*x!tf.resource>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
index 7b670cd..fb82bbe 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
@@ -184,11 +184,11 @@
       mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
     tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
   }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
-  %var = "tf.VarHandleOp"() {container = "c", shared_name = "v", device = "/device:TPU:0"} : () -> tensor<*x!tf.resource>
+  %var = "tf.VarHandleOp"() {container = "c", shared_name = "v", device = "/device:TPU:0"} : () -> tensor<!tf.resource<tensor<3x3x1x32xf32>>>
   // CHECK-NOT: "tf.TPUGetLayoutOp"
   // CHECK-NOT: "tf.TPUCopyWithLayout"
   %2:2 = "tf.IteratorGetNext"(%var) {device = "/device:CPU:0"}
-    : (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
+    : (tensor<!tf.resource<tensor<3x3x1x32xf32>>>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
   "tf_device.launch"() ( {
     "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
     tf_device.return
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
index 0a23912..08dfd31 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
@@ -103,16 +103,9 @@
   return composite_users;
 }
 
-// Checks if `tf.VarHandleOp` has a valid resource subtype and its users are of
-// `tf.ReadVariableOp` and `tf.AssignVariableOp` only.
+// Checks that the only users of `tf.VarHandleOp` are
+// `tf.ReadVariableOp` and `tf.AssignVariableOp`.
 mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) {
-  auto resource_type =
-      getElementTypeOrSelf(var_handle_op.getType()).cast<TF::ResourceType>();
-  if (resource_type.getSubtypes().size() != 1)
-    return var_handle_op.emitOpError()
-           << "expects resource type to have one subtype, got "
-           << resource_type;
-
   auto composite_ops = GetCompositeResourceUserNames(var_handle_op);
   if (!composite_ops.empty())
     return var_handle_op.emitOpError()
diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc
index f3e8780..54ffc57 100644
--- a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc
+++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc
@@ -45,7 +45,7 @@
   // isn't a composite op. The following ops are explicitly skipped here because
   // their "no-op" expansion is known to cause problems in some cases.
   static const char* kOpsToSkip[] = {"IdentityOp", "NoOp", "OptionalHasValue",
-                                     "OptionalGetValue", "VarHandleOp"};
+                                     "OptionalGetValue"};
   for (const char* skip : kOpsToSkip) {
     if (absl::StartsWith(orig_op->op_name(), skip)) return Status::OK();
   }