blob: 86e6f1bd55bce5df690f788d44ac17f6b9fbfa7b [file] [log] [blame]
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-cluster-formation | FileCheck %s --dump-input=fail
// Test ops in cluster only have `_tpu_replicate` and `device` attributes
// removed when moved to a launch.
// CHECK-LABEL: func @cluster_ops_removed_attrs
func @cluster_ops_removed_attrs() {
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
return
}
// CHECK: "tf.opA"
// CHECK-SAME: name = "name"
// CHECK-NOT: _tpu_replicate = "replicate"
// CHECK-NOT: device = "device"
// CHECK: tf_device.return
// Test TPUReplicateMetadata ops `name` and `num_replicas` attributes are not
// copied over to launch.
// CHECK-LABEL: func @launch_removed_metadata_attrs
func @launch_removed_metadata_attrs() {
%0 = "tf.opA"() {_tpu_replicate = "replicate"} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", name = "name", num_replicas = 1, topology = "topology"} : () -> ()
return
}
// CHECK-NOT: name = "name"
// CHECK-NOT: num_replicas = 1
// Test TPUReplicateMetadata op is removed when forming clusters.
// CHECK-LABEL: func @metadata_op_removed
func @metadata_op_removed() {
%0 = "tf.opA"() {_tpu_replicate = "replicate"} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
return
}
// CHECK-NOT: "tf.TPUReplicateMetadata"
// Test ops in an island with the same `_tpu_replicate` attribute are merged
// under a launch.
// CHECK-LABEL: func @simple_island
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @simple_island(%arg0 : tensor<i1>) -> tensor<i1> {
%0 = tf_executor.graph {
%1:2 = tf_executor.island {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%3 = "tf.opA"(%arg0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
%4 = "tf.opB"() : () -> tensor<i1>
%5 = "tf.opC"(%3) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %5 : tensor<i1>
}
tf_executor.fetch %1#0 : tensor<i1>
}
return %0 : tensor<i1>
}
// CHECK: "tf.opB"
// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
// CHECK-NEXT: tf_device.return %[[OP_C]]
// CHECK-NEXT: _tpu_replicate = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: tf_executor.yield %[[LAUNCH]]
// Test ops in an island with the same `_tpu_replicate` attribute are merged
// under a launch, even when the associated TPUReplicateMetadata op is in a
// different island.
// CHECK-LABEL: func @simple_island_separate_metadata
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @simple_island_separate_metadata(%arg0 : tensor<i1>) -> tensor<i1> {
%0 = tf_executor.graph {
%1 = tf_executor.island {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
}
%2:2 = tf_executor.island {
%3 = "tf.opA"(%arg0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
%4 = "tf.opB"() : () -> tensor<i1>
%5 = "tf.opC"(%3) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %5 : tensor<i1>
}
tf_executor.fetch %2#0 : tensor<i1>
}
return %0 : tensor<i1>
}
// CHECK: "tf.opB"
// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
// CHECK-NEXT: tf_device.return %[[OP_C]]
// CHECK-NEXT: _tpu_replicate = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: tf_executor.yield %[[LAUNCH]]
// Test ops in multiple islands with the same `_tpu_replicate` attribute are
// merged under launch ops only within their respective island.
// CHECK-LABEL: func @multiple_islands_separate_metadata
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @multiple_islands_separate_metadata(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
%0:2 = tf_executor.graph {
%1 = tf_executor.island {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
}
%2:2 = tf_executor.island {
%3 = "tf.opA"(%arg0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
%4 = "tf.opB"() : () -> tensor<i1>
%5 = "tf.opC"(%3) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %5 : tensor<i1>
}
%6:2 = tf_executor.island {
%7 = "tf.opD"(%2#0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
%8 = "tf.opE"() : () -> tensor<i1>
%9 = "tf.opF"(%arg0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
tf_executor.yield %9 : tensor<i1>
}
tf_executor.fetch %2#0, %6#0 : tensor<i1>, tensor<i1>
}
return %0#0, %0#1 : tensor<i1>, tensor<i1>
}
// CHECK: %[[ISLAND_1:.*]], %[[ISLAND_1_control:.*]] = tf_executor.island {
// CHECK: "tf.opB"
// CHECK: %[[LAUNCH_0:[0-9]*]] = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
// CHECK-NEXT: tf_device.return %[[OP_C]]
// CHECK-NEXT: _tpu_replicate = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: tf_executor.yield %[[LAUNCH_0]]
// CHECK: tf_executor.island {
// CHECK: "tf.opE"
// CHECK: %[[LAUNCH_1:[0-9]*]] = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[ISLAND_1]])
// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_0]])
// CHECK-NEXT: tf_device.return %[[OP_F]]
// CHECK-NEXT: _tpu_replicate = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: tf_executor.yield %[[LAUNCH_1]]
// Test ops in a function body with the same `_tpu_replicate` attribute are
// merged under a launch op.
// CHECK-LABEL: func @ops_in_func_body
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @ops_in_func_body(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>, tensor<i1>) {
%0 = "tf.opA"(%arg0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
%1 = "tf.opB"() : () -> tensor<i1>
%2 = "tf.opC"(%0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%3 = "tf.opD"(%2) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
%4 = "tf.opE"() : () -> tensor<i1>
%5 = "tf.opF"(%arg0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
return %2, %3, %5 : tensor<i1>, tensor<i1>, tensor<i1>
}
// CHECK: "tf.opB"
// CHECK: "tf.opE"
// CHECK: %[[LAUNCH:[0-9]*]]:3 = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]])
// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_0]])
// CHECK-NEXT: tf_device.return %[[OP_C]], %[[OP_D]], %[[OP_F]]
// CHECK-NEXT: _tpu_replicate = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: return %[[LAUNCH]]#0, %[[LAUNCH]]#1, %[[LAUNCH]]#2
// Test a nested user of an op in a cluster has its operand be updated to launch
// result.
// CHECK-LABEL: func @nested_cluster_op_user
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @nested_cluster_op_user(%arg0 : tensor<i1>) -> (tensor<i1>) {
%0 = "tf.opA"(%arg0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
%1 = tf_executor.graph {
tf_executor.fetch %0 : tensor<i1>
}
%2 = "tf.opB"(%0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
return %2 : tensor<i1>
}
// CHECK: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
// CHECK-NEXT: tf_device.return %[[OP_A]], %[[OP_B]]
// CHECK-NEXT: _tpu_replicate = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: tf_executor.graph {
// CHECK-NEXT: tf_executor.fetch %[[LAUNCH]]#0
// CHECK: return %[[LAUNCH]]#1
// Test nested op of a cluster with an operand from an op of the same cluster
// retains its original operand.
// CHECK-LABEL: func @nested_cluster_op
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @nested_cluster_op(%arg0 : tensor<i1>) -> (tensor<i1>) {
%0 = "tf.opA"(%arg0) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
%1 = "tf.opB"() ( {
"tf.opC"(%0) : (tensor<i1>) -> tensor<i1>
}) {_tpu_replicate = "replicate"} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
return %1 : tensor<i1>
}
// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"() ( {
// CHECK-NEXT: "tf.opC"(%[[OP_A]])
// CHECK: tf_device.return %[[OP_B]]
// CHECK-NEXT: _tpu_replicate = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: return %[[LAUNCH]]
// Test multiple clusters interleaved.
// CHECK-LABEL: func @interleaved_clusters
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @interleaved_clusters(%arg0 : tensor<i1>) -> (tensor<i1>, tensor<i1>) {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate_1", device = "device_1", num_replicas = 1, topology = "topology_1"} : () -> ()
%0 = "tf.opA"(%arg0) {_tpu_replicate = "replicate_0"} : (tensor<i1>) -> tensor<i1>
%1 = "tf.opB"(%arg0) {_tpu_replicate = "replicate_1"} : (tensor<i1>) -> tensor<i1>
%2 = "tf.opC"(%0) {_tpu_replicate = "replicate_0"} : (tensor<i1>) -> tensor<i1>
%3 = "tf.opD"(%1) {_tpu_replicate = "replicate_1"} : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate_0", device = "device_0", num_replicas = 1, topology = "topology_0"} : () -> ()
return %2, %3 : tensor<i1>, tensor<i1>
}
// CHECK: %[[LAUNCH_0:[0-9]*]] = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
// CHECK-NEXT: tf_device.return %[[OP_C]]
// CHECK-NEXT: _tpu_replicate = "replicate_0"
// CHECK-SAME: device = "device_0"
// CHECK-SAME: topology = "topology_0"
// CHECK: %[[LAUNCH_1:[0-9]*]] = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_B]])
// CHECK-NEXT: tf_device.return %[[OP_D]]
// CHECK-NEXT: _tpu_replicate = "replicate_1"
// CHECK-SAME: device = "device_1"
// CHECK-SAME: topology = "topology_1"
// CHECK: return %[[LAUNCH_0]], %[[LAUNCH_1]]
// Test operands and results of ops of a cluster that are interleaved between
// other ops of the same cluster are moved before and after the cluster
// properly.
// CHECK-LABEL: func @interleaved_cluster_operands_results
func @interleaved_cluster_operands_results() {
%0 = "tf.opA"() {_tpu_replicate = "replicate"} : () -> tensor<i1>
%1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1>
%2 = "tf.opC"() : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%3 = "tf.opD"(%1) : (tensor<i1>) -> tensor<i1>
%4 = "tf.opE"(%2) : (tensor<i1>) -> tensor<i1>
%5 = "tf.opF"(%4) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
return
}
// CHECK: %[[OP_C:[0-9]*]] = "tf.opC"
// CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_C]])
// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"
// CHECK-NEXT: "tf.opF"(%[[OP_E]])
// CHECK-NEXT: tf_device.return %[[OP_A]]
// CHECK-NEXT: _tpu_replicate = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[LAUNCH]])
// CHECK: "tf.opD"(%[[OP_B]])
// Test one replica cluster results in removing of TPUReplicatedInput and
// TPUReplicatedOutput nodes and operands are forwarded to results.
// CHECK-LABEL: func @one_replica
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
func @one_replica(%arg0: tensor<i1>) -> tensor<i1> {
%ri = "tf.TPUReplicatedInput"(%arg0) : (tensor<i1>) -> tensor<i1>
%0 = "tf.opA"(%ri) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
%1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1>
%2 = "tf.opC"() : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
%3 = "tf.opD"(%1) : (tensor<i1>) -> tensor<i1>
%4 = "tf.opE"(%2) : (tensor<i1>) -> tensor<i1>
%5 = "tf.opF"(%4) {_tpu_replicate = "replicate"} : (tensor<i1>) -> tensor<i1>
%ro = "tf.TPUReplicatedOutput"(%5) : (tensor<i1>) -> tensor<i1>
return %ro : tensor<i1>
}
// CHECK: %[[OP_C:[0-9]*]] = "tf.opC"
// CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_C]])
// CHECK: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[OP_E]])
// CHECK-NEXT: tf_device.return %[[OP_A]], %[[OP_F]]
// CHECK-NEXT: _tpu_replicate = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[LAUNCH]]#0)
// CHECK: "tf.opD"(%[[OP_B]])
// CHECK: return %[[LAUNCH]]#1
// CHECK-NOT: "tf.TPUReplicatedInput"
// CHECK-NOT: "tf.TPUReplicatedOutput"
// Test replication with replicated operands and replicated results. The cluster
// will be wrapped in a launch first and then by a replicate. TPUReplicatedInput
// and TPUReplicatedOutput nodes will be replaced by the replicate operands and
// results.
// CHECK-LABEL: func @replication
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>, %[[ARG_1:[a-z0-9]*]]: tensor<i32>, %[[ARG_2:[a-z0-9]*]]: tensor<f32>)
func @replication(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<f32>) -> (tensor<i32>, tensor<f32>) {
%0 = "tf.opA"() : () -> tensor<i1>
%ri_0 = "tf.TPUReplicatedInput"(%arg0, %0) : (tensor<i1>, tensor<i1>) -> tensor<i1>
%1 = "tf.opB"() : () -> tensor<i32>
%ri_1 = "tf.TPUReplicatedInput"(%1, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%2 = "tf.opC"() : () -> tensor<f32>
%3 = "tf.opD"(%ri_0, %ri_1, %arg2, %2) {_tpu_replicate = "replicate"} : (tensor<i1>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<i32>
%ro_0:2 = "tf.TPUReplicatedOutput"(%3) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
%7 = "tf.opE"(%3, %ri_0, %ri_1, %arg2, %2) {_tpu_replicate = "replicate"} : (tensor<i32>, tensor<i1>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
%ro_1:2 = "tf.TPUReplicatedOutput"(%7) : (tensor<f32>) -> (tensor<f32>, tensor<f32>)
return %ro_0#0, %ro_1#1 : tensor<i32>, tensor<f32>
}
// CHECK: %[[OP_A:[0-9]*]] = "tf.opA"
// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"
// CHECK: %[[OP_C:[0-9]*]] = "tf.opC"
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
// CHECK-DAG: [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>
// CHECK-DAG: [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>
// CHECK-SAME: n = 2 : i32
// CHECK-NEXT: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( {
// CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]])
// CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]], %[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]])
// CHECK: tf_device.return %[[OP_D]], %[[OP_E]]
// CHECK-NEXT: _tpu_replicate = "replicate"
// CHECK-SAME: device = "device"
// CHECK-SAME: topology = "topology"
// CHECK: tf_device.return %[[LAUNCH]]#0, %[[LAUNCH]]#1
// CHECK: return %[[REPLICATE]]#0, %[[REPLICATE]]#3
// Test `tf.TPUReplicatedInput` ops are sorted by their `index` attribute.
// Non-negative `index` should preceed `index` of -1, and ordering of ops with
// `index` of -1 does not matter.
// CHECK-LABEL: func @sort_replicated_input
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<i1>, %[[ARG_1:.*]]: tensor<i1>, %[[ARG_2:.*]]: tensor<i1>, %[[ARG_3:.*]]: tensor<i1>, %[[ARG_4:.*]]: tensor<i1>, %[[ARG_5:.*]]: tensor<i1>)
func @sort_replicated_input(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tensor<i1>, %arg4: tensor<i1>, %arg5: tensor<i1>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%1 = "tf.TPUReplicatedInput"(%arg1, %arg1) {index = 2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%2 = "tf.TPUReplicatedInput"(%arg2, %arg2) {index = 0 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%3 = "tf.TPUReplicatedInput"(%arg3, %arg3) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%4 = "tf.TPUReplicatedInput"(%arg4, %arg4) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%5 = "tf.TPUReplicatedInput"(%arg5, %arg5) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0, %1, %2, %3, %4, %5) {_tpu_replicate = "replicate", device = "device"} : (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %{{[a-z0-9]*}}
// CHECK-SAME: [%[[ARG_4]], %[[ARG_4]]] as %{{[a-z0-9]*}}
// CHECK-SAME: [%[[ARG_1]], %[[ARG_1]]] as %{{[a-z0-9]*}}
// CHECK-DAG: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
// CHECK-DAG: [%[[ARG_3]], %[[ARG_3]]] as %{{[a-z0-9]*}}
// CHECK-DAG: [%[[ARG_5]], %[[ARG_5]]] as %{{[a-z0-9]*}}
// -----
// Test cluster with missing `num_replicas` attribute.
func @missing_num_replicas() {
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
// expected-error@+1 {{'tf.TPUReplicateMetadata' op requires attribute 'num_replicas'}}
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", topology = "topology"} : () -> ()
return
}
// -----
// Test cluster with bad `num_replicas` attribute.
func @bad_num_replicas() {
// expected-error@+1 {{requires 'num_replicas' int attribute to be at least 1}}
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 0, topology = "topology"} : () -> ()
return
}
// -----
// Test cluster with TPUReplicatedInput where the number of operands does not
// match associated `num_replicas` attribute.
func @mismatched_replicated_input(%arg0: tensor<i1>) {
// expected-error@+1 {{'tf.TPUReplicatedInput' op requires 2 operands}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0, %arg0) : (tensor<i1>, tensor<i1>, tensor<i1>) -> tensor<i1>
%1 = "tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// -----
// Test cluster with TPUReplicatedOutput where the number of results does not
// match associated `num_replicas` attribute.
func @mismatched_replicated_output() {
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
// expected-error@+1 {{'tf.TPUReplicatedOutput' op requires 2 results}}
%1:3 = "tf.TPUReplicatedOutput"(%0) : (tensor<i1>) -> (tensor<i1>, tensor<i1>, tensor<i1>)
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// -----
// Test cluster that should be replicated where its outputs do not lead to a
// TPUReplicatedOutput.
func @missing_replicated_output() {
// expected-error@+1 {{requires output of tf_device.launch to lead to a 'tf.TPUReplicatedOutput' op}}
%0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor<i1>
%1 = "tf.opB"(%0) : (tensor<i1>) -> tensor<i1>
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// -----
// Test unused TPUReplicatedInput that has more than one operand.
func @leftover_replicated_input(%arg0: tensor<i1>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) : (tensor<i1>, tensor<i1>) -> tensor<i1>
return
}
// -----
// Test unused TPUReplicatedOutput that has more than one result.
func @leftover_replicated_output(%arg0: tensor<i1>) {
%0:2 = "tf.TPUReplicatedOutput"(%arg0) : (tensor<i1>) -> (tensor<i1>, tensor<i1>)
return
}
// -----
// Test bad TPUReplicatedInput positive `index` attribute.
func @bad_positive_index_input(%arg0: tensor<i1>) {
// expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got 1}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// -----
// Test bad TPUReplicatedInput negative `index` attribute.
func @bad_negative_index_input(%arg0: tensor<i1>) {
// expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got -2}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// -----
// Test TPUReplicatedInput with conflicting `index` attribute. This will result
// in gaps in the TPUReplicatedInput ordering.
func @input_index_gaps(%arg0: tensor<i1>) {
// expected-error@+1 {{failed to sort 'tf.TPUReplicatedInput' ops, gap(s) found in indices}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%1 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>, tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}