blob: 98d846041b3a4955e2ffcbb7841161601dd03d39 [file] [log] [blame]
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-hoist-replicate-invariant-resource-writes | FileCheck %s
!tf_res_i32 = tensor<*x!tf_type.resource<tensor<i32>>>
!tf_res_f32 = tensor<*x!tf_type.resource<tensor<f32>>>
// CHECK-LABEL: func @hoist_tail_assign
// CHECK-SAME: ([[ARG0:%.*]]: tensor<*x!tf_type.resource<tensor<f32>>>)
func.func @hoist_tail_assign(%arg0: !tf_res_f32) {
// CHECK: [[REPLICATE:%.*]]:4 = tf_device.replicate {n = 2 : i32}
%replicate:2 = tf_device.replicate {n = 2 : i32} {
// CHECK: [[OP_A:%.*]] = "tf.OpA"
%op_a = "tf.OpA"() : () -> tensor<f32>
// CHECK: [[OP_B:%.*]] = "tf.OpB"
%op_b = "tf.OpB"() : () -> tensor<i32>
// CHECK-NOT: tf.AssignVariableOp
"tf.AssignVariableOp"(%arg0, %op_a) : (!tf_res_f32, tensor<f32>) -> ()
// CHECK: tf_device.return [[OP_B]], [[OP_A]] : tensor<i32>, tensor<f32>
tf_device.return %op_b : tensor<i32>
}
// CHECK: "tf.AssignVariableOp"([[ARG0]], [[REPLICATE]]#2)
func.return
}
// CHECK-LABEL: func @do_not_hoist_non_tail_assigns
// CHECK-SAME: ([[ARG0:%.*]]: tensor<*x!tf_type.resource<tensor<f32>>>)
func.func @do_not_hoist_non_tail_assigns(%arg0: !tf_res_f32) {
// CHECK: tf_device.replicate {n = 2 : i32}
tf_device.replicate {n = 2 : i32} {
// CHECK: [[OP_A:%.*]] = "tf.OpA"
%op_a = "tf.OpA"() : () -> tensor<f32>
// CHECK: "tf.AssignVariableOp"([[ARG0]], [[OP_A]])
"tf.AssignVariableOp"(%arg0, %op_a) : (!tf_res_f32, tensor<f32>) -> ()
// CHECK: "tf.ResourceAccessOp"([[ARG0]])
"tf.ResourceAccessOp"(%arg0) : (!tf_res_f32) -> tensor<i32>
// CHECK: tf_device.return
tf_device.return
}
// CHECK-NOT: tf.AssignVariableOp
func.return
}
// CHECK-LABEL: func @do_not_hoist_writes_to_explicitly_captured_resources
// CHECK-SAME: ([[ARG0:%.*]]: tensor<*x!tf_type.resource<tensor<f32>>>, [[ARG1:%.*]]: tensor<*x!tf_type.resource<tensor<f32>>>) {
func.func @do_not_hoist_writes_to_explicitly_captured_resources(%arg0: !tf_res_f32, %arg1: !tf_res_f32) {
// CHECK: [[REPLICATE:%.*]]:2 = tf_device.replicate({{\[}}[[ARG0]], [[ARG1]]] as [[RI:%.*]]: tensor<*x!tf_type.resource<tensor<f32>>>) {n = 2 : i32}
%replicate:2 = tf_device.replicate ([%arg0, %arg1] as %ri: tensor<*x!tf_type.resource<tensor<f32>>>) {n = 2 : i32} {
// CHECK: [[OP_A:%.*]] = "tf.OpA"()
%op_a = "tf.OpA"() : () -> tensor<f32>
// CHECK: [[OP_B:%.*]] = "tf.OpB"()
%op_b = "tf.OpB"() : () -> tensor<i32>
// CHECK: "tf.AssignVariableOp"([[RI]], [[OP_A]])
"tf.AssignVariableOp"(%ri, %op_a) : (!tf_res_f32, tensor<f32>) -> ()
// CHECK: tf_device.return [[OP_B]] : tensor<i32>
tf_device.return %op_b : tensor<i32>
}
// CHECK-NOT: tf.AssignVariableOp
func.return
}
// CHECK-LABEL: func @only_hoist_tail_assign
// CHECK-SAME: ([[ARG0:%.*]]: tensor<*x!tf_type.resource<tensor<f32>>>)
func.func @only_hoist_tail_assign(%arg0: !tf_res_f32) {
// CHECK: [[REPLICATE:%.*]]:2 = tf_device.replicate {n = 2 : i32}
tf_device.replicate {n = 2 : i32} {
// CHECK: [[OP_A:%.*]] = "tf.OpA"
%op_a = "tf.OpA"() : () -> tensor<f32>
// CHECK: [[OP_B:%.*]] = "tf.OpB"
%op_b = "tf.OpB"() : () -> tensor<f32>
// CHECK: "tf.AssignVariableOp"([[ARG0]], [[OP_A]])
"tf.AssignVariableOp"(%arg0, %op_a) : (!tf_res_f32, tensor<f32>) -> ()
// CHECK-NOT: tf.AssignVariableOp
"tf.AssignVariableOp"(%arg0, %op_b) : (!tf_res_f32, tensor<f32>) -> ()
// CHECK: tf_device.return [[OP_B]] : tensor<f32>
tf_device.return
}
// CHECK: "tf.AssignVariableOp"([[ARG0]], [[REPLICATE]]#0)
func.return
}