Verify resource subtype is supplied for tf.VarHandleOp.
If no resource subtype is supplied, we can't derive the dtype attribute.
Verify that a resource subtype is supplied so we get a verification failure
instead of an assert failure if you try to access the dtype attribute for a
tf.VarHandleOp of type tensor<*x!tf.resource>.
PiperOrigin-RevId: 350805993
Change-Id: I049fb7e292993f6e760bb850a01337a3a8c09452
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
index a0f4baf..39a45f1 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
@@ -568,26 +568,6 @@
"return (*getOperation()->result_type_begin()).cast<ShapedType>();",
[{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
-// A derived attribute that returns the element type of the tensor held by a
-// named resource-type operand or result.
-class TF_DerivedOperandOrResultHandleTypeAttr<string name> : DerivedTypeAttr<
- "auto resource_type =\n"
- " mlir::getElementTypeOrSelf(this->" # name # "())\n"
- " .cast<TF::ResourceType>();\n"
- "assert(!resource_type.getSubtypes().empty() && \"unknown type\");\n"
- "return mlir::getElementTypeOrSelf(*resource_type.getSubtypes().begin());">;
-
-// A derived attribute that returns the shape of the tensor held by a named
-// resource-type operand or result.
-class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr<
- "ShapedType",
- "auto resource_type =\n"
- " mlir::getElementTypeOrSelf(this->" # name # "())\n"
- " .cast<TF::ResourceType>();\n"
- "assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n"
- "return resource_type.getSubtypes().begin()->cast<ShapedType>();",
- [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
-
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
let returnType = "Type";
}
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 4615064..93792ab 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -865,16 +865,35 @@
Res<TF_ResourceTensor, "", [TF_VariableAlloc]>:$resource
);
- TF_DerivedOperandOrResultHandleTypeAttr dtype =
- TF_DerivedOperandOrResultHandleTypeAttr<"resource">;
- TF_DerivedOperandOrResultHandleShapeAttr shape =
- TF_DerivedOperandOrResultHandleShapeAttr<"resource">;
+ let verifier = [{
+ // VarHandleOp requires the resource handle supply a single subtype from
+ // which to derive the dtype and shape attributes.
+ if (resource_type().getSubtypes().size() != 1) {
+ return emitOpError(
+ "must have exactly one subtype in the result resource type");
+ }
+
+ return success();
+ }];
+
+ DerivedTypeAttr dtype = DerivedTypeAttr<
+ "return getElementTypeOrSelf(resource_subtype());">;
+ DerivedAttr shape = DerivedAttr<
+ "ShapedType",
+ "return resource_subtype().cast<ShapedType>();",
+ [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
let extraClassDeclaration = [{
// TF_ResourceHandleAllocatorInterface:
ResourceHandleValueAndId GetResourceHandleValueAndId(
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
int64_t &next_id);
+
+ TensorType resource_subtype() { return resource_type().getSubtypes()[0]; }
+
+ ResourceType resource_type() {
+ return getElementTypeOrSelf(resource()).cast<TF::ResourceType>();
+ }
}];
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
index 920f535..5d8cdc7 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/decompose_resource_ops.mlir
@@ -28,7 +28,7 @@
// CHECK-LABEL: func @decompose_assign_add_variable_op
func @decompose_assign_add_variable_op() -> () {
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<i32>>>
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
@@ -36,7 +36,7 @@
// CHECK: "tf.AssignVariableOp"
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
- "tf.AssignAddVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
+ "tf.AssignAddVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
@@ -49,7 +49,7 @@
// CHECK-LABEL: func @decompose_assign_sub_variable_op
func @decompose_assign_sub_variable_op() -> () {
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<i32>>>
// CHECK: %[[ONE:[0-9]*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"
@@ -57,7 +57,7 @@
// CHECK: "tf.AssignVariableOp"
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
- "tf.AssignSubVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>, tensor<i32>) -> ()
+ "tf.AssignSubVariableOp"(%0, %1) {dtype = "tfdtype$DT_INT32"} : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
@@ -70,7 +70,7 @@
// CHECK-SAME: (%[[DELTA:.*]]: tensor<f32>)
func @decompose_resource_apply_gradient_descent(%arg0: tensor<f32>) -> () {
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
// CHECK: %[[ALPHA:[0-9]*]] = "tf.Const"
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
@@ -80,7 +80,7 @@
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[SUB]])
%1 = "tf.Const"() {T = f32, value = dense<[0.5]> : tensor<1xf32>} : () -> tensor<f32>
- "tf.ResourceApplyGradientDescent"(%0, %1, %arg0) {use_locking = false} : (tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
+ "tf.ResourceApplyGradientDescent"(%0, %1, %arg0) {use_locking = false} : (tensor<!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -96,8 +96,8 @@
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[ACCUM_HANDLE:%.*]] = "tf.VarHandleOp"
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
// CHECK: [[ACCUM:%.*]] = "tf.ReadVariableOp"([[ACCUM_HANDLE]])
// CHECK: [[ACCUM_MOMENTUM:%.*]] = "tf.Mul"([[ACCUM]], [[MOMENTUM]])
@@ -107,7 +107,7 @@
// CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]])
// CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[ACCUM_NEW_LR]])
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
- "tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+ "tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -122,8 +122,8 @@
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[ACCUM_HANDLE:%.*]] = "tf.VarHandleOp"
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<!tf.resource<tensor<f32>>>
// CHECK: [[ACCUM:%.*]] = "tf.ReadVariableOp"([[ACCUM_HANDLE]])
// CHECK: [[ACCUM_MOMENTUM:%.*]] = "tf.Mul"([[ACCUM]], [[MOMENTUM]])
@@ -136,7 +136,7 @@
// CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]])
// CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[DELTA]])
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
- "tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+ "tf.ResourceApplyMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -151,10 +151,10 @@
// CHECK: %[[VAR_HANDLE:[0-9]*]] = "tf.VarHandleOp"
// CHECK: %[[ACCUM_HANDLE:[0-9]*]] = "tf.VarHandleOp"
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
- // CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+ // CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: %[[ACCUM_MOMENTUM:[0-9]*]] = "tf.Mul"(%[[ACCUM]], %[[MOMENTUM]])
// CHECK: %[[GRAD_LR:[0-9]*]] = "tf.Mul"(%[[GRAD]], %[[LR]])
// CHECK: %[[NEW_ACCUM:[0-9]*]] = "tf.Sub"(%[[ACCUM_MOMENTUM]], %[[GRAD_LR]])
@@ -164,7 +164,7 @@
// CHECK: %[[NEW_VAR:[0-9]*]] = "tf.AddV2"(%[[VAR]], %[[NEW_ACCUM]])
// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[NEW_VAR]])
- "tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+ "tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -180,10 +180,10 @@
// CHECK: %[[VAR_HANDLE:[0-9]*]] = "tf.VarHandleOp"
// CHECK: %[[ACCUM_HANDLE:[0-9]*]] = "tf.VarHandleOp"
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
- // CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+ // CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: %[[ACCUM_MOMENTUM:[0-9]*]] = "tf.Mul"(%[[ACCUM]], %[[MOMENTUM]])
// CHECK: %[[GRAD_LR:[0-9]*]] = "tf.Mul"(%[[GRAD]], %[[LR]])
// CHECK: %[[NEW_ACCUM:[0-9]*]] = "tf.Sub"(%[[ACCUM_MOMENTUM]], %[[GRAD_LR]])
@@ -195,7 +195,7 @@
// CHECK: %[[NEW_VAR:[0-9]*]] = "tf.AddV2"(%[[VAR]], %[[NEW_DELTA]])
// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[NEW_VAR]])
- "tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+ "tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -212,21 +212,21 @@
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"()
// CHECK: [[ACC_HANDLE:%.*]] = "tf.VarHandleOp"()
// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_ACC:%.*]] = "tf.ReadVariableOp"([[ACC_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[NEW_ACC:%.*]] = "tf.AddV2"([[OLD_ACC]], [[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: [[LR_MULTIPLY:%.*]] = "tf.Mul"([[LR]], [[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: [[SQRT:%.*]] = "tf.Sqrt"([[NEW_ACC]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[DIVISOR:%.*]] = "tf.AddV2"([[SQRT]], [[EPSILON]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[LR_MULTIPLY]], [[DIVISOR]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
-// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"(%9, %8) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
-// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
-// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"([[ACC_HANDLE]], [[NEW_ACC]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
- "tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+ "tf.ResourceApplyAdagradV2"(%0, %1, %arg0, %arg1, %arg2) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -236,22 +236,22 @@
// CHECK-SAME: (%[[LR:.*]]: tensor<f32>, %[[GRAD:.*]]: tensor<f32>)
func @decompose_resource_apply_adagrad(%arg0: tensor<f32>, %arg1: tensor<f32>) -> () {
- // CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- // CHECK: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ // CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+ // CHECK: %[[ACCUM_HANDLE:.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
// CHECK: %[[GRAD_SQUARE:.*]] = "tf.Mul"(%[[GRAD]], %[[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
- // CHECK: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+ // CHECK: %[[ACCUM:.*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: %[[ACCUM_NEW:.*]] = "tf.AddV2"(%[[ACCUM]], %[[GRAD_SQUARE]]) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: %[[LR_MULTIPLY:.*]] = "tf.Mul"(%[[LR]], %[[GRAD]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[SQRT:.*]] = "tf.Sqrt"(%[[ACCUM_NEW]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: %[[DIV:.*]] = "tf.Div"(%[[LR_MULTIPLY]], %[[SQRT]]) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
- // CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+ // CHECK: %[[VAR:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: %[[VAR_NEW:.*]] = "tf.Sub"(%[[VAR]], %[[DIV]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
- // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
- // CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ // CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[VAR_NEW]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+ // CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[ACCUM_NEW]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
- "tf.ResourceApplyAdagrad"(%0, %1, %arg0, %arg1) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>) -> ()
+ "tf.ResourceApplyAdagrad"(%0, %1, %arg0, %arg1) {update_slots = true, use_locking = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -274,12 +274,12 @@
// CHECK: [[ONE_MINUS_BETA1_POWER:%.*]] = "tf.Sub"([[ONE]], [[BETA1_POWER]])
// CHECK: [[ALPHA_NO_LR:%.*]] = "tf.Div"([[SQRT_ONE_MINUS_BETA2_POWER]], [[ONE_MINUS_BETA1_POWER]])
// CHECK: [[ALPHA:%.*]] = "tf.Mul"([[LR]], [[ALPHA_NO_LR]])
-// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[BETA1_OLD_M:%.*]] = "tf.Mul"([[BETA1]], [[OLD_M]])
// CHECK: [[ONE_MINUS_BETA1:%.*]] = "tf.Sub"([[ONE]], [[BETA1]])
// CHECK: [[ONE_MINUS_BETA1_GRAD:%.*]] = "tf.Mul"([[ONE_MINUS_BETA1]], [[GRAD]])
// CHECK: [[NEW_M:%.*]] = "tf.AddV2"([[BETA1_OLD_M]], [[ONE_MINUS_BETA1_GRAD]])
-// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[BETA2_OLD_V:%.*]] = "tf.Mul"([[BETA2]], [[OLD_V]])
// CHECK: [[ONE_MINUS_BETA2:%.*]] = "tf.Sub"([[ONE]], [[BETA2]])
// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Square"([[GRAD]])
@@ -289,17 +289,17 @@
// CHECK: [[SQRT_NEW_V:%.*]] = "tf.Sqrt"([[NEW_V]])
// CHECK: [[SQRT_NEW_V_EPSILON:%.*]] = "tf.AddV2"([[SQRT_NEW_V]], [[EPSILON]])
// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[ALPHA_NEW_M]], [[SQRT_NEW_V_EPSILON]])
-// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"([[OLD_VAR]], [[VAR_DELTA]])
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]])
// CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]])
// CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]])
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+ %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
- "tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+ "tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -322,12 +322,12 @@
// CHECK: [[VAL_84:%.*]] = "tf.Sub"([[ONE]], [[BETA1_POWER]])
// CHECK: [[VAL_85:%.*]] = "tf.Div"([[VAL_83]], [[VAL_84]])
// CHECK: [[VAL_86:%.*]] = "tf.Mul"([[LR]], [[VAL_85]])
-// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[VAL_88:%.*]] = "tf.Mul"([[BETA1]], [[OLD_M]])
// CHECK: [[VAL_89:%.*]] = "tf.Sub"([[ONE]], [[BETA1]])
// CHECK: [[VAL_90:%.*]] = "tf.Mul"([[VAL_89]], [[GRAD]])
// CHECK: [[NEW_M:%.*]] = "tf.AddV2"([[VAL_88]], [[VAL_90]])
-// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[VAL_93:%.*]] = "tf.Mul"([[BETA2]], [[OLD_V]])
// CHECK: [[VAL_94:%.*]] = "tf.Sub"([[ONE]], [[BETA2]])
// CHECK: [[VAL_95:%.*]] = "tf.Square"([[GRAD]])
@@ -341,17 +341,17 @@
// CHECK: [[VAL_103:%.*]] = "tf.Sqrt"([[NEW_V]])
// CHECK: [[VAL_104:%.*]] = "tf.AddV2"([[VAL_103]], [[EPSILON]])
// CHECK: [[VAL_105:%.*]] = "tf.Div"([[VAL_102]], [[VAL_104]])
-// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
+// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"([[OLD_VAR]], [[VAL_105]])
-// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
-// CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
-// CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
+// CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]]) : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*xf32>) -> ()
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
+ %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xf32>>>
- "tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+ "tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<*x!tf.resource<tensor<*xf32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
@@ -366,12 +366,12 @@
// CHECK: [[ZERO:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
// CHECK: [[VAR:%.+]] = "tf.VarHandleOp"
- %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
// CHECK: [[READVAR:%.+]] = "tf.ReadVariableOp"([[VAR]])
// CHECK: [[GATHER:%.+]] = "tf.GatherV2"([[READVAR]], [[INDEX]], [[ZERO]]) {batch_dims = 0 : i64} : (tensor<*xi32>, tensor<?xi32>, tensor<i64>) -> tensor<*xi32>
// CHECK: return [[GATHER]]
- %0 = "tf.ResourceGather"(%resource, %indices) : (tensor<*x!tf.resource>, tensor<?xi32>) -> (tensor<*xi32>)
+ %0 = "tf.ResourceGather"(%resource, %indices) : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<?xi32>) -> (tensor<*xi32>)
return %0: tensor<*xi32>
}
@@ -403,10 +403,10 @@
// CHECK: [[MG_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[MS_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[MOM_HANDLE:%.*]] = "tf.VarHandleOp"
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
+ %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
+ %3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: [[GRADSQ:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]])
// CHECK: [[SB:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
@@ -438,7 +438,7 @@
// CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[MOM_NEW]])
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
- "tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+ "tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
// -----
@@ -477,12 +477,12 @@
// CHECK-SAME: ([[INDEX:%.+]]: tensor<2x?xi32>, [[UPDATE:%.+]]: tensor<?x?x?xi32>)
func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates: tensor<?x?x?xi32>) {
// CHECK: [[VAR:%.+]] = "tf.VarHandleOp"
- %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
// CHECK: [[READ:%.+]] = "tf.ReadVariableOp"([[VAR]])
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterUpdate"([[READ]], [[INDEX]], [[UPDATE]]) : (tensor<*xi32>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
// CHECK: "tf.AssignVariableOp"([[VAR]], [[TENSOR]])
- "tf.ResourceScatterUpdate"(%resource, %indices, %updates) : (tensor<*x!tf.resource>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> ()
+ "tf.ResourceScatterUpdate"(%resource, %indices, %updates) : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> ()
return
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir
index b18b03b..de9ef7b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir
@@ -340,7 +340,7 @@
// Tests main function with invalid VarHandleOp resource subtype.
func @main() {
- // expected-error@+1 {{expects resource type to have one subtype, got '!tf.resource'}}
+ // expected-error @+1 {{must have exactly one subtype in the result resource type}}
%0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
return
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir
index f2c045a..903d57e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir
@@ -12,18 +12,18 @@
// -----
// CHECK-LABEL: func @no_args
-// CHECK-SAME: (%arg0: tensor<!tf.resource> {tf.resource_name = "x"})
+// CHECK-SAME: (%arg0: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
// CHECK-NOT: "tf.VarHandleOp"
func @no_args() {
- %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
return
}
// CHECK-LABEL: func @some_args
-// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource> {tf.resource_name = "x"})
+// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<!tf.resource<tensor<f32>>> {tf.resource_name = "x"})
// CHECK-NOT: "tf.VarHandleOp"
func @some_args(%arg0: tensor<i1>) {
- %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<f32>>>
return
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
index b4df968..5ad8336 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir
@@ -6,7 +6,7 @@
func @only_resource_load() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
// CHECK: "tf_device.cluster"
@@ -16,7 +16,7 @@
// CHECK-SAME: () -> tensor<*xi32>
%1 = "tf_device.cluster"() ( {
- %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
+ %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
@@ -32,7 +32,7 @@
func @only_resource_store() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
// CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"()
@@ -43,7 +43,7 @@
%1 = "tf_device.cluster"() ( {
%2 = "tf.SomeComputation"() : () -> (tensor<*xi32>)
- "tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
+ "tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
tf_device.return %2 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
@@ -59,7 +59,7 @@
func @same_resource_load_and_store() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
// CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
@@ -70,9 +70,9 @@
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1)
%1 = "tf_device.cluster"() ( {
- %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
+ %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
- "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
+ "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
@@ -89,7 +89,7 @@
func @same_resource_load_and_store_cast() -> tensor<1xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
// CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
@@ -101,10 +101,10 @@
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1)
%1 = "tf_device.cluster"() ( {
- %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<1xi32>
+ %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<1xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<1xi32>) -> (tensor<*xi32>)
- "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
- %4 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<1xi32>
+ "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
+ %4 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<1xi32>
tf_device.return %4 : tensor<1xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<1xi32>
@@ -123,16 +123,16 @@
%0 = "tf_device.cluster"() ( {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
- %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
- %2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
+ %2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]])
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[COMPUTE_RES]])
- "tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> ()
+ "tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*xi32>) -> ()
// CHECK: tf_device.return %[[COMPUTE_RES]]
tf_device.return %3 : tensor<*xi32>
@@ -1006,10 +1006,10 @@
// CHECK: tf_device.return
// CHECK: {cluster_attr = "cluster_attr"}
// CHECK: return
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%1 = "tf_device.cluster"() ( {
- %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32>
- "tf.SomeResourceOperation"(%0) : (tensor<*x!tf.resource>) -> ()
+ %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
+ "tf.SomeResourceOperation"(%0) : (tensor<*x!tf.resource<tensor<*xi32>>>) -> ()
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
@@ -1032,12 +1032,12 @@
// CHECK-SAME: else_branch = @else_fn, is_stateless = true, then_branch = @then_fn
// CHECK: tf_device.return
// CHECK: return
- %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
- %1 = "tf.VarHandleOp"() {container = "d", shared_name = "w"} : () -> tensor<*x!tf.resource>
+ %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
+ %1 = "tf.VarHandleOp"() {container = "d", shared_name = "w"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
%2 = "tf_device.cluster"() ( {
%3 = "tf.If"(%arg0, %0, %1)
{ else_branch = @else_fn, then_branch = @then_fn, is_stateless = true}
- : (tensor<i1>, tensor<*x!tf.resource>, tensor<*x!tf.resource>) -> tensor<*xi32>
+ : (tensor<i1>, tensor<*x!tf.resource<tensor<*xi32>>>, tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<*xi32>
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
return %2 : tensor<*xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index fddb45f..1b1ae77 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -4135,3 +4135,27 @@
"tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1) {device_var_reads_indices = [0], device_var_updates_indices = [-2]} : (tensor<!tf.resource<tensor<i32>>>, tensor<3x!tf.string>) -> ()
return
}
+
+// -----
+
+// Valid VarHandleOp operation.
+// CHECK-LABEL: func @testVarHandleOp
+func @testVarHandleOp() -> tensor<!tf.resource<tensor<*xf32>>> {
+ %0 = "tf.VarHandleOp"() {
+ container = "",
+ shared_name = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
+ } : () -> tensor<!tf.resource<tensor<*xf32>>>
+ return %0 : tensor<!tf.resource<tensor<*xf32>>>
+}
+
+// -----
+
+// VarHandleOp operation missing the required resource subtype.
+func @testVarHandleOp() -> tensor<*x!tf.resource> {
+ // expected-error @+1 {{must have exactly one subtype in the result resource type}}
+ %0 = "tf.VarHandleOp"() {
+ container = "",
+ shared_name = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
+ } : () -> tensor<*x!tf.resource>
+ return %0 : tensor<*x!tf.resource>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
index 7b670cd..fb82bbe 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
@@ -184,11 +184,11 @@
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
- %var = "tf.VarHandleOp"() {container = "c", shared_name = "v", device = "/device:TPU:0"} : () -> tensor<*x!tf.resource>
+ %var = "tf.VarHandleOp"() {container = "c", shared_name = "v", device = "/device:TPU:0"} : () -> tensor<!tf.resource<tensor<3x3x1x32xf32>>>
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2:2 = "tf.IteratorGetNext"(%var) {device = "/device:CPU:0"}
- : (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
+ : (tensor<!tf.resource<tensor<3x3x1x32xf32>>>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
index 0a23912..08dfd31 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc
@@ -103,16 +103,9 @@
return composite_users;
}
-// Checks if `tf.VarHandleOp` has a valid resource subtype and its users are of
-// `tf.ReadVariableOp` and `tf.AssignVariableOp` only.
+// Checks that the only users of `tf.VarHandleOp` are
+// `tf.ReadVariableOp` and `tf.AssignVariableOp`.
mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) {
- auto resource_type =
- getElementTypeOrSelf(var_handle_op.getType()).cast<TF::ResourceType>();
- if (resource_type.getSubtypes().size() != 1)
- return var_handle_op.emitOpError()
- << "expects resource type to have one subtype, got "
- << resource_type;
-
auto composite_ops = GetCompositeResourceUserNames(var_handle_op);
if (!composite_ops.empty())
return var_handle_op.emitOpError()
diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc
index f3e8780..54ffc57 100644
--- a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc
+++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc
@@ -45,7 +45,7 @@
// isn't a composite op. The following ops are explicitly skipped here because
// their "no-op" expansion is known to cause problems in some cases.
static const char* kOpsToSkip[] = {"IdentityOp", "NoOp", "OptionalHasValue",
- "OptionalGetValue", "VarHandleOp"};
+ "OptionalGetValue"};
for (const char* skip : kOpsToSkip) {
if (absl::StartsWith(orig_op->op_name(), skip)) return Status::OK();
}