blob: 70a3b27f19434a937488c48f426da47f2713a285 [file] [log] [blame]
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-cluster-formation | FileCheck %s
// Test ops in cluster only have `_replication_info` and `device` attributes
// removed when moved to a `tf_device.cluster`.
// CHECK-LABEL: func @cluster_ops_removed_attrs
func.func @cluster_ops_removed_attrs() {
%0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
func.return
}
// CHECK: "tf.opA"
// CHECK-SAME: name = "name"
// CHECK-NOT: _replication_info = "replicate"
// CHECK-NOT: device = "device"
// CHECK: tf_device.return
// Test TPUReplicateMetadata ops `name` and `num_replicas` attributes are not
// copied over to `tf_device.cluster`.
// CHECK-LABEL: func @removed_metadata_attrs
func.func @removed_metadata_attrs() {
%0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", num_replicas = 1, topology = "topology"} : () -> ()
func.return
}
// CHECK-NOT: name = "name"
// CHECK-NOT: num_replicas = 1
// Test TPUReplicateMetadata op is removed when forming clusters.
// CHECK-LABEL: func @metadata_op_removed
func.func @metadata_op_removed() {
%0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
func.return
}
// CHECK-NOT: "tf.TPUReplicateMetadata"
// Test ops in a function body with the same `_replication_info` attribute are
// merged under a `tf_device.cluster` op.
// CHECK-LABEL: func @ops_in_func_body
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func.func @ops_in_func_body(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>, tensor<i1>) {
%0 = "tf.opA"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>) -> tensor<i1>
%1 = "tf.opB"() {is_stateless = true} : () -> tensor<i1>
%2 = "tf.opC"(%0) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%3 = "tf.opD"(%2) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>) -> tensor<i1>
%4 = "tf.opE"() {is_stateless = true} : () -> tensor<i1>
%5 = "tf.opF"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>) -> tensor<i1>
func.return %2, %3, %5 : tensor<i1>, tensor<i1>, tensor<i1>
}
// CHECK: "tf.opB"
// CHECK: "tf.opE"
// CHECK: %[[CLUSTER:[0-9]*]]:3 = "tf_device.cluster"() ({
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]])
// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_0]])
// CHECK-NEXT: tf_device.return %[[OP_C]], %[[OP_D]], %[[OP_F]]
// CHECK-NEXT: _replication_info = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: return %[[CLUSTER]]#0, %[[CLUSTER]]#1, %[[CLUSTER]]#2
// Test a nested user of an op in a cluster has its operand be updated to
// `tf_device.cluster` result.
// CHECK-LABEL: func @nested_cluster_op_user
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func.func @nested_cluster_op_user(%arg0 : tensor<i1>) -> (tensor<i1>) {
%0 = "tf.opA"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>) -> tensor<i1>
%1 = "tf_device.launch"() ({
tf_device.return %0 : tensor<i1>
}) {device = "device"} : () -> tensor<i1>
%2 = "tf.opB"(%0) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
func.return %2 : tensor<i1>
}
// CHECK: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ({
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
// CHECK-NEXT: tf_device.return %[[OP_A]], %[[OP_B]]
// CHECK-NEXT: _replication_info = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: tf_device.launch
// CHECK-NEXT: tf_device.return %[[CLUSTER]]#0
// CHECK: return %[[CLUSTER]]#1
// Test nested op of a cluster with an operand from an op of the same cluster
// retains its original operand.
// CHECK-LABEL: func @nested_cluster_op
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func.func @nested_cluster_op(%arg0 : tensor<i1>) -> (tensor<i1>) {
%0 = "tf.opA"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>) -> tensor<i1>
%1 = "tf.opB"() ({
"tf.opC"(%0) {is_stateless = true} : (tensor<i1>) -> tensor<i1>
}) {_xla_compile_device_type = "TPU", _replication_info = "replicate"} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
func.return %1 : tensor<i1>
}
// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ({
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"() ({
// CHECK-NEXT: "tf.opC"(%[[OP_A]])
// CHECK: tf_device.return %[[OP_B]]
// CHECK-NEXT: _replication_info = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: return %[[CLUSTER]]
// Test multiple clusters interleaved.
// CHECK-LABEL: func @interleaved_clusters
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func.func @interleaved_clusters(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate_1", device = "device_1", num_replicas = 1, topology = "topology_1"} : () -> ()
%0 = "tf.opA"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "replicate_0", is_stateless = true} : (tensor<i1>) -> tensor<i1>
%1 = "tf.opB"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "replicate_1", is_stateless = true} : (tensor<i1>) -> tensor<i1>
%2 = "tf.opC"(%0) {_xla_compile_device_type = "TPU", _replication_info = "replicate_0", is_stateless = true} : (tensor<i1>) -> tensor<i1>
%3 = "tf.opD"(%1) {_xla_compile_device_type = "TPU", _replication_info = "replicate_1", is_stateless = true} : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate_0", device = "device_0", num_replicas = 1, topology = "topology_0"} : () -> ()
func.return %2, %3 : tensor<i1>, tensor<i1>
}
// CHECK: %[[CLUSTER_0:[0-9]*]] = "tf_device.cluster"() ({
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
// CHECK-NEXT: tf_device.return %[[OP_C]]
// CHECK-NEXT: _replication_info = "replicate_0"
// CHECK-SAME: device = "device_0"
// CHECK-SAME: topology = "topology_0"
// CHECK: %[[CLUSTER_1:[0-9]*]] = "tf_device.cluster"() ({
// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_B]])
// CHECK-NEXT: tf_device.return %[[OP_D]]
// CHECK-NEXT: _replication_info = "replicate_1"
// CHECK-SAME: device = "device_1"
// CHECK-SAME: topology = "topology_1"
// CHECK: return %[[CLUSTER_0]], %[[CLUSTER_1]]
// Test operands and results of ops of a cluster that are interleaved between
// other ops of the same cluster are moved before and after the cluster
// properly.
// CHECK-LABEL: func @interleaved_cluster_operands_results
func.func @interleaved_cluster_operands_results() {
%0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : () -> tensor<i1>
%1 = "tf.opB"(%0) {is_stateless = true} : (tensor<i1>) -> tensor<i1>
%2 = "tf.opC"() {is_stateless = true} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%3 = "tf.opD"(%1) {is_stateless = true} : (tensor<i1>) -> tensor<i1>
%4 = "tf.opE"(%2) {is_stateless = true} : (tensor<i1>) -> tensor<i1>
%5 = "tf.opF"(%4) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>) -> tensor<i1>
func.return
}
// CHECK: %[[OP_C:[0-9]*]] = "tf.opC"
// CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_C]])
// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ({
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"
// CHECK-NEXT: "tf.opF"(%[[OP_E]])
// CHECK-NEXT: tf_device.return %[[OP_A]]
// CHECK-NEXT: _replication_info = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[CLUSTER]])
// CHECK: "tf.opD"(%[[OP_B]])
// Test one replica cluster results in removing of TPUReplicatedInput and
// TPUReplicatedOutput nodes and operands are forwarded to results.
// CHECK-LABEL: func @one_replica
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func.func @one_replica(%arg0: tensor<i1>) -> tensor<i1> {
%ri = "tf.TPUReplicatedInput"(%arg0) : (tensor<i1>) -> tensor<i1>
%0 = "tf.opA"(%ri) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>) -> tensor<i1>
%1 = "tf.opB"(%0) {is_stateless = true} : (tensor<i1>) -> tensor<i1>
%2 = "tf.opC"() {is_stateless = true} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%3 = "tf.opD"(%1) {is_stateless = true} : (tensor<i1>) -> tensor<i1>
%4 = "tf.opE"(%2) {is_stateless = true} : (tensor<i1>) -> tensor<i1>
%5 = "tf.opF"(%4) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>) -> tensor<i1>
%ro = "tf.TPUReplicatedOutput"(%5) : (tensor<i1>) -> tensor<i1>
func.return %ro : tensor<i1>
}
// CHECK: %[[OP_C:[0-9]*]] = "tf.opC"
// CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_C]])
// CHECK: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ({
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[OP_E]])
// CHECK-NEXT: tf_device.return %[[OP_A]], %[[OP_F]]
// CHECK-NEXT: _replication_info = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[CLUSTER]]#0)
// CHECK: "tf.opD"(%[[OP_B]])
// CHECK: return %[[CLUSTER]]#1
// CHECK-NOT: "tf.TPUReplicatedInput"
// CHECK-NOT: "tf.TPUReplicatedOutput"
// Test replication with replicated operands and replicated results. The cluster
// will be wrapped in a `tf_device.cluster` first and then by a replicate.
// TPUReplicatedInput and TPUReplicatedOutput nodes will be replaced by the
// replicate operands and results.
// CHECK-LABEL: func @replication
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>, %[[ARG_1:[a-z0-9]*]]: tensor<i32>, %[[ARG_2:[a-z0-9]*]]: tensor<f32>)
func.func @replication(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<f32>) -> (tensor<i32>, tensor<f32>) {
%0 = "tf.opA"() {is_stateless = true} : () -> tensor<i1>
%ri_0 = "tf.TPUReplicatedInput"(%arg0, %0) : (tensor<i1>, tensor<i1>) -> tensor<i1>
%1 = "tf.opB"() {is_stateless = true} : () -> tensor<i32>
%ri_1 = "tf.TPUReplicatedInput"(%1, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%2 = "tf.opC"() {is_stateless = true} : () -> tensor<f32>
%3 = "tf.opD"(%ri_0, %ri_1, %arg2, %2) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i1>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<i32>
%ro_0:2 = "tf.TPUReplicatedOutput"(%3) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
%7 = "tf.opE"(%3, %ri_0, %ri_1, %arg2, %2) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i32>, tensor<i1>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
%ro_1:2 = "tf.TPUReplicatedOutput"(%7) : (tensor<f32>) -> (tensor<f32>, tensor<f32>)
func.return %ro_0#0, %ro_1#1 : tensor<i32>, tensor<f32>
}
// CHECK: %[[OP_A:[0-9]*]] = "tf.opA"
// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"
// CHECK: %[[OP_C:[0-9]*]] = "tf.opC"
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
// CHECK-DAG: [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>
// CHECK-DAG: [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>
// CHECK-NOT: _replicated_input_indices
// CHECK-SAME: n = 2 : i32
// CHECK-NEXT: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ({
// CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]])
// CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]], %[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]])
// CHECK: tf_device.return %[[OP_D]], %[[OP_E]]
// CHECK-NEXT: _replication_info = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: tf_device.return %[[CLUSTER]]#0, %[[CLUSTER]]#1
// CHECK: return %[[REPLICATE]]#0, %[[REPLICATE]]#3
// Test replication with model parallelism using partitioned resource inputs.
// The cluster will be wrapped in a `tf_device.cluster` first and then by a
// replicate.
// TPUPartitionedInput nodes would be inside the replicate but outside the
// cluster.
// TPUReplicatedInput and TPUReplicatedOutput nodes will be replaced by the
// replicate operands and results.
// CHECK-LABEL: func @replication_with_model_parallelism
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, %[[ARG_1:[a-z0-9]*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, %[[ARG_2:[a-z0-9]*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>, %[[ARG_3:[a-z0-9]*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>)
!rtype = tensor<!tf_type.resource<tensor<10x3xf32>>>
func.func @replication_with_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> (tensor<10x3xf32>, tensor<f32>) {
%0 = "tf.opA"() {is_stateless = true} : () -> tensor<i32>
%1 = "tf.opB"() {is_stateless = true} : () -> tensor<i32>
%2 = "tf.TPUReplicatedInput"(%arg0, %arg2) : (!rtype, !rtype) -> !rtype
%3 = "tf.TPUReplicatedInput"(%arg1, %arg3) : (!rtype, !rtype) -> !rtype
%4 = "tf.TPUPartitionedInput"(%2, %3) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
%5 = "tf.TPUReplicatedInput"(%0, %1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%6 = "tf.opC"(%4) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (!rtype) -> tensor<10x3xf32>
%7:2 = "tf.TPUReplicatedOutput"(%6) : (tensor<10x3xf32>) -> (tensor<10x3xf32>, tensor<10x3xf32>)
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_cores_per_replica = 2 : i64, num_replicas = 2 : i64, topology = "topology"} : () -> ()
%8 = "tf.opD"(%5) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (tensor<i32>) -> tensor<f32>
%9:2 = "tf.TPUReplicatedOutput"(%8) : (tensor<f32>) -> (tensor<f32>, tensor<f32>)
func.return %7#0, %9#1 : tensor<10x3xf32>, tensor<f32>
}
// CHECK: %[[OP_A:[0-9]*]] = "tf.opA"
// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
// CHECK-DAG: [%[[ARG_0]], %[[ARG_2]]] as %[[RI_0:[a-z0-9]*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>
// CHECK-DAG: [%[[ARG_1]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<!tf_type.resource<tensor<10x3xf32>>>
// CHECK-DAG: [%[[OP_A]], %[[OP_B]]] as %[[RI_2:[a-z0-9]*]]: tensor<i32>
// CHECK-NOT: _replicated_input_indices
// CHECK-SAME: n = 2 : i32
// CHECK: %[[PI:[0-9]*]] = "tf.TPUPartitionedInput"(%[[RI_0]], %[[RI_1]])
// CHECK-NEXT: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ({
// CHECK: %[[OP_C:[0-9]*]] = "tf.opC"(%[[PI]])
// CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_2]])
// CHECK: tf_device.return %[[OP_C]], %[[OP_D]]
// CHECK-NEXT: _replication_info = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: tf_device.return %[[CLUSTER]]#0, %[[CLUSTER]]#1
// CHECK: return %[[REPLICATE]]#0, %[[REPLICATE]]#3
// Test TPUReplicatedInput ops are sorted by their `index` attribute.
// Non-negative `index` should precede `index` of -1, and ordering of ops with
// `index` of -1 does not matter.
// CHECK-LABEL: func @sort_replicated_input
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<i1>, %[[ARG_1:.*]]: tensor<i1>, %[[ARG_2:.*]]: tensor<i1>, %[[ARG_3:.*]]: tensor<i1>, %[[ARG_4:.*]]: tensor<i1>, %[[ARG_5:.*]]: tensor<i1>, %[[ARG_6:.*]]: tensor<i1>, %[[ARG_7:.*]]: tensor<i1>)
func.func @sort_replicated_input(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tensor<i1>, %arg4: tensor<i1>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%1 = "tf.TPUReplicatedInput"(%arg1, %arg1) {index = 3 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%2 = "tf.TPUReplicatedInput"(%arg2, %arg2) {index = 0 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%3 = "tf.TPUReplicatedInput"(%arg3, %arg3) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%4 = "tf.TPUReplicatedInput"(%arg4, %arg4) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%5 = "tf.TPUReplicatedInput"(%arg5) {index = -1 : i64, is_packed = true} : (tensor<i1>) -> tensor<i1>
%6 = "tf.TPUReplicatedInput"(%arg6) {index = 2 : i64, is_packed = true} : (tensor<i1>) -> tensor<i1>
%7 = "tf.TPUReplicatedInput"(%arg7, %arg7) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0, %1, %2, %3, %4, %5, %6, %7) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", is_stateless = true} : (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
func.return
}
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %[[RI_2:[a-z0-9]*]]
// CHECK-SAME: [%[[ARG_4]], %[[ARG_4]]] as %[[RI_4:[a-z0-9]*]]
// CHECK-SAME: [%[[ARG_1]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]
// CHECK-DAG: [%[[ARG_0]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]
// CHECK-DAG: [%[[ARG_3]], %[[ARG_3]]] as %[[RI_3:[a-z0-9]*]]
// CHECK-DAG: [%[[ARG_7]], %[[ARG_7]]] as %[[RI_7:[a-z0-9]*]]
// CHECK-DAG: %[[ARG_6]] as %[[RI_6:[a-z0-9]*]]
// CHECK-DAG: %[[ARG_5]] as %[[RI_5:[a-z0-9]*]]
// CHECK-SAME: _replicated_input_indices = [0, 1, 3, -1, -1, -1, 2, -1]
// CHECK: "tf.opA"(%[[RI_0]], %[[RI_1]], %[[RI_2]], %[[RI_3]], %[[RI_4]], %[[RI_5]], %[[RI_6]], %[[RI_7]])
// Test TPUReplicatedInputs with non contiguous `index` attributes.
// CHECK-LABEL: func @non_contigous_indices
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<i1>, %[[ARG_1:.*]]: tensor<i1>, %[[ARG_2:.*]]: tensor<i1>, %[[ARG_3:.*]]: tensor<i1>, %[[ARG_4:.*]]: tensor<i1>, %[[ARG_5:.*]]: tensor<i1>)
func.func @non_contigous_indices(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tensor<i1>, %arg4: tensor<i1>, %arg5: tensor<i1>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 8 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : (tensor<i1>) -> ()
%1 = "tf.TPUReplicatedInput"(%arg1) {index = 6 : i64, is_packed = true} : (tensor<i1>) -> tensor<i1>
"tf.opA"(%1) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : (tensor<i1>) -> ()
%2 = "tf.TPUReplicatedInput"(%arg2, %arg2) : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opB"(%2) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : (tensor<i1>) -> ()
%3 = "tf.TPUReplicatedInput"(%arg3) {is_packed = true} : (tensor<i1>) -> tensor<i1>
"tf.opB"(%3) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : (tensor<i1>) -> ()
%4 = "tf.TPUReplicatedInput"(%arg4, %arg4) {index = 2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opC"(%4) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : (tensor<i1>) -> ()
%5 = "tf.TPUReplicatedInput"(%arg5) {index = 4 : i64, is_packed = true} : (tensor<i1>) -> tensor<i1>
"tf.opC"(%5) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : (tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
func.return
}
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_4]], %[[ARG_4]]] as %{{[a-z0-9]*}}
// CHECK-SAME: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %{{[a-z0-9]*}}
// CHECK-SAME: %[[ARG_5]] as %{{[a-z0-9]*}}
// CHECK-SAME: %[[ARG_1]] as %{{[a-z0-9]*}}
// CHECK-SAME: %[[ARG_3]] as %{{[a-z0-9]*}}
// CHECK-SAME: _replicated_input_indices = [2, 8, -1, 4, 6, -1]
// Test that the `is_mirrored_variable` attribute is preserved in the
// tf_device.replicate op.
// CHECK-LABEL: func @mirrored_variables
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<!tf_type.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<!tf_type.resource<tensor<32xf32>>>, %[[ARG_2:.*]]: tensor<!tf_type.resource<tensor<32xf32>>>, %[[ARG_3:.*]]: tensor<!tf_type.resource<tensor<32xf32>>>, %[[ARG_4:.*]]: tensor<!tf_type.resource<tensor<32xf32>>>)
func.func @mirrored_variables(%arg0: tensor<!tf_type.resource<tensor<32xf32>>>, %arg1: tensor<!tf_type.resource<tensor<32xf32>>>, %arg2: tensor<!tf_type.resource<tensor<32xf32>>>, %arg3: tensor<!tf_type.resource<tensor<32xf32>>>, %arg4: tensor<!tf_type.resource<tensor<32xf32>>>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 0 : i64} : (tensor<!tf_type.resource<tensor<32xf32>>>, tensor<!tf_type.resource<tensor<32xf32>>>) -> tensor<!tf_type.resource<tensor<32xf32>>>
%1 = "tf.TPUReplicatedInput"(%arg2, %arg3) {index = 1 : i64, is_mirrored_variable = true} : (tensor<!tf_type.resource<tensor<32xf32>>>, tensor<!tf_type.resource<tensor<32xf32>>>) -> tensor<!tf_type.resource<tensor<32xf32>>>
%2 = "tf.TPUReplicatedInput"(%arg4) {index = 2 : i64, is_mirrored_variable = true, is_packed = true} : (tensor<!tf_type.resource<tensor<32xf32>>>) -> tensor<!tf_type.resource<tensor<32xf32>>>
"tf.opA"(%0, %1, %2) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", is_stateless = true} : (tensor<!tf_type.resource<tensor<32xf32>>>, tensor<!tf_type.resource<tensor<32xf32>>>, tensor<!tf_type.resource<tensor<32xf32>>>) -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
func.return
}
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %{{[a-z0-9]*}}
// CHECK-SAME: %[[ARG_4]] as %{{[a-z0-9]*}}
// CHECK-SAME: _mirrored_variable_indices = [1, 2]
// CHECK-SAME: _replicated_input_indices = [0, 1, 2]
// Test resource usage after resource use in cluster is moved to after the
// cluster.
// CHECK-LABEL: func @resource_after_cluster
// CHECK-SAME: ([[USED_RESOURCE:%.*]]: tensor<*x!tf_type.resource<tensor<f32>>>, [[UNUSED_RESOURCE:%.*]]: tensor<*x!tf_type.resource<tensor<f32>>>)
func.func @resource_after_cluster(%arg0: tensor<*x!tf_type.resource<tensor<f32>>>, %arg1: tensor<*x!tf_type.resource<tensor<f32>>>) {
// CHECK-NEXT: [[CONST:%.*]] = "tf.Const"
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK-NEXT: "tf.AssignSubVariableOp"([[UNUSED_RESOURCE]], [[CONST]])
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.ReadVariableOp"([[USED_RESOURCE]])
// CHECK-NEXT: "tf.NoOp"
// CHECK-NEXT: tf_device.return
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster_test_fn", allow_soft_placement = false, computation_shape = [], device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> ()
%1 = "tf.ReadVariableOp"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster_test_fn"} : (tensor<*x!tf_type.resource<tensor<f32>>>) -> tensor<f32>
"tf.AssignSubVariableOp"(%arg1, %0) : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
// CHECK: "tf.AssignAddVariableOp"([[USED_RESOURCE]], [[CONST]])
"tf.AssignAddVariableOp"(%arg0, %0) : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
"tf.NoOp"() {_xla_compile_device_type = "TPU", _replication_info = "cluster_test_fn"} : () -> ()
func.return
}
// Test resource not used by cluster is moved to before the cluster.
// CHECK-LABEL: func @resource_before_cluster
func.func @resource_before_cluster() {
// CHECK-NEXT: [[CONST:%.*]] = "tf.Const"
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK-NEXT: [[UNUSED_RESOURCE:%.*]] = "tf.VarHandleOp"
// CHECK-NEXT: "tf.AssignAddVariableOp"([[UNUSED_RESOURCE]], [[CONST]])
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.NoOp"
// CHECK-NEXT: tf_device.return
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster_test_fn", allow_soft_placement = false, computation_shape = [], device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> ()
%1 = "tf.VarHandleOp"() {container = "", shape = #tf_type.shape<>, shared_name = "x"} : () -> tensor<*x!tf_type.resource<tensor<f32>>>
"tf.AssignAddVariableOp"(%1, %0) : (tensor<*x!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
"tf.NoOp"() {_xla_compile_device_type = "TPU", _replication_info = "cluster_test_fn"} : () -> ()
func.return
}
// Test cluster formation with ops with attached regions within a cluster.
// Nested op's that are moved should get their _replication_info and device
// attributes cleared.
// CHECK-LABEL: func @cluster_ops_with_regions
func.func @cluster_ops_with_regions() {
%0 = "tf.opA"() ({
%1 = "tf.opB"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "nameB", is_stateless = true} : () -> (tensor<i32>)
}) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "nameA"} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
func.return
}
// CHECK: "tf.opA"() ({
// CHECK-NEXT: "tf.opB"
// CHECK-NOT: _replication_info = "replicate"
// CHECK-NOT: device = "device"
// CHECK-SAME: name = "nameB"
// CHECK: })
// CHECK-NOT: _replication_info = "replicate"
// CHECK-NOT: device = "device"
// CHECK: name = "nameA"
// CHECK: tf_device.return
// A nested cluster op using result of another cluster op. In the below, opA and
// opB go in a cluster, and opD stays outside.
// CHECK-LABEL: func @cluster_nested_op_using_other_op
func.func @cluster_nested_op_using_other_op() {
%0 = "tf.opA"() { _xla_compile_device_type = "TPU", _replication_info = "foo" , is_stateless = true} : () -> tensor<i32>
"tf.opB"() ({
"tf.opC"(%0) {is_stateless = true} : (tensor<i32>) -> ()
}) { _xla_compile_device_type = "TPU", _replication_info = "foo" } : () -> ()
"tf.opD"(%0) {is_stateless = true} : (tensor<i32>) -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> ()
func.return
}
// CHECK: [[CLUSTER:%.*]] = "tf_device.cluster"() ({
// CHECK: [[OPA:%.*]] = "tf.opA"() {is_stateless = true} : () -> tensor<i32>
// CHECK: "tf.opB"() ({
// CHECK: "tf.opC"([[OPA]])
// CHECK: tf_device.return [[OPA]]
// CHECK: "tf.opD"([[CLUSTER]])
// Preceding user is using resource updated by a nested op.
!tf_res = tensor<*x!tf_type.resource<tensor<f32>>>
// CHECK-LABEL: func @cluster_nested_op_updating_resource
func.func @cluster_nested_op_updating_resource() {
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.VarHandleOp"() {container = "", shape = #tf_type.shape<>, shared_name = "x"} : () -> !tf_res
"tf.opA"() ({
"tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor<f32>) -> ()
"tf.terminator"() : () -> ()
}) { _xla_compile_device_type = "TPU", _replication_info = "foo" } : () -> ()
"tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor<f32>) -> ()
"tf.opB"() { _xla_compile_device_type = "TPU", _replication_info = "foo" , is_stateless = true} : () -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> ()
func.return
}
// CHECK: [[CONST:%.*]] = "tf.Const"
// CHECK: [[VAR:%.*]] = "tf.VarHandleOp"
// CHECK: "tf_device.cluster"() ({
// CHECK: "tf.opA"() ({
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
// CHECK: })
// CHECK: "tf.opB"()
// CHECK: tf_device.return
// CHECK: })
// CHECK-SAME: _replication_info = "foo"
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
// Preceding user is using resource updated by the cluster within a nested op.
// Resource is updated by a cluster op, and opA (not in cluster) is using the
// resource in a nested op. We expect opA to be after the cluster.
// CHECK-LABEL: func @cluster_nested_op_using_resource
func.func @cluster_nested_op_using_resource() {
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.VarHandleOp"() {container = "", shape = #tf_type.shape<>, shared_name = "x"} : () -> !tf_res
"tf.AssignAddVariableOp"(%1, %0) { _xla_compile_device_type = "TPU", _replication_info = "foo" } : (!tf_res, tensor<f32>) -> ()
"tf.opA"() ({
"tf.AssignAddVariableOp"(%1, %0) : (!tf_res, tensor<f32>) -> ()
"tf.terminator"() : () -> ()
}) : () -> ()
"tf.opB"() { _xla_compile_device_type = "TPU", _replication_info = "foo" , is_stateless = true} : () -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "foo", device = "CPU", num_replicas = 1, topology = "topology"} : () -> ()
func.return
}
// CHECK: [[CONST:%.*]] = "tf.Const"
// CHECK: [[VAR:%.*]] = "tf.VarHandleOp"
// CHECK: "tf_device.cluster"() ({
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
// CHECK: "tf.opB"()
// CHECK: tf_device.return
// CHECK: })
// CHECK-SAME: _replication_info = "foo"
// CHECK: "tf.opA"() ({
// CHECK: "tf.AssignAddVariableOp"([[VAR]], [[CONST]])
// -----
!tf_res = tensor<*x!tf_type.resource<tensor<f32>>>
// Test multiple replicated clusters interleaved and uses resource variables.
// CHECK-LABEL: func @multiple_replicated_interleaved
func.func @multiple_replicated_interleaved(%arg0: !tf_res) {
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "a", num_replicas = 2, topology = "topology"} : () -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "b", num_replicas = 2, topology = "topology"} : () -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "c", num_replicas = 2, topology = "topology"} : () -> ()
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (!tf_res, !tf_res) -> !tf_res
%1 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (!tf_res, !tf_res) -> !tf_res
%2 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (!tf_res, !tf_res) -> !tf_res
%3 = "tf.ReadVariableOp"(%0) {_xla_compile_device_type = "TPU", _replication_info = "a"} : (!tf_res) -> tensor<f32>
%4 = "tf.ReadVariableOp"(%1) {_xla_compile_device_type = "TPU", _replication_info = "b"} : (!tf_res) -> tensor<f32>
%5 = "tf.ReadVariableOp"(%2) {_xla_compile_device_type = "TPU", _replication_info = "c"} : (!tf_res) -> tensor<f32>
%6 = "tf.Identity"(%3) {_xla_compile_device_type = "TPU", _replication_info = "a"} : (tensor<f32>) -> tensor<f32>
%7 = "tf.Identity"(%4) {_xla_compile_device_type = "TPU", _replication_info = "b"} : (tensor<f32>) -> tensor<f32>
%8 = "tf.Identity"(%5) {_xla_compile_device_type = "TPU", _replication_info = "c"} : (tensor<f32>) -> tensor<f32>
%9:2 = "tf.TPUReplicatedOutput"(%6) : (tensor<f32>) -> (tensor<f32>, tensor<f32>)
%10:2 = "tf.TPUReplicatedOutput"(%7) : (tensor<f32>) -> (tensor<f32>, tensor<f32>)
%11:2 = "tf.TPUReplicatedOutput"(%8) : (tensor<f32>) -> (tensor<f32>, tensor<f32>)
func.return
}
// CHECK: tf_device.replicate
// CHECK: tf_device.replicate
// CHECK: tf_device.replicate
// -----
// Test cluster that is replicated but has a non TPUReplicatedOutput consumer.
// CHECK-LABEL: func @replicated_non_replicated_output
func.func @replicated_non_replicated_output() {
%0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : () -> tensor<i1>
%1 = "tf.opB"(%0) {is_stateless = true} : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
func.return
}
// CHECK: [[REPLICATE:%.+]]:2 = tf_device.replicate
// CHECK: "tf.opB"([[REPLICATE]]#0)
// -----
// Test cluster with missing `num_replicas` attribute.
func.func @missing_num_replicas() {
%0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : () -> tensor<i1>
// expected-error@+1 {{'tf.TPUReplicateMetadata' op requires attribute 'num_replicas'}}
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", topology = "topology"} : () -> ()
func.return
}
// -----
// Test cluster with bad `num_replicas` attribute.
func.func @bad_num_replicas() {
// expected-error@+1 {{requires 'num_replicas' int attribute to be at least 1}}
%0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 0, topology = "topology"} : () -> ()
func.return
}
// -----
// Test cluster with bad `num_cores_per_replica` attribute.
!rtype = tensor<!tf_type.resource<tensor<10x3xf32>>>
func.func @replication_with_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> (tensor<10x3xf32>) {
%2 = "tf.TPUReplicatedInput"(%arg0, %arg2) : (!rtype, !rtype) -> !rtype
%3 = "tf.TPUReplicatedInput"(%arg1, %arg3) : (!rtype, !rtype) -> !rtype
// expected-error@+1 {{'tf.TPUPartitionedInput' op requires 4 operands but found 2}}
%4 = "tf.TPUPartitionedInput"(%2, %3) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
%6 = "tf.opC"(%4) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (!rtype) -> tensor<10x3xf32>
%7:2 = "tf.TPUReplicatedOutput"(%6) : (tensor<10x3xf32>) -> (tensor<10x3xf32>, tensor<10x3xf32>)
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_cores_per_replica = 4 : i64, num_replicas = 2 : i64, topology = "topology"} : () -> ()
func.return %7#0 : tensor<10x3xf32>
}
// -----
// Test cluster with TPUReplicatedInput where the number of operands does not
// match associated `num_replicas` attribute.
func.func @mismatched_replicated_input(%arg0: tensor<i1>) {
// expected-error@+1 {{'tf.TPUReplicatedInput' op requires 2 operands}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0, %arg0) : (tensor<i1>, tensor<i1>, tensor<i1>) -> tensor<i1>
%1 = "tf.opA"(%0) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
func.return
}
// -----
// Test cluster with TPUReplicatedOutput where the number of results does not
// match associated `num_replicas` attribute.
func.func @mismatched_replicated_output() {
%0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : () -> tensor<i1>
// expected-error@+1 {{'tf.TPUReplicatedOutput' op requires 2 results}}
%1:3 = "tf.TPUReplicatedOutput"(%0) : (tensor<i1>) -> (tensor<i1>, tensor<i1>, tensor<i1>)
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
func.return
}
// -----
// Test unused TPUReplicatedInput that has more than one operand.
func.func @leftover_replicated_input(%arg0: tensor<i1>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (tensor<i1>, tensor<i1>) -> tensor<i1>
func.return
}
// -----
// Test unused TPUReplicatedOutput that has more than one result.
func.func @leftover_replicated_output(%arg0: tensor<i1>) {
%0:2 = "tf.TPUReplicatedOutput"(%arg0) : (tensor<i1>) -> (tensor<i1>, tensor<i1>)
func.return
}
// -----
// Test bad TPUReplicatedInput negative `index` attribute.
func.func @bad_negative_index_input(%arg0: tensor<i1>) {
// expected-error@+1 {{'tf.TPUReplicatedInput' op requires index to be at least -1, but got -2}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : (tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
func.return
}
// -----
// Test TPUReplicatedInput with conflicting `index` attribute.
func.func @input_index_gaps(%arg0: tensor<i1>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
// expected-error@+1 {{'tf.TPUReplicatedInput' op requires indices to be unique, but found multiple 'tf.TPUReplicatedInput' ops with index 1}}
%1 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0, %1) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", name = "name", is_stateless = true} : (tensor<i1>, tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
func.return
}
// -----
// CHECK-LABEL: func @cluster_ops_keep_replicated_core_attr
func.func @cluster_ops_keep_replicated_core_attr() {
%0 = "tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU_REPLICATED_CORE:0", name = "name", is_stateless = true} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
func.return
}
// CHECK: "tf.opA"
// CHECK-SAME-DAG: name = "name"
// CHECK-SAME-DAG: device = "/device:TPU_REPLICATED_CORE:0"
// CHECK: tf_device.return
// -----
func.func @missing_compilation_attribute() {
// expected-error@+1 {{'tf.opA' op has '_replication_info' attribute but not '_xla_compile_device_type' attribute which is unsupported}}
%0 = "tf.opA"() { _replication_info = "replicate", device = "/device:TPU_REPLICATED_CORE:0", name = "name", is_stateless = true} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
func.return
}
// -----
func.func @empty_replication_attribute() {
// expected-error@+1 {{'tf.opA' op has an empty '_replication_info' attribute}}
%0 = "tf.opA"() { _xla_compile_device_type = "TPU", _replication_info = "", device = "/device:TPU_REPLICATED_CORE:0", name = "name", is_stateless = true} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
func.return
}
// -----
func.func @invalid_device_type() {
// expected-error@+1 {{'tf.opA' op has invalid '_xla_compile_device_type' value 'XPU'}}
"tf.opA"() { _xla_compile_device_type = "XPU", _replication_info = "replicate", is_stateless = true} : () -> ()
func.return
}
// -----
// Check non-replicated case, including expected attributes at device cluster.
// CHECK: "tf_device.cluster"()
// CHECK: "tf.opA"()
// CHECK: "tf.opB"()
// CHECK: tf_device.return
// CHECK: }) {_replication_info = "__no_replication_cluster", _xla_compile_device_type = "TPU", allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i32, step_marker_location = "", topology = "", use_spmd_for_xla_partitioning = false}
func.func @valid_compilation_cluster_no_replication() {
"tf.opA"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> ()
"tf.opB"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> ()
func.return
}
// -----
// expected-error@+1 {{found different '_xla_compile_device_type' attribute values (GPU,TPU) in same block which is not supported}}
func.func @invalid_compilation_cluster_mixed_device_types() {
"tf.opA"() { _xla_compile_device_type = "GPU", is_stateless = true} : () -> ()
"tf.opB"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> ()
func.return
}
// -----
// expected-error@+1 {{found different '_xla_compile_device_type' attribute values (CPU,GPU) in same block which is not supported}}
func.func @invalid_compilation_replication_cluster_mixed_device_types() {
"tf.opA"() { _xla_compile_device_type = "CPU", _replication_info = "cluster", is_stateless = true} : () -> ()
"tf.opB"() { _xla_compile_device_type = "GPU", _replication_info = "cluster", is_stateless = true} : () -> ()
func.return
}
// -----
// expected-error@+1 {{found mixed replicated and non-replicated compiled ops in same block which is not supported}}
func.func @mixed_replicated_non_replicated_ops() {
"tf.opA"() { _xla_compile_device_type = "TPU", is_stateless = true} : () -> ()
"tf.opB"() { _xla_compile_device_type = "TPU", _replication_info = "cluster", is_stateless = true} : () -> ()
func.return
}
// -----
func.func @cyclic_control_dependency_no_replication() {
"tf.opA"() {_xla_compile_device_type = "CPU"} : () -> ()
// expected-warning@+1 {{op has cyclic dependency with a compilation cluster}}
"tf.opB"() : () -> ()
"tf.opC"() {_xla_compile_device_type = "CPU"} : () -> ()
func.return
}
// -----
func.func @cyclic_data_dependency_no_replication() {
%0 = "tf.opA"() {_xla_compile_device_type = "GPU", is_stateless = true} : () -> (tensor<i32>)
// expected-warning@+2 {{op has cyclic dependency with a compilation cluster}}
// expected-error@+1 {{operand #0 does not dominate this use}}
%1 = "tf.opB"(%0) {is_stateless = true} : (tensor<i32>) -> (tensor<i32>)
// expected-note@+1 {{operand defined here (op in the same block)}}
"tf.opC"(%1) {_xla_compile_device_type = "GPU", is_stateless = true} : (tensor<i32>) -> ()
func.return
}
// -----
func.func @cyclic_control_dependency_replication() {
"tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : () -> ()
// expected-warning@+1 {{op has cyclic dependency with a compilation cluster}}
"tf.opB"() : () -> ()
"tf.opC"() {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : () -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
func.return
}
// -----
func.func @cyclic_data_dependency_replication() {
%0 = "tf.opA"() {_xla_compile_device_type = "TPU", is_stateless = true} : () -> (tensor<i32>)
// expected-warning@+2 {{op has cyclic dependency with a compilation cluster}}
// expected-error@+1 {{operand #0 does not dominate this use}}
%1 = "tf.opB"(%0) {is_stateless = true} : (tensor<i32>) -> (tensor<i32>)
// expected-note@+1 {{operand defined here (op in the same block)}}
"tf.opC"(%1) {_xla_compile_device_type = "TPU", is_stateless = true} : (tensor<i32>) -> ()
"tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
func.return
}
// -----
// expected-warning@+1 {{TPUReplicateMetadata for associated '_replication_info' attribute 'cluster' is missing}}
func.func @missing_metadata() {
"tf.opA"() {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : () -> ()
func.return
}
// -----
// CHECK-LABEL: func @const_with_attrs
func.func @const_with_attrs(%arg0: tensor<*xi32>, %arg1: tensor<?xi64>) -> (tensor<?xi32>, tensor<?xi64>) {
// CHECK: %{{[a-z0-9_]*}} = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: %{{[a-z0-9_]*}} = "tf.Reshape"(%arg0
// CHECK-NEXT: %{{.*}} = "tf_device.cluster"() ({
%minus_one = "tf.Const"() {_replication_info = "cluster",
_xla_compile_device_type = "TPU",
value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
"tf.TPUReplicateMetadata"() {_replication_info = "cluster", num_replicas = 1 : i64} : () -> ()
%1 = "tf.Reshape"(%arg0, %minus_one) : (tensor<*xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Identity"(%1) {_replication_info = "cluster", _xla_compile_device_type = "TPU"} : (tensor<?xi32>) -> tensor<?xi32>
%4 = "tf.Reshape"(%arg1, %minus_one) {_replication_info = "cluster", _xla_compile_device_type = "TPU", device = ""} : (tensor<?xi64>, tensor<1xi32>) -> tensor<?xi64>
%5 = "tf.Identity"(%4) {_replication_info = "cluster", _xla_compile_device_type = "TPU"} : (tensor<?xi64>) -> tensor<?xi64>
func.return %2, %5 : tensor<?xi32>, tensor<?xi64>
}
// -----
// CHECK-LABEL: func @two_clusters
func.func @two_clusters(%arg0: tensor<*xi32>, %arg1: tensor<?xi64>) -> (tensor<?xi32>, tensor<?xi64>) {
// CHECK: %{{[a-z0-9_]*}} = "tf.Const"(){{.*}}value = dense<1>
// CHECK-NEXT: %{{[a-z0-9_]*}} = "tf.Const"(){{.*}}value = dense<2>
// CHECK-NEXT: %{{[a-z0-9_]*}} = "tf_device.cluster"
%one = "tf.Const"() {_replication_info = "cluster1",
_xla_compile_device_type = "TPU",
value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%two = "tf.Const"() {_replication_info = "cluster2",
_xla_compile_device_type = "TPU",
value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
"tf.TPUReplicateMetadata"() {_replication_info = "cluster1", num_replicas = 1 : i64} : () -> ()
"tf.TPUReplicateMetadata"() {_replication_info = "cluster2", num_replicas = 1 : i64} : () -> ()
%1 = "tf.Reshape"(%arg0, %one) {_replication_info = "cluster2", _xla_compile_device_type = "TPU", device = ""} : (tensor<*xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Identity"(%1) {_replication_info = "cluster2", _xla_compile_device_type = "TPU"} : (tensor<?xi32>) -> tensor<?xi32>
%3 = "tf.Reshape"(%arg1, %two) {_replication_info = "cluster1", _xla_compile_device_type = "TPU", device = ""} : (tensor<?xi64>, tensor<1xi32>) -> tensor<?xi64>
%4 = "tf.Identity"(%3) {_replication_info = "cluster1", _xla_compile_device_type = "TPU"} : (tensor<?xi64>) -> tensor<?xi64>
func.return %2, %4 : tensor<?xi32>, tensor<?xi64>
}