blob: 6b7896f16f80eb012224517a463cf420ce08a9fe [file] [log] [blame]
// RUN: tf-opt %s -split-input-file -verify-diagnostics --tf-tpu-rewrite=tpu-compile-metadata-debug | FILECHECK_OPTS="" FileCheck %s
// Tests module with missing `tf.versions` attribute.
// expected-error@+1 {{requires attribute 'tf.versions'}}
module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @missing_tf_versions() {
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests collecting compilation and execution devices results in an error.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_devices() {
// expected-error@+1 {{error in fetching TPU compilation/execution devices: no TPU_SYSTEM devices found}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with missing `num_cores_per_replicas`
// attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @missing_num_cores_per_replica() {
// expected-error@+1 {{requires attribute 'num_cores_per_replica'}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with bad `num_cores_per_replicas` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_num_cores_per_replica() {
// expected-error@+1 {{requires attribute 'num_cores_per_replica'}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = "", step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with missing `step_marker_location` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_num_cores_per_replica() {
// expected-error@+1 {{requires attribute 'step_marker_location'}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with bad `step_marker_location` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_step_marker_location() {
// expected-error@+1 {{requires attribute 'step_marker_location'}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = 1, topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with unparsable `step_marker_location` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @unparsable_step_marker_location() {
// expected-error@+1 {{bad 'step_marker_location' attribute with value 'test'}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "test", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with missing `topology` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @missing_topology() {
// expected-error@+1 {{requires attribute 'topology'}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with bad `topology` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_topology() {
// expected-error@+1 {{requires attribute 'topology'}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = 1 : i32, device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with `topology` attribute resulting in device assignment error.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @invalid_topology() {
// expected-error@+1 {{error in fetching TPU compilation/execution devices}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "test", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with missing `device_assignment` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @missing_device_assignment() {
// expected-error@+1 {{requires attribute 'device_assignment'}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with bad `device_assignment` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_device_assignment() {
// expected-error@+1 {{requires attribute 'device_assignment'}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = "", input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with bad element in `device_assignment` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_element_device_assignment() {
// expected-error@+1 {{bad 'device_assignment' attribute at index 0, not an int}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [""], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// The following topology is used in subsequent test cases:
// Proto debug string:
// mesh_shape: 1
// mesh_shape: 1
// mesh_shape: 1
// mesh_shape: 2
// num_tasks: 1
// num_tpu_devices_per_task: 2
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 1
// Serialized string:
// "\0A\04\01\01\01\02\10\01\18\02\22\06\00\00\00\00\00\00\00\01"
// -----
// Tests `tf_device.cluster_func` with `device_assignment` attribute resulting in device assignment error.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @invalid_device_assignment() {
// expected-error@+1 {{error in fetching TPU compilation/execution devices}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\03\01\01\02\10\01\18\02\22\06\00\00\00\00\00\01", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests `tf_device.cluster_func` with missing `input_sharding_configuration` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @missing_input_sharding_configuration(%arg0: tensor<?xi32>) {
// expected-error@+1 {{requires attribute 'input_sharding_configuration'}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", device_assignment = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
func.return
}
func.func @empty_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func.return %arg0 : tensor<?xi32>
}
}
// -----
// The following op sharding is used in subsequent test cases:
// Proto debug string:
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// Serialized string:
// "\08\01\1A\01\01\22\01\00"
// -----
// Tests `tf_device.cluster_func` with bad `input_sharding_configuration` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_input_sharding_configuration(%arg0: tensor<?xi32>) {
// expected-error@+1 {{requires attribute 'input_sharding_configuration'}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = "", output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
func.return
}
func.func @empty_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func.return %arg0 : tensor<?xi32>
}
}
// -----
// Tests `tf_device.cluster_func` with mismatched `input_sharding_configuration` attribute size.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @mismatched_size_input_sharding_configuration(%arg0: tensor<?xi32>) {
// expected-error@+1 {{bad 'input_sharding_configuration' attribute, expected array attribute of size 1, got size 0}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
func.return
}
func.func @empty_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func.return %arg0 : tensor<?xi32>
}
}
// -----
// Tests `tf_device.cluster_func` with unsupported operand type.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @unsupported_operand_type(%arg0: tensor<?xi2>) {
// expected-error@+1 {{failed to determine operand type at index 0: Converting i2 to DataType}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi2>) -> tensor<?xi2>
func.return
}
func.func @empty_func(%arg0: tensor<?xi2>) -> tensor<?xi2> {
func.return %arg0 : tensor<?xi2>
}
}
// -----
// Tests `tf_device.cluster_func` with bad element in `input_sharding_configuration` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_element_input_sharding_configuration(%arg0: tensor<?xi32>) {
// expected-error@+1 {{bad 'input_sharding_configuration' attribute at index 0, not a string}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [1], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
func.return
}
func.func @empty_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func.return %arg0 : tensor<?xi32>
}
}
// -----
// Tests `tf_device.cluster_func` with unparsable element in `input_sharding_configuration` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @unparsable_element_input_sharding_configuration(%arg0: tensor<?xi32>) {
// expected-error@+1 {{bad 'input_sharding_configuration' attribute at index 0 with value 'test': failed to parse to xla::OpSharding}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["test"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
func.return
}
func.func @empty_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func.return %arg0 : tensor<?xi32>
}
}
// -----
// Tests `tf_device.cluster_func` with missing `output_sharding_configuration` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @missing_output_sharding_configuration(%arg0: tensor<?xi32>) {
// expected-error@+1 {{requires attribute 'output_sharding_configuration'}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
func.return
}
func.func @empty_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func.return %arg0 : tensor<?xi32>
}
}
// -----
// Tests `tf_device.cluster_func` with bad `output_sharding_configuration` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_output_sharding_configuration(%arg0: tensor<?xi32>) {
// expected-error@+1 {{requires attribute 'output_sharding_configuration'}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = "", use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
func.return
}
func.func @empty_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func.return %arg0 : tensor<?xi32>
}
}
// -----
// Tests `tf_device.cluster_func` with mismatched `output_sharding_configuration` attribute size.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @mismatched_size_output_sharding_configuration(%arg0: tensor<?xi32>) {
// expected-error@+1 {{bad 'output_sharding_configuration' attribute, expected array attribute of size 1, got size 0}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
func.return
}
func.func @empty_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func.return %arg0 : tensor<?xi32>
}
}
// -----
// Tests `tf_device.cluster_func` with bad element in `output_sharding_configuration` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @bad_element_output_sharding_configuration(%arg0: tensor<?xi32>) {
// expected-error@+1 {{bad 'output_sharding_configuration' attribute at index 0, not a string}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [1], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
func.return
}
func.func @empty_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func.return %arg0 : tensor<?xi32>
}
}
// -----
// Tests `tf_device.cluster_func` with unparsable element in `output_sharding_configuration` attribute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @unparsable_element_output_sharding_configuration(%arg0: tensor<?xi32>) {
// expected-error@+1 {{bad 'output_sharding_configuration' attribute at index 0 with value 'test': failed to parse to xla::OpSharding}}
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["test"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
func.return
}
func.func @empty_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func.return %arg0 : tensor<?xi32>
}
}
// -----
// Tests `tf_device.cluster_func` with empty `step_marker_location` attribute
// defaults to `STEP_MARK_AT_ENTRY`.
//
// The expected TPUCompileMetadataProto is:
// num_replicas: 1
// num_cores_per_replica: 1
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @default_step_marker_location
func.func @default_step_marker_location() {
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
// CHECK: metadata
// CHECK-SAME: num_replicas: 1
// CHECK-SAME: num_cores_per_replica: 1
func.return
}
func.func @empty_func() {
func.return
}
}
// -----
// Tests argument with unranked shape. Empty shape should be populated in the
// metadata for associated argument.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @unranked_shape_arg
func.func @unranked_shape_arg(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<*xi32>) -> tensor<*xi32>
// CHECK: metadata
// CHECK-SAME: shape {\0A unknown_rank: true
func.return %0: tensor<*xi32>
}
func.func @_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
func.return %arg0 : tensor<*xi32>
}
}
// -----
// Tests argument with partial shape.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @partial_shape_arg
func.func @partial_shape_arg(%arg0: tensor<?x?x3xi32>) -> tensor<?x?x3xi32> {
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?x?x3xi32>) -> tensor<?x?x3xi32>
// CHECK: metadata
// CHECK-SAME: args
// CHECK-SAME: shape {\0A dim {\0A size: -1\0A }\0A dim {\0A size: -1\0A }\0A dim {\0A size: 3\0A }\0A }
func.return %0: tensor<?x?x3xi32>
}
func.func @_func(%arg0: tensor<?x?x3xi32>) -> tensor<?x?x3xi32> {
func.return %arg0 : tensor<?x?x3xi32>
}
}
// -----
// Tests argument with static shape.
// The expected TensorShapeProto is:
// shape {
// dim {
// size: 1
// }
// dim {
// size: 2
// }
// dim {
// size: 3
// }
// }
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @static_shape_arg
func.func @static_shape_arg(%arg0: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> {
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<1x2x3xi32>) -> tensor<1x2x3xi32>
// CHECK: metadata
// CHECK-SAME: args
// CHECK-SAME: shape
// CHECK-SAME: dim
// CHECK-SAME: size: 1
// CHECK-SAME: dim
// CHECK-SAME: size: 2
// CHECK-SAME: dim
// CHECK-SAME: size: 3
func.return %0: tensor<1x2x3xi32>
}
func.func @_func(%arg0: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> {
func.return %arg0 : tensor<1x2x3xi32>
}
}
// -----
// Tests argument that is a resource variable.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @resource_arg
func.func @resource_arg(%arg0: tensor<*x!tf_type.resource>) -> tensor<*x!tf_type.resource> {
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<*x!tf_type.resource>) -> tensor<*x!tf_type.resource>
// CHECK: metadata
// CHECK: dtype: DT_RESOURCE
// CHECK-SAME: kind: VARIABLE
func.return %0: tensor<*x!tf_type.resource>
}
func.func @_func(%arg0: tensor<*x!tf_type.resource>) -> tensor<*x!tf_type.resource> {
func.return %arg0 : tensor<*x!tf_type.resource>
}
}
// -----
// Tests argument that is a parameter.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @parameter_arg
func.func @parameter_arg(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: metadata
// CHECK: dtype: DT_FLOAT
// CHECK-SAME: kind: PARAMETER
func.return %0: tensor<*xf32>
}
func.func @_func(%arg0: tensor<*xf32>) -> tensor<*xf32> {
func.return %arg0 : tensor<*xf32>
}
}
// -----
// Tests metadata is populated correctly based on cluster_func op and attributes.
//
// The expected TPUCompileMetadataProto is:
// args {
// dtype: DT_INT32
// shape {
// dim {
// size: 8
// }
// . }
// kind: PARAMETER
// sharding {
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// }
// }
// retvals {
// sharding {
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// }
// }
// num_replicas: 1
// num_cores_per_replica: 1
// step_marker_location: STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @metadata
func.func @metadata(%arg0: tensor<8xi32>) -> tensor<8xi32> {
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32>
// CHECK: metadata
// CHECK-SAME: args
// CHECK-SAME: dtype: DT_INT32
// CHECK-SAME: shape
// CHECK-SAME: dim
// CHECK-SAME: size: 8
// CHECK-SAME: kind: PARAMETER
// CHECK-SAME: sharding
// CHECK-SAME: type: MAXIMAL
// CHECK-SAME: tile_assignment_dimensions: 1
// CHECK-SAME: tile_assignment_devices: 0
// CHECK-SAME: retvals
// CHECK-SAME: sharding
// CHECK-SAME: type: MAXIMAL
// CHECK-SAME: tile_assignment_dimensions: 1
// CHECK-SAME: tile_assignment_devices: 0
// CHECK-SAME: num_replicas: 1
// CHECK-SAME: num_cores_per_replica: 1
// CHECK-SAME: step_marker_location: STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP
func.return %0: tensor<8xi32>
}
func.func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> {
func.return %arg0 : tensor<8xi32>
}
}
// -----
// Tests metadata is populated correctly for use_spmd_for_xla_partitioning ==
// true.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @metadata_use_spmd
func.func @metadata_use_spmd(%arg0: tensor<8xi32>) -> tensor<8xi32> {
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = true} : (tensor<8xi32>) -> tensor<8xi32>
// CHECK: metadata
// CHECK-SAME: use_spmd_for_xla_partitioning: true
func.return %0: tensor<8xi32>
}
func.func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> {
func.return %arg0 : tensor<8xi32>
}
}
// -----
// Tests metadata is populated correctly for is_same_data_across_replicas.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @metadata_same_data_across_replicas
func.func @metadata_same_data_across_replicas(%arg0: tensor<8xi32>) -> tensor<8xi32> {
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32>
// CHECK: metadata
// CHECK-SAME: is_same_data_across_replicas: true
// CHECK-SAME: mhlo.is_same_data_across_replicas
func.return %0: tensor<8xi32>
}
func.func @tpu0_func(%arg0: tensor<8xi32> {mhlo.is_same_data_across_replicas}) -> tensor<8xi32> {
func.return %arg0 : tensor<8xi32>
}
}
// -----
// Tests shape ops are only generated for operands with non static shapes.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @static_and_dynamic_shapes
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_1:[a-z0-9]*]]: tensor<8xi32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<8xi32>)
func.func @static_and_dynamic_shapes(%arg0: tensor<*xi32>, %arg1: tensor<8xi32>, %arg2: tensor<*xi32>, %arg3: tensor<8xi32>) -> tensor<8xi32> {
// CHECK-NOT: "tf.Shape"(%[[ARG_1]])
// CHECK-NOT: "tf.Shape"(%[[ARG_3]])
// CHECK: %[[ARG_0_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_0]])
// CHECK: %[[ARG_2_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_2]])
%0 = "tf_device.cluster_func"(%arg0, %arg1, %arg2, %arg3) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32>
// CHECK: "tf._TPUCompileMlir"(%[[ARG_0_SHAPE]], %[[ARG_2_SHAPE]])
func.return %0: tensor<8xi32>
}
func.func @_func(%arg0: tensor<*xi32>, %arg1: tensor<8xi32>, %arg2: tensor<*xi32>, %arg3: tensor<8xi32>) -> tensor<8xi32> {
func.return %arg1 : tensor<8xi32>
}
}
// -----
// Tests simple case of `tf_device.cluster_func` on TPU with single input and
// single output.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @single_tpu_cluster_func
func.func @single_tpu_cluster_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
%1 = "tf_device.cluster_func"(%0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
// CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
func.return %2 : tensor<?xi32>
// CHECK: return %[[C_OUTPUT]]
}
func.func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
}
// -----
// Tests simple case of `tf_device.cluster_func` on TPU with replication. Under
// data parallelism replicated host devices are also added to the
// tf_device.replicate
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
// CHECK-LABEL: func @replicated_tpu_cluster_func
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<?xi32>)
func.func @replicated_tpu_cluster_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor<?xi32>)
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}
// CHECK-SAME: n = 2
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[RI_0]], %[[COMPILE_OUTPUT]]#1)
%2 = "tf_device.cluster_func"(%ri_0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: tf_device.return %[[EXECUTE_OUTPUT]]
tf_device.return %2 : tensor<?xi32>
}
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[REPLICATE]]#1)
%2 = "tf.C"(%1#1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: return %[[C_OUTPUT]]
func.return %2 : tensor<?xi32>
}
func.func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
}
// -----
// Tests that cluster_func without _xla_compile_device_type = "TPU", _replication_info attribute is ignored.
module attributes {tf.versions = {producer = 888 : i32}} {
// CHECK-LABEL: func @single_gpu_cluster_func
func.func @single_gpu_cluster_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%1 = "tf_device.cluster_func"(%0) {device = "gpu0", func = @gpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: tf_device.cluster_func
// CHECK-SAME: device = "gpu0"
// CHECK-SAME: func = @gpu0_func
// CHECK-SAME: num_cores_per_replica = 1
// CHECK-SAME: step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP"
// CHECK-NOT: metadata
func.return %1 : tensor<?xi32>
}
func.func @gpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
}
// -----
// Tests of `tf_device.cluster_func` on TPU with nested function calls.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @with_nested_func
func.func @with_nested_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
%1 = "tf_device.cluster_func"(%0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: func private @nested_func
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
// CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
func.return %2 : tensor<?xi32>
// CHECK: return %[[C_OUTPUT]]
}
func.func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%1 = func.call @nested_func(%0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %1 : tensor<?xi32>
}
func.func @nested_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.D"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
}
// -----
// Tests of `tf_device.cluster_func` on TPU with referenced function that's not
// via a standard call op.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @with_referenced_func
func.func @with_referenced_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
%1 = "tf_device.cluster_func"(%0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: func private @referenced_func
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
func.return %2 : tensor<?xi32>
// CHECK: return %[[C_OUTPUT]]
}
func.func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) {body = @referenced_func} : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
func.func @referenced_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.D"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
}
// -----
// Tests rewriting `tf_device.cluster_func` on TPU with a chain of referenced
// functions.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @with_referenced_func_chain
func.func @with_referenced_func_chain(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
%1 = "tf_device.cluster_func"(%0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: @referenced_func1
// CHECK-SAME: tf.D
// CHECK-SAME: @referenced_func2
// CHECK-SAME: tf.E
// CHECK-NOT: func = @tpu0_func
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
func.return %2 : tensor<?xi32>
// CHECK: return %[[C_OUTPUT]]
}
func.func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) {body = @referenced_func1} : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
func.func @referenced_func1(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.D"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%1 = func.call @referenced_func2(%0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %1 : tensor<?xi32>
}
func.func @referenced_func2(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.E"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
}
// -----
// Tests rewriting `tf_device.cluster_func` on TPU with multiple calls to same
// function.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @with_multiple_call_same_referenced_func
func.func @with_multiple_call_same_referenced_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
%1 = "tf_device.cluster_func"(%0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-COUNT-2: call @referenced_func
// CHECK-COUNT-1: func private @referenced_func
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
func.return %2 : tensor<?xi32>
// CHECK: return %[[C_OUTPUT]]
}
func.func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) {body = @referenced_func1} : (tensor<?xi32>) -> tensor<?xi32>
%1 = func.call @referenced_func(%0) : (tensor<?xi32>) -> tensor<?xi32>
%2 = func.call @referenced_func(%1) : (tensor<?xi32>) -> tensor<?xi32>
func.return %2 : tensor<?xi32>
}
func.func @referenced_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%1 = "tf.D"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %1 : tensor<?xi32>
}
}
// -----
// Tests multiple `tf_device.cluster_func` on TPU with different computation.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @multiple_cluster_different_func
func.func @multiple_cluster_different_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
%1 = "tf_device.cluster_func"(%0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func0
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE0_OUTPUT]]#0)
// CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1)
%2 = "tf_device.cluster_func"(%1) {_xla_compile_device_type = "TPU", _replication_info = "cluster1", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func1
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE1_OUTPUT]]#0)
// CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1)
%3 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE1_OUTPUT]])
func.return %3 : tensor<?xi32>
// CHECK: return %[[C_OUTPUT]]
}
func.func @tpu0_func0(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
func.func @tpu0_func1(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.D"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
}
// -----
// Tests multiple `tf_device.cluster_func` on TPU with same computation.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @multiple_cluster_same_func
func.func @multiple_cluster_same_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
%1 = "tf_device.cluster_func"(%0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE0_OUTPUT]]#0)
// CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1)
%2 = "tf_device.cluster_func"(%1) {_xla_compile_device_type = "TPU", _replication_info = "cluster1", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE1_OUTPUT]]#0)
// CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1)
%3 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE1_OUTPUT]])
func.return %3 : tensor<?xi32>
// CHECK: return %[[C_OUTPUT]]
}
func.func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
}
// -----
// Tests Functions referenced by TPU function via SymbolRefAttr nested in
// ArrayAttr and DictionaryAttr.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @single_tpu_cluster_func
func.func @single_tpu_cluster_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
%1 = "tf_device.cluster_func"(%0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: func private @referenced_func3
// CHECK-SAME: tf.I
// CHECK-SAME: func private @referenced_func2
// CHECK-SAME: tf.H
// CHECK-SAME: func private @referenced_func1
// CHECK-SAME: tf.G
// CHECK-SAME: func private @referenced_func0
// CHECK-SAME: tf.F
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
func.return %2 : tensor<?xi32>
// CHECK: return %[[C_OUTPUT]]
}
func.func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
%1 = "tf.D"(%0) {array_attr_funcs = [@referenced_func0, @referenced_func1]} : (tensor<?xi32>) -> tensor<?xi32>
%2 = "tf.E"(%1) {dictionary_attr_funcs = {fn1 = @referenced_func2, fn2 = @referenced_func3}} : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
func.func @referenced_func0(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%1 = "tf.F"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %1 : tensor<?xi32>
}
func.func @referenced_func1(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%1 = "tf.G"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %1 : tensor<?xi32>
}
func.func @referenced_func2(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%1 = "tf.H"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %1 : tensor<?xi32>
}
func.func @referenced_func3(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%1 = "tf.I"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %1 : tensor<?xi32>
}
}
// -----
// Test `tf_device.cluster_func` on TPU with pre-split replicate sharded
// input/output using `tf.TPUPartitionedInput` and `tf.TPUPartitionedOutput`.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
func.func @cluster(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) {
// CHECK: %[[READ_VAR_0:[0-9]*]] = "tf.ReadVariableOp"(%arg0)
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
// CHECK: %[[READ_VAR_1:[0-9]*]] = "tf.ReadVariableOp"(%arg1)
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
// CHECK-NOT: tf.TPUPartitionedInput
%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: [[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[READ_VAR_0]], %[[COMPILE_OUTPUT]]#1)
// CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[READ_VAR_1]], %[[COMPILE_OUTPUT]]#2)
// CHECK: device = "/job:worker/replica:0/task:0/device:TPU:1"
%computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = [""], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
// CHECK-NOT: tf.TPUPartitionedOutput
%partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
// CHECK: "tf.AssignVariableOp"(%arg0, %[[PARALLEL_EXECUTE_OUTPUT]]#0)
// CHECK: "tf.AssignVariableOp"(%arg1, %[[PARALLEL_EXECUTE_OUTPUT]]#1)
"tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
func.return
}
func.func @computation(%arg0: tensor<i32>) -> tensor<i32> {
func.return %arg0: tensor<i32>
}
}
// -----
// Test `tf_device.cluster_func` on TPU with pre-split tile sharded input/
// output using `tf.TPUPartitionedInput` and `tf.TPUPartitionedOutput`.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
func.func @cluster(%arg0: tensor<!tf_type.resource<tensor<3x2xf32>>>, %arg1: tensor<!tf_type.resource<tensor<3x2xf32>>>) {
// CHECK: %[[READ_VAR_0:[0-9]*]] = "tf.ReadVariableOp"(%arg0)
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<3x2xf32>>>) -> tensor<3x2xf32>
// CHECK: %[[READ_VAR_1:[0-9]*]] = "tf.ReadVariableOp"(%arg1)
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf_type.resource<tensor<3x2xf32>>>) -> tensor<3x2xf32>
// CHECK-NOT: tf.TPUPartitionedInput
%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {_XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", partition_dim = 1 : i64} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x4xf32>
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: [[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[READ_VAR_0]], %[[COMPILE_OUTPUT]]#1)
// CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[READ_VAR_1]], %[[COMPILE_OUTPUT]]#2)
// CHECK: device = "/job:worker/replica:0/task:0/device:TPU:1"
%computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01"], use_spmd_for_xla_partitioning = true} : (tensor<3x4xf32>) -> tensor<3x4xf32>
// CHECK-NOT: tf.TPUPartitionedOutput
%partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {_XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", partition_dim = 1 : i64} : (tensor<3x4xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>)
// CHECK: "tf.AssignVariableOp"(%arg0, %[[PARALLEL_EXECUTE_OUTPUT]]#0)
// CHECK: "tf.AssignVariableOp"(%arg1, %[[PARALLEL_EXECUTE_OUTPUT]]#1)
"tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf_type.resource<tensor<3x2xf32>>>, tensor<3x2xf32>) -> ()
"tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf_type.resource<tensor<3x2xf32>>>, tensor<3x2xf32>) -> ()
func.return
}
func.func @computation(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
func.return %arg0: tensor<3x4xf32>
}
}
// -----
// Test that unsupported input sharding type of TPUPartitionedInputOp inputs of
// ClusterFuncOp result in error.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
func.func @cluster(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) {
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// expected-error@+1 {{unsupported input sharding type MAXIMAL for 0-th input}}
%computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1],
input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
%partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
"tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
func.return
}
func.func @computation(%arg0: tensor<i32>) -> tensor<i32> {
func.return %arg0: tensor<i32>
}
}
// -----
// Test that unsupported output sharding type of TPUPartitionedOutputOp outputs
// of ClusterFuncOp result in error.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
func.func @cluster(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) {
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// expected-error@+1 {{unsupported output sharding type MAXIMAL for 0-th output}}
%computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1],
input_sharding_configuration = [""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
%partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
"tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
func.return
}
func.func @computation(%arg0: tensor<i32>) -> tensor<i32> {
func.return %arg0: tensor<i32>
}
}
// -----
// Test that multiple uses of ClusterFuncOp output alongwith
// TPUPartitionedOutputOp results in error.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
func.func @cluster(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) {
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = [""], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
// expected-error@+1 {{'tf.TPUPartitionedOutput' op must be a unique user of tf_device.cluster_func output}}
%partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
"tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf._SomeOp"(%computation) : (tensor<i32>) -> ()
func.return
}
func.func @computation(%arg0: tensor<i32>) -> tensor<i32> {
func.return %arg0: tensor<i32>
}
}
// -----
// Tests that TPUCompilationResult operations are properly rewritten.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @tpu_compilation_result
func.func @tpu_compilation_result(%arg0: tensor<?xi32>) -> (tensor<?xi32>, tensor<!tf_type.string>, tensor<!tf_type.string>) {
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: %[[COMPILE_RESULT_0:.*]] = "tf.Identity"(%[[COMPILE_OUTPUT]]#0)
// CHECK: %[[COMPILE_RESULT_1:.*]] = "tf.Identity"(%[[COMPILE_RESULT_0]])
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"
%1 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
%compile_result = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster0"} : () -> tensor<!tf_type.string>
%compile_result2 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster0"} : () -> tensor<!tf_type.string>
// CHECK-NOT: "tf.TPUCompilationResult"
// CHECK: return %[[EXECUTE_OUTPUT]], %[[COMPILE_OUTPUT]]#0, %[[COMPILE_OUTPUT]]#0
func.return %1, %compile_result, %compile_result2 : tensor<?xi32>, tensor<!tf_type.string>, tensor<!tf_type.string>
}
func.func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
}
// -----
// Tests simple case of `tf_device.cluster_func` on TPU with replication and
// parallel_execute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
// CHECK-LABEL: func @replicated_parallel_tpu_cluster_func
func.func @replicated_parallel_tpu_cluster_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK: "tf._TPUCompileMlir"
// CHECK: "tf.TPUCompileSucceededAssert"
// CHECK: "tf_device.parallel_execute"
// CHECK-NOT:"tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK: "tf.D"(%[[COMPILE_OUTPUT]]#1
// CHECK: "tf.TPUExecute"
// CHECK-NOT:"tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK: "tf.E"(%[[COMPILE_OUTPUT]]#1
%3 = "tf_device.parallel_execute"() ({
%program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf_type.string>
"tf.D"(%program) : (tensor<3x!tf_type.string>) -> ()
tf_device.return
}, {
%4 = "tf_device.cluster_func"(%ri_0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}, {
%program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf_type.string>
"tf.E"(%program) : (tensor<3x!tf_type.string>) -> ()
tf_device.return
}) : () -> (tensor<?xi32>)
tf_device.return %3 : tensor<?xi32>
}
%2 = "tf.C"(%1#1) : (tensor<?xi32>) -> tensor<?xi32>
func.return %2 : tensor<?xi32>
}
func.func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
func.return %0 : tensor<?xi32>
}
}
// -----
// Tests devices are set properly for non replicated model parallelism.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @non_replicated_parallel_execute
func.func @non_replicated_parallel_execute(%arg0: tensor<8xi32>) -> tensor<8xi32> {
// CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:1"
%0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32>
func.return %0 : tensor<8xi32>
}
func.func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> {
func.return %arg0 : tensor<8xi32>
}
}
// -----
// The following topology is used in subsequent test cases:
// Proto debug string:
// mesh_shape: 1
// mesh_shape: 2
// mesh_shape: 1
// mesh_shape: 2
// num_tasks: 2
// num_tpu_devices_per_task: 2
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 1
// Serialized string:
// "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01"
// -----
// Tests devices are set properly for replicated model parallelism. No
// replicated host device should be present.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @replicated_parallel_execute
func.func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) {
// CHECK: tf_device.replicate
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]}
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} {
// CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_1"
%1 = "tf_device.cluster_func"(%ri) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32>
tf_device.return %1 : tensor<8xi32>
}
func.return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32>
}
func.func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> {
func.return %arg0 : tensor<8xi32>
}
}
// -----
// Tests that inputs are inputs with maximal and replicate sharding are set
// properly for replicated model parallelism.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_input_with_sharding_configurations
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<8xi32>, %[[ARG_1:[a-z0-9]*]]: tensor<8xi32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi1>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi1>, %[[ARG_4:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_5:[a-z0-9]*]]: tensor<*xi32>)
func.func @parallel_execute_with_input_with_sharding_configurations(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>, %arg2: tensor<*xi1>, %arg3: tensor<*xi1>, %arg4: tensor<*xi32>, %arg5: tensor<*xi32>) -> (tensor<8xi32>, tensor<8xi32>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<8xi32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi1>
// CHECK-SAME: [%[[ARG_4]], %[[ARG_5]]] as %[[RI_2:[a-z0-9]*]]: tensor<*xi32>
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>, [%arg2, %arg3] as %ri2: tensor<*xi1>, [%arg4, %arg5] as %ri3: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK: "tf._TPUCompileMlir"
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[RI_0]], %[[RI_1]], %[[RI_2]], %[[COMPILE]]#1)
// CHECK-NEXT: tf_device.return %[[EXECUTE_OUTPUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[RI_1]], %[[RI_2]], %[[COMPILE]]#2)
// CHECK: device = "TPU_REPLICATED_CORE_1"
%1 = "tf_device.cluster_func"(%ri, %ri2, %ri3) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "", ""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> tensor<8xi32>
tf_device.return %1 : tensor<8xi32>
}
func.return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32>
}
func.func @tpu0_func(%arg0: tensor<8xi32>, %arg1: tensor<*xi1>, %arg2: tensor<*xi32>) -> tensor<8xi32> {
%1 = "tf.A"(%arg0, %arg1, %arg2) : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> (tensor<8xi32>)
func.return %1 : tensor<8xi32>
}
}
// -----
// Tests devices are set properly for replicated model parallelism with outputs
// to TPU computation placed on logical device 0.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_different_outputs
func.func @parallel_execute_with_different_outputs(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) {
// CHECK: tf_device.replicate
// CHECK-SAME: devices =
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} {
// CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"
// CHECK-NEXT: tf_device.return %[[EXECUTE_OUTPUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"
// CHECK: device = "TPU_REPLICATED_CORE_1"
%1 = "tf_device.cluster_func"(%ri) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32>
tf_device.return %1 : tensor<8xi32>
}
func.return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32>
}
func.func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> {
func.return %arg0 : tensor<8xi32>
}
}
// -----
// Tests devices are set properly for replicated model parallelism with
// TPU computation with maximal and replicated outputs.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_replicated_output
func.func @parallel_execute_with_replicated_output(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: devices =
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} {
// CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"
// CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"
// CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: device = "TPU_REPLICATED_CORE_1"
%1, %2 = "tf_device.cluster_func"(%ri) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %1, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// Tests inputs are correctly split and fed into TPU computation for tiled input
// sharding.
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:
// input 0
// type: OTHER
// tile_assignment_dimensions: 1
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// Serialized string:
// "\08\03\1A\02\01\02\22\02\00\01"
//
// input 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 1
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_tiled_input
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>)
func.func @parallel_execute_with_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
// CHECK-SAME: devices =
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
//
// CHECK: %[[CONST_SPLIT_DIM:.*]] = "tf.Const"()
// CHECK: %[[SPLIT_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_DIM]], %[[RI_0]])
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
//
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_OUT]]#0, %[[COMPILE]]#1)
// CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2)
// CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: device = "TPU_REPLICATED_CORE_1"
%1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// Tests that outputs are correctly merged and fed from TPU computation for
// tiled output sharding.
// The following OpSharding is used for TPU computation outputs in below test:
// Proto debug string:
// output 0
// type: OTHER
// tile_assignment_dimensions: 1
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// Serialized string:
// "\08\03\1A\02\01\02\22\02\00\01"
//
// output 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_tiled_output
func.func @parallel_execute_with_tiled_output(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
// CHECK-SAME: devices =
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
//
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"
// CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"
// CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: device = "TPU_REPLICATED_CORE_1"
//
// CHECK: %[[CONST_CONCAT_DIM:.*]] = "tf.Const"()
// CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2
%1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// Tests that outputs are correctly merged and fed from TPU computation for
// tiled output sharding with MAXIMAL sharding for one of the output and OTHER
// for latter output.
// The following OpSharding is used for TPU computation outputs in below test:
// Proto debug string:
// output 0
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
//
// output 1
// type: OTHER
// tile_assignment_dimensions: 1
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// Serialized string:
// "\08\03\1A\02\01\02\22\02\00\01"
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_tiled_output
func.func @parallel_execute_with_tiled_output(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
// CHECK-SAME: devices =
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
//
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"
// CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"
// CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: device = "TPU_REPLICATED_CORE_1"
//
// CHECK: %[[CONST_CONCAT_DIM:.*]] = "tf.Const"()
// CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#1, %[[PARALLEL_EXECUTE_OUTPUT]]#2
%1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {
_xla_compile_device_type = "TPU", _replication_info = "cluster0",
func = @tpu0_func, num_cores_per_replica = 2,
step_marker_location = "",
topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01",
device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0],
input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"],
output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\03\1A\02\01\02\22\02\00\01"],
use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:
// input 0
// type: OTHER
// tile_assignment_dimensions: 1
// tile_assignment_dimensions: 4
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// Serialized string:
// "\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03"
//
// input 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 1
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
//
// -----
// Tests tile sharding of inputs with number of splits that does not evenly divide
// the input results in an error.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
func.func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}}
%1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// The following OpSharding is used for TPU computation outputs in below test:
// Proto debug string:
// output 0
// type: OTHER
// tile_assignment_dimensions: 1
// tile_assignment_dimensions: 4
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// Serialized string:
// "\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03"
//
// output 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 1
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
//
// -----
// Tests tile sharding of outputs with number of splits that exeed number
// of logical devices is not allowed.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
func.func @uneven_output_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// expected-error@+1 {{incorrect sharding format for outputs. Number of tiled outputs(4) must match the number of logical devices(2)}}
%1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["", ""], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi1>, tensor<*xi32>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %3, %4 : tensor<*xi1>, tensor<*xi32>
}
}
// -----
// The following topology is used in subsequent test cases:
// Proto debug string:
// mesh_shape: 2
// mesh_shape: 1
// mesh_shape: 2
// num_tasks: 2
// num_tpu_devices_per_task: 2
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 1
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:
// input 0
// type: OTHER
// tile_shape {
// element_type: F32
// dimensions: 2
// dimensions: 2
// layout {
// minor_to_major: 1
// minor_to_major: 0
// format: DENSE
// }
// is_dynamic_dimension: false
// is_dynamic_dimension: false
// }
// tile_assignment_dimensions: 2
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// tile_assignment_devices: 2
// tile_assignment_devices: 3
// Serialized string:
// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03"
//
// input 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 1
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
// Tests inputs to TPUComputation that are tiled in multiple dimensions.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_multi_dimension_tiled_input
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>)
func.func @parallel_execute_with_multi_dimension_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: %[[CONST_SPLIT_0_DIM:.*]] = "tf.Const"()
// CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]])
// CHECK: %[[CONST_SPLIT_1_DIM:.*]] = "tf.Const"()
// CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0)
// CHECK: %[[CONST_SPLIT_2_DIM:.*]] = "tf.Const"()
// CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1)
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#1)
// CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2)
// CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[COMPILE]]#3)
// CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]]
// CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4)
// CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]]
%1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// The following topology is used in subsequent test cases:
// Proto debug string:
// mesh_shape: 2
// mesh_shape: 1
// mesh_shape: 2
// num_tasks: 2
// num_tpu_devices_per_task: 2
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 0
// device_coordinates: 1
// device_coordinates: 0
// device_coordinates: 1
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:
// input 0
// type: OTHER
// tile_shape {
// element_type: F32
// dimensions: 2
// dimensions: 2
// layout {
// minor_to_major: 1
// minor_to_major: 0
// format: DENSE
// }
// is_dynamic_dimension: false
// is_dynamic_dimension: false
// }
// tile_assignment_dimensions: 2
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// tile_assignment_devices: 2
// tile_assignment_devices: 3
// Serialized string:
// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03"
//
// input 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 1
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
// Tests inputs to TPUComputation that are tiled in multiple dimensions.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_multi_dimension_tiled_input
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>)
func.func @parallel_execute_with_multi_dimension_tiled_input(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: %[[CONST_SPLIT_0_DIM:.*]] = "tf.Const"()
// CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]])
// CHECK: %[[CONST_SPLIT_1_DIM:.*]] = "tf.Const"()
// CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0)
// CHECK: %[[CONST_SPLIT_2_DIM:.*]] = "tf.Const"()
// CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1)
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#1)
// CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2)
// CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[COMPILE]]#3)
// CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]]
// CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4)
// CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]]
%1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// Tests that tiled output with multiple dimension sharding works properly.
// The following OpSharding is used for TPU computation outputs in below test:
// output 0
// Proto debug string:
// type: OTHER
// tile_shape {
// element_type: F32
// dimensions: 2
// dimensions: 2
// layout {
// minor_to_major: 1
// minor_to_major: 0
// format: DENSE
// }
// is_dynamic_dimension: false
// is_dynamic_dimension: false
// }
// tile_assignment_dimensions: 2
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// tile_assignment_devices: 2
// tile_assignment_devices: 3
// Serialized string:
// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03"
//
// output 1
// Proto debug string:
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @multiple_dimension_output_sharding
func.func @multiple_dimension_output_sharding(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"
// CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"
// CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(
// CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]]
// CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(
// CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]]
// CHECK: %[[CONST_CONCAT_DIM:.*]] = "tf.Const"()
// CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2
// CHECK: %[[CONST_CONCAT2_DIM:.*]] = "tf.Const"()
// CHECK: %[[CONCAT2_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT2_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#3, %[[PARALLEL_EXECUTE_OUTPUT]]#4
// CHECK: %[[CONST_CONCAT3_DIM:.*]] = "tf.Const"()
// CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[CONCAT_OUTPUT]], %[[CONCAT2_OUTPUT]]
%1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// Tests inputs device assignment order is well preserved for tiled input sharding.
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:
// input 0
// type: OTHER
// tile_shape {
// element_type: F32
// dimensions: 2
// dimensions: 2
// layout {
// minor_to_major: 1
// minor_to_major: 0
// format: DENSE
// }
// is_dynamic_dimension: false
// is_dynamic_dimension: false
// }
// tile_assignment_dimensions: 2
// tile_assignment_dimensions: 2
// tile_assignment_devices: 3
// tile_assignment_devices: 2
// tile_assignment_devices: 1
// tile_assignment_devices: 0
// Serialized string:
// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00"
//
//
// input 1
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 1
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
//
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @tiled_input_sharding_with_device_assignment_order
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>)
func.func @tiled_input_sharding_with_device_assignment_order(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: %[[CONST_SPLIT_0_DIM:.*]] = "tf.Const"()
// CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]])
// CHECK: %[[CONST_SPLIT_1_DIM:.*]] = "tf.Const"()
// CHECK: %[[SPLIT_1_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_1_DIM]], %[[SPLIT_0_OUT]]#0)
// CHECK: %[[CONST_SPLIT_2_DIM:.*]] = "tf.Const"()
// CHECK: %[[SPLIT_2_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_2_DIM]], %[[SPLIT_0_OUT]]#1)
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#1)
// CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#0, %[[RI_1]], %[[COMPILE]]#2)
// CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#1, %[[COMPILE]]#3)
// CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]]
// CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#4)
// CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]]
%1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
// Tests device assignment is well preserved for tile sharded outputs.
// The following OpSharding is used for TPU computation outputs in below test:
// output 0
// Proto debug string:
// type: OTHER
// tile_shape {
// element_type: F32
// dimensions: 2
// dimensions: 2
// layout {
// minor_to_major: 1
// minor_to_major: 0
// format: DENSE
// }
// is_dynamic_dimension: false
// is_dynamic_dimension: false
// }
// tile_assignment_dimensions: 2
// tile_assignment_dimensions: 2
// tile_assignment_devices: 3
// tile_assignment_devices: 2
// tile_assignment_devices: 1
// tile_assignment_devices: 0
// Serialized string:
// "\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00"
//
// output 1
// Proto debug string:
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// Serialized string:
// "\08\01\1A\01\01\22\01\01"
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @device_order_preserved_for_tiled_output
func.func @device_order_preserved_for_tiled_output(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32>
%0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0)
// CHECK: device = "/job:localhost/replica:0/task:0/device:CPU:0"
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute"
// CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"
// CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]]
// CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"
// CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]]
// CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(
// CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]]
// CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(
// CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]]
// CHECK: %[[CONST_CONCAT_DIM:.*]] = "tf.Const"()
// CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#4, %[[PARALLEL_EXECUTE_OUTPUT]]#3
// CHECK: %[[CONST_CONCAT2_DIM:.*]] = "tf.Const"()
// CHECK: %[[CONCAT2_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT2_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#2, %[[PARALLEL_EXECUTE_OUTPUT]]#0
// CHECK: %[[CONST_CONCAT3_DIM:.*]] = "tf.Const"()
// CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[CONCAT_OUTPUT]], %[[CONCAT2_OUTPUT]]
%1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1>
}
func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1>
}
func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) {
%1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
%3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1>
func.return %4, %3 : tensor<*xi32>, tensor<*xi1>
}
}
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @missing_compilation_attribute() {
// expected-error@+1 {{'tf_device.cluster_func' op has '_replication_info' attribute but not '_xla_compile_device_type' attribute which is unsupported}}
"tf_device.cluster_func"() {_replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
}
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @empty_replication_attribute() {
// expected-error@+1 {{'tf_device.cluster_func' op has an empty '_replication_info' attribute}}
"tf_device.cluster_func"() {_xla_compile_device_type = "TPU", _replication_info = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
}
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @empty_compilation_attribute() {
// expected-error@+1 {{'tf_device.cluster_func' op has invalid '_xla_compile_device_type' value ''}}
"tf_device.cluster_func"() {_xla_compile_device_type = "", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
}
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
func.func @invalid_compilation_attribute() {
// expected-error@+1 {{'tf_device.cluster_func' op has invalid '_xla_compile_device_type' value 'XPU'}}
"tf_device.cluster_func"() {_xla_compile_device_type = "XPU", _replication_info = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = [], use_spmd_for_xla_partitioning = false} : () -> ()
func.return
}
}
// -----
// Test `tf.TPUPartitionedInput` has outputs not in `tf_device.cluster_func`
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
func.func @cluster(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) {
// expected-error@+1 {{Output of TPUPartitionedInput must be in tpu computation.}}
%partitioned_input = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<!tf_type.resource<tensor<i32>>>, tensor<!tf_type.resource<tensor<i32>>>) -> tensor<!tf_type.resource<tensor<i32>>>
%read = "tf.ReadVariableOp"(%partitioned_input) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%computation = "tf_device.cluster_func"(%read) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = [""], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
%partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
"tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
func.return
}
func.func @computation(%arg0: tensor<i32>) -> tensor<i32> {
func.return %arg0: tensor<i32>
}
}
// -----
// Test `tf.TPUPartitionedOutput` has inputs not in `tf_device.cluster_func`
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
func.func @cluster(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) {
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = [""], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
%add_result = "tf.Add"(%computation, %computation) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// expected-error@+1 {{Input of TPUPartitionedOutput must be in tpu computation.}}
%partitioned_output:2 = "tf.TPUPartitionedOutput"(%add_result) {N = 2 : i64, partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
"tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
func.return
}
func.func @computation(%arg0: tensor<i32>) -> tensor<i32> {
func.return %arg0: tensor<i32>
}
}